""" Self-RAG Grading — document relevance grading, filtering, retrieval decisions. Extracted from self_rag.py for modularity. Based on research: - Self-RAG (Asai et al., 2023) - Corrective RAG (Yan et al., 2024) """ import os from typing import List, Dict, Optional, Tuple from enum import Enum import httpx # Configuration SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true" OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini") # Thresholds for self-reflection RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6")) GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7")) MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2")) class RetrievalDecision(Enum): """Decision after grading retrieval.""" SUFFICIENT = "sufficient" # Context is good, proceed to generation NEEDS_MORE = "needs_more" # Need to retrieve more documents REFORMULATE = "reformulate" # Query needs reformulation FALLBACK = "fallback" # Use fallback (no good context found) class SelfRAGError(Exception): """Error during Self-RAG processing.""" pass async def grade_document_relevance( query: str, document: str, ) -> Tuple[float, str]: """ Grade whether a document is relevant to the query. Returns a score between 0 (irrelevant) and 1 (highly relevant) along with an explanation. """ if not OPENAI_API_KEY: # Fallback: simple keyword overlap query_words = set(query.lower().split()) doc_words = set(document.lower().split()) overlap = len(query_words & doc_words) / max(len(query_words), 1) return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)" prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist. SUCHANFRAGE: {query} DOKUMENT: {document[:2000]} Ist dieses Dokument relevant, um die Anfrage zu beantworten? Beruecksichtige: - Thematische Uebereinstimmung - Enthaelt das Dokument spezifische Informationen zur Anfrage? - Wuerde dieses Dokument bei der Beantwortung helfen? Antworte im Format: SCORE: [0.0-1.0] BEGRUENDUNG: [Kurze Erklaerung]""" try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers={ "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": SELF_RAG_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 150, "temperature": 0.0, }, timeout=30.0 ) if response.status_code != 200: return 0.5, f"API error: {response.status_code}" result = response.json()["choices"][0]["message"]["content"] import re score_match = re.search(r'SCORE:\s*([\d.]+)', result) score = float(score_match.group(1)) if score_match else 0.5 reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL) reason = reason_match.group(1).strip() if reason_match else result return min(max(score, 0.0), 1.0), reason except Exception as e: return 0.5, f"Grading error: {str(e)}" async def grade_documents_batch( query: str, documents: List[str], ) -> List[Tuple[float, str]]: """ Grade multiple documents for relevance. Returns list of (score, reason) tuples. """ results = [] for doc in documents: score, reason = await grade_document_relevance(query, doc) results.append((score, reason)) return results async def filter_relevant_documents( query: str, documents: List[Dict], threshold: float = RELEVANCE_THRESHOLD, ) -> Tuple[List[Dict], List[Dict]]: """ Filter documents by relevance, separating relevant from irrelevant. Args: query: The search query documents: List of document dicts with 'text' field threshold: Minimum relevance score to keep Returns: Tuple of (relevant_docs, filtered_out_docs) """ relevant = [] filtered = [] for doc in documents: text = doc.get("text", "") score, reason = await grade_document_relevance(query, text) doc_with_grade = doc.copy() doc_with_grade["relevance_score"] = score doc_with_grade["relevance_reason"] = reason if score >= threshold: relevant.append(doc_with_grade) else: filtered.append(doc_with_grade) # Sort relevant by score relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True) return relevant, filtered async def decide_retrieval_strategy( query: str, documents: List[Dict], attempt: int = 1, ) -> Tuple[RetrievalDecision, Dict]: """ Decide what to do based on current retrieval results. Args: query: The search query documents: Retrieved documents with relevance scores attempt: Current retrieval attempt number Returns: Tuple of (decision, metadata) """ if not documents: if attempt >= MAX_RETRIEVAL_ATTEMPTS: return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"} return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"} # Check average relevance scores = [doc.get("relevance_score", 0.5) for doc in documents] avg_score = sum(scores) / len(scores) max_score = max(scores) if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7: return RetrievalDecision.SUFFICIENT, { "avg_relevance": avg_score, "max_relevance": max_score, "doc_count": len(documents), } if attempt >= MAX_RETRIEVAL_ATTEMPTS: if max_score >= RELEVANCE_THRESHOLD * 0.5: # At least some relevant context, proceed with caution return RetrievalDecision.SUFFICIENT, { "avg_relevance": avg_score, "warning": "Low relevance after max attempts", } return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"} if avg_score < 0.3: return RetrievalDecision.REFORMULATE, { "reason": "Very low relevance, query reformulation needed", "avg_relevance": avg_score, } return RetrievalDecision.NEEDS_MORE, { "reason": "Moderate relevance, retrieving more documents", "avg_relevance": avg_score, } async def grade_answer_groundedness( answer: str, contexts: List[str], ) -> Tuple[float, List[str]]: """ Grade whether an answer is grounded in the provided contexts. Returns: Tuple of (grounding_score, list of unsupported claims) """ if not OPENAI_API_KEY: return 0.5, ["LLM not configured for grounding check"] context_text = "\n---\n".join(contexts[:5]) prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird. KONTEXTE: {context_text} ANTWORT: {answer} Identifiziere: 1. Welche Aussagen sind durch die Kontexte belegt? 2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)? Antworte im Format: SCORE: [0.0-1.0] (1.0 = vollstaendig belegt) NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]""" try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/chat/completions", headers={ "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": SELF_RAG_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 300, "temperature": 0.0, }, timeout=30.0 ) if response.status_code != 200: return 0.5, [f"API error: {response.status_code}"] result = response.json()["choices"][0]["message"]["content"] import re score_match = re.search(r'SCORE:\s*([\d.]+)', result) score = float(score_match.group(1)) if score_match else 0.5 unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL) unsupported_text = unsupported_match.group(1).strip() if unsupported_match else "" if unsupported_text.lower() == "keine": unsupported = [] else: unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()] return min(max(score, 0.0), 1.0), unsupported except Exception as e: return 0.5, [f"Grounding check error: {str(e)}"]