Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
530 lines
18 KiB
Python
530 lines
18 KiB
Python
"""
|
|
Self-RAG / Corrective RAG Module
|
|
|
|
Implements self-reflective RAG that can:
|
|
1. Grade retrieved documents for relevance
|
|
2. Decide if more retrieval is needed
|
|
3. Reformulate queries if initial retrieval fails
|
|
4. Filter irrelevant passages before generation
|
|
5. Grade answers for groundedness and hallucination
|
|
|
|
Based on research:
|
|
- Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique
|
|
- Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation
|
|
|
|
This is especially useful for German educational documents where:
|
|
- Queries may use informal language
|
|
- Documents use formal/technical terminology
|
|
- Context must be precisely matched to scoring criteria
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Dict, Optional, Tuple
|
|
from enum import Enum
|
|
import httpx
|
|
|
|
# Configuration
|
|
# IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
|
|
# When enabled, search queries and retrieved documents are sent to OpenAI API
|
|
# for relevance grading and query reformulation. This exposes user data to third parties.
|
|
# Only enable if you have explicit user consent for data processing.
|
|
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
|
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
|
|
|
|
# Thresholds for self-reflection
|
|
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
|
|
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
|
|
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
|
|
|
|
|
|
class RetrievalDecision(Enum):
|
|
"""Decision after grading retrieval."""
|
|
SUFFICIENT = "sufficient" # Context is good, proceed to generation
|
|
NEEDS_MORE = "needs_more" # Need to retrieve more documents
|
|
REFORMULATE = "reformulate" # Query needs reformulation
|
|
FALLBACK = "fallback" # Use fallback (no good context found)
|
|
|
|
|
|
class SelfRAGError(Exception):
|
|
"""Error during Self-RAG processing."""
|
|
pass
|
|
|
|
|
|
async def grade_document_relevance(
|
|
query: str,
|
|
document: str,
|
|
) -> Tuple[float, str]:
|
|
"""
|
|
Grade whether a document is relevant to the query.
|
|
|
|
Returns a score between 0 (irrelevant) and 1 (highly relevant)
|
|
along with an explanation.
|
|
"""
|
|
if not OPENAI_API_KEY:
|
|
# Fallback: simple keyword overlap
|
|
query_words = set(query.lower().split())
|
|
doc_words = set(document.lower().split())
|
|
overlap = len(query_words & doc_words) / max(len(query_words), 1)
|
|
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
|
|
|
|
prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist.
|
|
|
|
SUCHANFRAGE: {query}
|
|
|
|
DOKUMENT:
|
|
{document[:2000]}
|
|
|
|
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
|
|
Berücksichtige:
|
|
- Thematische Übereinstimmung
|
|
- Enthält das Dokument spezifische Informationen zur Anfrage?
|
|
- Würde dieses Dokument bei der Beantwortung helfen?
|
|
|
|
Antworte im Format:
|
|
SCORE: [0.0-1.0]
|
|
BEGRÜNDUNG: [Kurze Erklärung]"""
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": SELF_RAG_MODEL,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 150,
|
|
"temperature": 0.0,
|
|
},
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return 0.5, f"API error: {response.status_code}"
|
|
|
|
result = response.json()["choices"][0]["message"]["content"]
|
|
|
|
import re
|
|
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
|
score = float(score_match.group(1)) if score_match else 0.5
|
|
|
|
reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL)
|
|
reason = reason_match.group(1).strip() if reason_match else result
|
|
|
|
return min(max(score, 0.0), 1.0), reason
|
|
|
|
except Exception as e:
|
|
return 0.5, f"Grading error: {str(e)}"
|
|
|
|
|
|
async def grade_documents_batch(
|
|
query: str,
|
|
documents: List[str],
|
|
) -> List[Tuple[float, str]]:
|
|
"""
|
|
Grade multiple documents for relevance.
|
|
|
|
Returns list of (score, reason) tuples.
|
|
"""
|
|
results = []
|
|
for doc in documents:
|
|
score, reason = await grade_document_relevance(query, doc)
|
|
results.append((score, reason))
|
|
return results
|
|
|
|
|
|
async def filter_relevant_documents(
|
|
query: str,
|
|
documents: List[Dict],
|
|
threshold: float = RELEVANCE_THRESHOLD,
|
|
) -> Tuple[List[Dict], List[Dict]]:
|
|
"""
|
|
Filter documents by relevance, separating relevant from irrelevant.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of document dicts with 'text' field
|
|
threshold: Minimum relevance score to keep
|
|
|
|
Returns:
|
|
Tuple of (relevant_docs, filtered_out_docs)
|
|
"""
|
|
relevant = []
|
|
filtered = []
|
|
|
|
for doc in documents:
|
|
text = doc.get("text", "")
|
|
score, reason = await grade_document_relevance(query, text)
|
|
|
|
doc_with_grade = doc.copy()
|
|
doc_with_grade["relevance_score"] = score
|
|
doc_with_grade["relevance_reason"] = reason
|
|
|
|
if score >= threshold:
|
|
relevant.append(doc_with_grade)
|
|
else:
|
|
filtered.append(doc_with_grade)
|
|
|
|
# Sort relevant by score
|
|
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
|
|
|
|
return relevant, filtered
|
|
|
|
|
|
async def decide_retrieval_strategy(
|
|
query: str,
|
|
documents: List[Dict],
|
|
attempt: int = 1,
|
|
) -> Tuple[RetrievalDecision, Dict]:
|
|
"""
|
|
Decide what to do based on current retrieval results.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: Retrieved documents with relevance scores
|
|
attempt: Current retrieval attempt number
|
|
|
|
Returns:
|
|
Tuple of (decision, metadata)
|
|
"""
|
|
if not documents:
|
|
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
|
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
|
|
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
|
|
|
|
# Check average relevance
|
|
scores = [doc.get("relevance_score", 0.5) for doc in documents]
|
|
avg_score = sum(scores) / len(scores)
|
|
max_score = max(scores)
|
|
|
|
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
|
|
return RetrievalDecision.SUFFICIENT, {
|
|
"avg_relevance": avg_score,
|
|
"max_relevance": max_score,
|
|
"doc_count": len(documents),
|
|
}
|
|
|
|
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
|
if max_score >= RELEVANCE_THRESHOLD * 0.5:
|
|
# At least some relevant context, proceed with caution
|
|
return RetrievalDecision.SUFFICIENT, {
|
|
"avg_relevance": avg_score,
|
|
"warning": "Low relevance after max attempts",
|
|
}
|
|
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
|
|
|
|
if avg_score < 0.3:
|
|
return RetrievalDecision.REFORMULATE, {
|
|
"reason": "Very low relevance, query reformulation needed",
|
|
"avg_relevance": avg_score,
|
|
}
|
|
|
|
return RetrievalDecision.NEEDS_MORE, {
|
|
"reason": "Moderate relevance, retrieving more documents",
|
|
"avg_relevance": avg_score,
|
|
}
|
|
|
|
|
|
async def reformulate_query(
|
|
original_query: str,
|
|
context: Optional[str] = None,
|
|
previous_results_summary: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Reformulate a query to improve retrieval.
|
|
|
|
Uses LLM to generate a better query based on:
|
|
- Original query
|
|
- Optional context (subject, niveau, etc.)
|
|
- Summary of why previous retrieval failed
|
|
"""
|
|
if not OPENAI_API_KEY:
|
|
# Simple reformulation: expand abbreviations, add synonyms
|
|
reformulated = original_query
|
|
expansions = {
|
|
"EA": "erhöhtes Anforderungsniveau",
|
|
"eA": "erhöhtes Anforderungsniveau",
|
|
"GA": "grundlegendes Anforderungsniveau",
|
|
"gA": "grundlegendes Anforderungsniveau",
|
|
"AFB": "Anforderungsbereich",
|
|
"Abi": "Abitur",
|
|
}
|
|
for abbr, expansion in expansions.items():
|
|
if abbr in original_query:
|
|
reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
|
|
return reformulated
|
|
|
|
prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen.
|
|
|
|
Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
|
|
ORIGINAL: {original_query}
|
|
|
|
{f"KONTEXT: {context}" if context else ""}
|
|
{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
|
|
|
|
Formuliere die Anfrage so um, dass sie:
|
|
1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
|
|
2. Relevante Synonyme oder verwandte Begriffe einschließt
|
|
3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
|
|
|
|
Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung."""
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": SELF_RAG_MODEL,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 100,
|
|
"temperature": 0.3,
|
|
},
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return original_query
|
|
|
|
return response.json()["choices"][0]["message"]["content"].strip()
|
|
|
|
except Exception:
|
|
return original_query
|
|
|
|
|
|
async def grade_answer_groundedness(
|
|
answer: str,
|
|
contexts: List[str],
|
|
) -> Tuple[float, List[str]]:
|
|
"""
|
|
Grade whether an answer is grounded in the provided contexts.
|
|
|
|
Returns:
|
|
Tuple of (grounding_score, list of unsupported claims)
|
|
"""
|
|
if not OPENAI_API_KEY:
|
|
return 0.5, ["LLM not configured for grounding check"]
|
|
|
|
context_text = "\n---\n".join(contexts[:5])
|
|
|
|
prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird.
|
|
|
|
KONTEXTE:
|
|
{context_text}
|
|
|
|
ANTWORT:
|
|
{answer}
|
|
|
|
Identifiziere:
|
|
1. Welche Aussagen sind durch die Kontexte belegt?
|
|
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
|
|
|
|
Antworte im Format:
|
|
SCORE: [0.0-1.0] (1.0 = vollständig belegt)
|
|
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": SELF_RAG_MODEL,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 300,
|
|
"temperature": 0.0,
|
|
},
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return 0.5, [f"API error: {response.status_code}"]
|
|
|
|
result = response.json()["choices"][0]["message"]["content"]
|
|
|
|
import re
|
|
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
|
score = float(score_match.group(1)) if score_match else 0.5
|
|
|
|
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
|
|
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
|
|
|
|
if unsupported_text.lower() == "keine":
|
|
unsupported = []
|
|
else:
|
|
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
|
|
|
|
return min(max(score, 0.0), 1.0), unsupported
|
|
|
|
except Exception as e:
|
|
return 0.5, [f"Grounding check error: {str(e)}"]
|
|
|
|
|
|
async def self_rag_retrieve(
|
|
query: str,
|
|
search_func,
|
|
subject: Optional[str] = None,
|
|
niveau: Optional[str] = None,
|
|
initial_top_k: int = 10,
|
|
final_top_k: int = 5,
|
|
**search_kwargs
|
|
) -> Dict:
|
|
"""
|
|
Perform Self-RAG enhanced retrieval with reflection and correction.
|
|
|
|
This implements a retrieval loop that:
|
|
1. Retrieves initial documents
|
|
2. Grades them for relevance
|
|
3. Decides if more retrieval is needed
|
|
4. Reformulates query if necessary
|
|
5. Returns filtered, high-quality context
|
|
|
|
Args:
|
|
query: The search query
|
|
search_func: Async function to perform the actual search
|
|
subject: Optional subject context
|
|
niveau: Optional niveau context
|
|
initial_top_k: Number of documents for initial retrieval
|
|
final_top_k: Maximum documents to return
|
|
**search_kwargs: Additional args for search_func
|
|
|
|
Returns:
|
|
Dict with results, metadata, and reflection trace
|
|
"""
|
|
if not SELF_RAG_ENABLED:
|
|
# Fall back to simple search
|
|
results = await search_func(query=query, limit=final_top_k, **search_kwargs)
|
|
return {
|
|
"results": results,
|
|
"self_rag_enabled": False,
|
|
"query_used": query,
|
|
}
|
|
|
|
trace = []
|
|
current_query = query
|
|
attempt = 1
|
|
|
|
while attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
|
# Step 1: Retrieve documents
|
|
results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
|
|
|
|
trace.append({
|
|
"attempt": attempt,
|
|
"query": current_query,
|
|
"retrieved_count": len(results) if results else 0,
|
|
})
|
|
|
|
if not results:
|
|
attempt += 1
|
|
if attempt <= MAX_RETRIEVAL_ATTEMPTS:
|
|
current_query = await reformulate_query(
|
|
query,
|
|
context=f"Fach: {subject}" if subject else None,
|
|
previous_results_summary="Keine Dokumente gefunden"
|
|
)
|
|
trace[-1]["action"] = "reformulate"
|
|
trace[-1]["new_query"] = current_query
|
|
continue
|
|
|
|
# Step 2: Grade documents for relevance
|
|
relevant, filtered = await filter_relevant_documents(current_query, results)
|
|
|
|
trace[-1]["relevant_count"] = len(relevant)
|
|
trace[-1]["filtered_count"] = len(filtered)
|
|
|
|
# Step 3: Decide what to do
|
|
decision, decision_meta = await decide_retrieval_strategy(
|
|
current_query, relevant, attempt
|
|
)
|
|
|
|
trace[-1]["decision"] = decision.value
|
|
trace[-1]["decision_meta"] = decision_meta
|
|
|
|
if decision == RetrievalDecision.SUFFICIENT:
|
|
# We have good context, return it
|
|
return {
|
|
"results": relevant[:final_top_k],
|
|
"self_rag_enabled": True,
|
|
"query_used": current_query,
|
|
"original_query": query if current_query != query else None,
|
|
"attempts": attempt,
|
|
"decision": decision.value,
|
|
"trace": trace,
|
|
"filtered_out_count": len(filtered),
|
|
}
|
|
|
|
elif decision == RetrievalDecision.REFORMULATE:
|
|
# Reformulate and try again
|
|
avg_score = decision_meta.get("avg_relevance", 0)
|
|
current_query = await reformulate_query(
|
|
query,
|
|
context=f"Fach: {subject}" if subject else None,
|
|
previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
|
|
)
|
|
trace[-1]["action"] = "reformulate"
|
|
trace[-1]["new_query"] = current_query
|
|
|
|
elif decision == RetrievalDecision.NEEDS_MORE:
|
|
# Retrieve more with expanded query
|
|
current_query = f"{current_query} Bewertungskriterien Anforderungen"
|
|
trace[-1]["action"] = "expand_query"
|
|
trace[-1]["new_query"] = current_query
|
|
|
|
elif decision == RetrievalDecision.FALLBACK:
|
|
# Return what we have, even if not ideal
|
|
return {
|
|
"results": (relevant or results)[:final_top_k],
|
|
"self_rag_enabled": True,
|
|
"query_used": current_query,
|
|
"original_query": query if current_query != query else None,
|
|
"attempts": attempt,
|
|
"decision": decision.value,
|
|
"warning": "Fallback mode - low relevance context",
|
|
"trace": trace,
|
|
}
|
|
|
|
attempt += 1
|
|
|
|
# Max attempts reached
|
|
return {
|
|
"results": results[:final_top_k] if results else [],
|
|
"self_rag_enabled": True,
|
|
"query_used": current_query,
|
|
"original_query": query if current_query != query else None,
|
|
"attempts": attempt - 1,
|
|
"decision": "max_attempts",
|
|
"warning": "Max retrieval attempts reached",
|
|
"trace": trace,
|
|
}
|
|
|
|
|
|
def get_self_rag_info() -> dict:
|
|
"""Get information about Self-RAG configuration."""
|
|
return {
|
|
"enabled": SELF_RAG_ENABLED,
|
|
"llm_configured": bool(OPENAI_API_KEY),
|
|
"model": SELF_RAG_MODEL,
|
|
"relevance_threshold": RELEVANCE_THRESHOLD,
|
|
"grounding_threshold": GROUNDING_THRESHOLD,
|
|
"max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
|
|
"features": [
|
|
"document_grading",
|
|
"relevance_filtering",
|
|
"query_reformulation",
|
|
"answer_grounding_check",
|
|
"retrieval_decision",
|
|
],
|
|
"sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI
|
|
"privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
|
|
"default_enabled": False, # Disabled by default for privacy
|
|
}
|