""" Self-RAG / Corrective RAG Module Implements self-reflective RAG that can: 1. Grade retrieved documents for relevance 2. Decide if more retrieval is needed 3. Reformulate queries if initial retrieval fails 4. Filter irrelevant passages before generation 5. Grade answers for groundedness and hallucination Based on research: - Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique - Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation This is especially useful for German educational documents where: - Queries may use informal language - Documents use formal/technical terminology - Context must be precisely matched to scoring criteria """ import os from typing import List, Dict, Optional, Tuple from enum import Enum import httpx # Configuration # IMPORTANT: Self-RAG is DISABLED by default for privacy reasons! # When enabled, search queries and retrieved documents are sent to OpenAI API # for relevance grading and query reformulation. This exposes user data to third parties. # Only enable if you have explicit user consent for data processing. SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true" OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini") # Thresholds for self-reflection RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6")) GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7")) MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2")) class RetrievalDecision(Enum): """Decision after grading retrieval.""" SUFFICIENT = "sufficient" # Context is good, proceed to generation NEEDS_MORE = "needs_more" # Need to retrieve more documents REFORMULATE = "reformulate" # Query needs reformulation FALLBACK = "fallback" # Use fallback (no good context found) class SelfRAGError(Exception): """Error during Self-RAG processing.""" pass async def grade_document_relevance( query: str, document: str, ) -> Tuple[float, str]: """ Grade whether a document is relevant to the query. Returns a score between 0 (irrelevant) and 1 (highly relevant) along with an explanation. """ if not OPENAI_API_KEY: # Fallback: simple keyword overlap query_words = set(query.lower().split()) doc_words = set(document.lower().split()) overlap = len(query_words & doc_words) / max(len(query_words), 1) return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)" prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist. SUCHANFRAGE: {query} DOKUMENT: {document[:2000]} Ist dieses Dokument relevant, um die Anfrage zu beantworten? Berücksichtige: - Thematische Übereinstimmung - Enthält das Dokument spezifische Informationen zur Anfrage? - Würde dieses Dokument bei der Beantwortung helfen? Antworte im Format: SCORE: [0.0-1.0] BEGRÜNDUNG: [Kurze Erklärung]""" try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers={ "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": SELF_RAG_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 150, "temperature": 0.0, }, timeout=30.0 ) if response.status_code != 200: return 0.5, f"API error: {response.status_code}" result = response.json()["choices"][0]["message"]["content"] import re score_match = re.search(r'SCORE:\s*([\d.]+)', result) score = float(score_match.group(1)) if score_match else 0.5 reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL) reason = reason_match.group(1).strip() if reason_match else result return min(max(score, 0.0), 1.0), reason except Exception as e: return 0.5, f"Grading error: {str(e)}" async def grade_documents_batch( query: str, documents: List[str], ) -> List[Tuple[float, str]]: """ Grade multiple documents for relevance. Returns list of (score, reason) tuples. """ results = [] for doc in documents: score, reason = await grade_document_relevance(query, doc) results.append((score, reason)) return results async def filter_relevant_documents( query: str, documents: List[Dict], threshold: float = RELEVANCE_THRESHOLD, ) -> Tuple[List[Dict], List[Dict]]: """ Filter documents by relevance, separating relevant from irrelevant. Args: query: The search query documents: List of document dicts with 'text' field threshold: Minimum relevance score to keep Returns: Tuple of (relevant_docs, filtered_out_docs) """ relevant = [] filtered = [] for doc in documents: text = doc.get("text", "") score, reason = await grade_document_relevance(query, text) doc_with_grade = doc.copy() doc_with_grade["relevance_score"] = score doc_with_grade["relevance_reason"] = reason if score >= threshold: relevant.append(doc_with_grade) else: filtered.append(doc_with_grade) # Sort relevant by score relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True) return relevant, filtered async def decide_retrieval_strategy( query: str, documents: List[Dict], attempt: int = 1, ) -> Tuple[RetrievalDecision, Dict]: """ Decide what to do based on current retrieval results. Args: query: The search query documents: Retrieved documents with relevance scores attempt: Current retrieval attempt number Returns: Tuple of (decision, metadata) """ if not documents: if attempt >= MAX_RETRIEVAL_ATTEMPTS: return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"} return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"} # Check average relevance scores = [doc.get("relevance_score", 0.5) for doc in documents] avg_score = sum(scores) / len(scores) max_score = max(scores) if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7: return RetrievalDecision.SUFFICIENT, { "avg_relevance": avg_score, "max_relevance": max_score, "doc_count": len(documents), } if attempt >= MAX_RETRIEVAL_ATTEMPTS: if max_score >= RELEVANCE_THRESHOLD * 0.5: # At least some relevant context, proceed with caution return RetrievalDecision.SUFFICIENT, { "avg_relevance": avg_score, "warning": "Low relevance after max attempts", } return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"} if avg_score < 0.3: return RetrievalDecision.REFORMULATE, { "reason": "Very low relevance, query reformulation needed", "avg_relevance": avg_score, } return RetrievalDecision.NEEDS_MORE, { "reason": "Moderate relevance, retrieving more documents", "avg_relevance": avg_score, } async def reformulate_query( original_query: str, context: Optional[str] = None, previous_results_summary: Optional[str] = None, ) -> str: """ Reformulate a query to improve retrieval. Uses LLM to generate a better query based on: - Original query - Optional context (subject, niveau, etc.) - Summary of why previous retrieval failed """ if not OPENAI_API_KEY: # Simple reformulation: expand abbreviations, add synonyms reformulated = original_query expansions = { "EA": "erhöhtes Anforderungsniveau", "eA": "erhöhtes Anforderungsniveau", "GA": "grundlegendes Anforderungsniveau", "gA": "grundlegendes Anforderungsniveau", "AFB": "Anforderungsbereich", "Abi": "Abitur", } for abbr, expansion in expansions.items(): if abbr in original_query: reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})") return reformulated prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen. Die folgende Suchanfrage hat keine guten Ergebnisse geliefert: ORIGINAL: {original_query} {f"KONTEXT: {context}" if context else ""} {f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""} Formuliere die Anfrage so um, dass sie: 1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten) 2. Relevante Synonyme oder verwandte Begriffe einschließt 3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung.""" try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers={ "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": SELF_RAG_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 100, "temperature": 0.3, }, timeout=30.0 ) if response.status_code != 200: return original_query return response.json()["choices"][0]["message"]["content"].strip() except Exception: return original_query async def grade_answer_groundedness( answer: str, contexts: List[str], ) -> Tuple[float, List[str]]: """ Grade whether an answer is grounded in the provided contexts. Returns: Tuple of (grounding_score, list of unsupported claims) """ if not OPENAI_API_KEY: return 0.5, ["LLM not configured for grounding check"] context_text = "\n---\n".join(contexts[:5]) prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird. KONTEXTE: {context_text} ANTWORT: {answer} Identifiziere: 1. Welche Aussagen sind durch die Kontexte belegt? 2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)? Antworte im Format: SCORE: [0.0-1.0] (1.0 = vollständig belegt) NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]""" try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers={ "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": SELF_RAG_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 300, "temperature": 0.0, }, timeout=30.0 ) if response.status_code != 200: return 0.5, [f"API error: {response.status_code}"] result = response.json()["choices"][0]["message"]["content"] import re score_match = re.search(r'SCORE:\s*([\d.]+)', result) score = float(score_match.group(1)) if score_match else 0.5 unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL) unsupported_text = unsupported_match.group(1).strip() if unsupported_match else "" if unsupported_text.lower() == "keine": unsupported = [] else: unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()] return min(max(score, 0.0), 1.0), unsupported except Exception as e: return 0.5, [f"Grounding check error: {str(e)}"] async def self_rag_retrieve( query: str, search_func, subject: Optional[str] = None, niveau: Optional[str] = None, initial_top_k: int = 10, final_top_k: int = 5, **search_kwargs ) -> Dict: """ Perform Self-RAG enhanced retrieval with reflection and correction. This implements a retrieval loop that: 1. Retrieves initial documents 2. Grades them for relevance 3. Decides if more retrieval is needed 4. Reformulates query if necessary 5. Returns filtered, high-quality context Args: query: The search query search_func: Async function to perform the actual search subject: Optional subject context niveau: Optional niveau context initial_top_k: Number of documents for initial retrieval final_top_k: Maximum documents to return **search_kwargs: Additional args for search_func Returns: Dict with results, metadata, and reflection trace """ if not SELF_RAG_ENABLED: # Fall back to simple search results = await search_func(query=query, limit=final_top_k, **search_kwargs) return { "results": results, "self_rag_enabled": False, "query_used": query, } trace = [] current_query = query attempt = 1 while attempt <= MAX_RETRIEVAL_ATTEMPTS: # Step 1: Retrieve documents results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs) trace.append({ "attempt": attempt, "query": current_query, "retrieved_count": len(results) if results else 0, }) if not results: attempt += 1 if attempt <= MAX_RETRIEVAL_ATTEMPTS: current_query = await reformulate_query( query, context=f"Fach: {subject}" if subject else None, previous_results_summary="Keine Dokumente gefunden" ) trace[-1]["action"] = "reformulate" trace[-1]["new_query"] = current_query continue # Step 2: Grade documents for relevance relevant, filtered = await filter_relevant_documents(current_query, results) trace[-1]["relevant_count"] = len(relevant) trace[-1]["filtered_count"] = len(filtered) # Step 3: Decide what to do decision, decision_meta = await decide_retrieval_strategy( current_query, relevant, attempt ) trace[-1]["decision"] = decision.value trace[-1]["decision_meta"] = decision_meta if decision == RetrievalDecision.SUFFICIENT: # We have good context, return it return { "results": relevant[:final_top_k], "self_rag_enabled": True, "query_used": current_query, "original_query": query if current_query != query else None, "attempts": attempt, "decision": decision.value, "trace": trace, "filtered_out_count": len(filtered), } elif decision == RetrievalDecision.REFORMULATE: # Reformulate and try again avg_score = decision_meta.get("avg_relevance", 0) current_query = await reformulate_query( query, context=f"Fach: {subject}" if subject else None, previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}" ) trace[-1]["action"] = "reformulate" trace[-1]["new_query"] = current_query elif decision == RetrievalDecision.NEEDS_MORE: # Retrieve more with expanded query current_query = f"{current_query} Bewertungskriterien Anforderungen" trace[-1]["action"] = "expand_query" trace[-1]["new_query"] = current_query elif decision == RetrievalDecision.FALLBACK: # Return what we have, even if not ideal return { "results": (relevant or results)[:final_top_k], "self_rag_enabled": True, "query_used": current_query, "original_query": query if current_query != query else None, "attempts": attempt, "decision": decision.value, "warning": "Fallback mode - low relevance context", "trace": trace, } attempt += 1 # Max attempts reached return { "results": results[:final_top_k] if results else [], "self_rag_enabled": True, "query_used": current_query, "original_query": query if current_query != query else None, "attempts": attempt - 1, "decision": "max_attempts", "warning": "Max retrieval attempts reached", "trace": trace, } def get_self_rag_info() -> dict: """Get information about Self-RAG configuration.""" return { "enabled": SELF_RAG_ENABLED, "llm_configured": bool(OPENAI_API_KEY), "model": SELF_RAG_MODEL, "relevance_threshold": RELEVANCE_THRESHOLD, "grounding_threshold": GROUNDING_THRESHOLD, "max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS, "features": [ "document_grading", "relevance_filtering", "query_reformulation", "answer_grounding_check", "retrieval_decision", ], "sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI "privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading", "default_enabled": False, # Disabled by default for privacy }