[split-required] Split final 43 files (500-668 LOC) to complete refactoring
klausur-service (11 files): - cv_gutter_repair, ocr_pipeline_regression, upload_api - ocr_pipeline_sessions, smart_spell, nru_worksheet_generator - ocr_pipeline_overlays, mail/aggregator, zeugnis_api - cv_syllable_detect, self_rag backend-lehrer (17 files): - classroom_engine/suggestions, generators/quiz_generator - worksheets_api, llm_gateway/comparison, state_engine_api - classroom/models (→ 4 submodules), services/file_processor - alerts_agent/api/wizard+digests+routes, content_generators/pdf - classroom/routes/sessions, llm_gateway/inference - classroom_engine/analytics, auth/keycloak_auth - alerts_agent/processing/rule_engine, ai_processor/print_versions agent-core (5 files): - brain/memory_store, brain/knowledge_graph, brain/context_manager - orchestrator/supervisor, sessions/session_manager admin-lehrer (5 components): - GridOverlay, StepGridReview, DevOpsPipelineSidebar - DataFlowDiagram, sbom/wizard/page website (2 files): - DependencyMap, lehrer/abitur-archiv Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
285
klausur-service/backend/self_rag_grading.py
Normal file
285
klausur-service/backend/self_rag_grading.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Self-RAG Grading — document relevance grading, filtering, retrieval decisions.
|
||||
|
||||
Extracted from self_rag.py for modularity.
|
||||
|
||||
Based on research:
|
||||
- Self-RAG (Asai et al., 2023)
|
||||
- Corrective RAG (Yan et al., 2024)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
||||
# Configuration
|
||||
SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
|
||||
|
||||
# Thresholds for self-reflection
|
||||
RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
|
||||
GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
|
||||
MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
|
||||
|
||||
|
||||
class RetrievalDecision(Enum):
|
||||
"""Decision after grading retrieval."""
|
||||
SUFFICIENT = "sufficient" # Context is good, proceed to generation
|
||||
NEEDS_MORE = "needs_more" # Need to retrieve more documents
|
||||
REFORMULATE = "reformulate" # Query needs reformulation
|
||||
FALLBACK = "fallback" # Use fallback (no good context found)
|
||||
|
||||
|
||||
class SelfRAGError(Exception):
|
||||
"""Error during Self-RAG processing."""
|
||||
pass
|
||||
|
||||
|
||||
async def grade_document_relevance(
|
||||
query: str,
|
||||
document: str,
|
||||
) -> Tuple[float, str]:
|
||||
"""
|
||||
Grade whether a document is relevant to the query.
|
||||
|
||||
Returns a score between 0 (irrelevant) and 1 (highly relevant)
|
||||
along with an explanation.
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
# Fallback: simple keyword overlap
|
||||
query_words = set(query.lower().split())
|
||||
doc_words = set(document.lower().split())
|
||||
overlap = len(query_words & doc_words) / max(len(query_words), 1)
|
||||
return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
|
||||
|
||||
prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist.
|
||||
|
||||
SUCHANFRAGE: {query}
|
||||
|
||||
DOKUMENT:
|
||||
{document[:2000]}
|
||||
|
||||
Ist dieses Dokument relevant, um die Anfrage zu beantworten?
|
||||
Beruecksichtige:
|
||||
- Thematische Uebereinstimmung
|
||||
- Enthaelt das Dokument spezifische Informationen zur Anfrage?
|
||||
- Wuerde dieses Dokument bei der Beantwortung helfen?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0]
|
||||
BEGRUENDUNG: [Kurze Erklaerung]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, f"API error: {response.status_code}"
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL)
|
||||
reason = reason_match.group(1).strip() if reason_match else result
|
||||
|
||||
return min(max(score, 0.0), 1.0), reason
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, f"Grading error: {str(e)}"
|
||||
|
||||
|
||||
async def grade_documents_batch(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
) -> List[Tuple[float, str]]:
|
||||
"""
|
||||
Grade multiple documents for relevance.
|
||||
|
||||
Returns list of (score, reason) tuples.
|
||||
"""
|
||||
results = []
|
||||
for doc in documents:
|
||||
score, reason = await grade_document_relevance(query, doc)
|
||||
results.append((score, reason))
|
||||
return results
|
||||
|
||||
|
||||
async def filter_relevant_documents(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
threshold: float = RELEVANCE_THRESHOLD,
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Filter documents by relevance, separating relevant from irrelevant.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of document dicts with 'text' field
|
||||
threshold: Minimum relevance score to keep
|
||||
|
||||
Returns:
|
||||
Tuple of (relevant_docs, filtered_out_docs)
|
||||
"""
|
||||
relevant = []
|
||||
filtered = []
|
||||
|
||||
for doc in documents:
|
||||
text = doc.get("text", "")
|
||||
score, reason = await grade_document_relevance(query, text)
|
||||
|
||||
doc_with_grade = doc.copy()
|
||||
doc_with_grade["relevance_score"] = score
|
||||
doc_with_grade["relevance_reason"] = reason
|
||||
|
||||
if score >= threshold:
|
||||
relevant.append(doc_with_grade)
|
||||
else:
|
||||
filtered.append(doc_with_grade)
|
||||
|
||||
# Sort relevant by score
|
||||
relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
|
||||
|
||||
return relevant, filtered
|
||||
|
||||
|
||||
async def decide_retrieval_strategy(
|
||||
query: str,
|
||||
documents: List[Dict],
|
||||
attempt: int = 1,
|
||||
) -> Tuple[RetrievalDecision, Dict]:
|
||||
"""
|
||||
Decide what to do based on current retrieval results.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: Retrieved documents with relevance scores
|
||||
attempt: Current retrieval attempt number
|
||||
|
||||
Returns:
|
||||
Tuple of (decision, metadata)
|
||||
"""
|
||||
if not documents:
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
|
||||
return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
|
||||
|
||||
# Check average relevance
|
||||
scores = [doc.get("relevance_score", 0.5) for doc in documents]
|
||||
avg_score = sum(scores) / len(scores)
|
||||
max_score = max(scores)
|
||||
|
||||
if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"max_relevance": max_score,
|
||||
"doc_count": len(documents),
|
||||
}
|
||||
|
||||
if attempt >= MAX_RETRIEVAL_ATTEMPTS:
|
||||
if max_score >= RELEVANCE_THRESHOLD * 0.5:
|
||||
# At least some relevant context, proceed with caution
|
||||
return RetrievalDecision.SUFFICIENT, {
|
||||
"avg_relevance": avg_score,
|
||||
"warning": "Low relevance after max attempts",
|
||||
}
|
||||
return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
|
||||
|
||||
if avg_score < 0.3:
|
||||
return RetrievalDecision.REFORMULATE, {
|
||||
"reason": "Very low relevance, query reformulation needed",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
return RetrievalDecision.NEEDS_MORE, {
|
||||
"reason": "Moderate relevance, retrieving more documents",
|
||||
"avg_relevance": avg_score,
|
||||
}
|
||||
|
||||
|
||||
async def grade_answer_groundedness(
|
||||
answer: str,
|
||||
contexts: List[str],
|
||||
) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
Grade whether an answer is grounded in the provided contexts.
|
||||
|
||||
Returns:
|
||||
Tuple of (grounding_score, list of unsupported claims)
|
||||
"""
|
||||
if not OPENAI_API_KEY:
|
||||
return 0.5, ["LLM not configured for grounding check"]
|
||||
|
||||
context_text = "\n---\n".join(contexts[:5])
|
||||
|
||||
prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird.
|
||||
|
||||
KONTEXTE:
|
||||
{context_text}
|
||||
|
||||
ANTWORT:
|
||||
{answer}
|
||||
|
||||
Identifiziere:
|
||||
1. Welche Aussagen sind durch die Kontexte belegt?
|
||||
2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
|
||||
|
||||
Antworte im Format:
|
||||
SCORE: [0.0-1.0] (1.0 = vollstaendig belegt)
|
||||
NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": SELF_RAG_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return 0.5, [f"API error: {response.status_code}"]
|
||||
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
import re
|
||||
score_match = re.search(r'SCORE:\s*([\d.]+)', result)
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
|
||||
unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
|
||||
unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
|
||||
|
||||
if unsupported_text.lower() == "keine":
|
||||
unsupported = []
|
||||
else:
|
||||
unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
|
||||
|
||||
return min(max(score, 0.0), 1.0), unsupported
|
||||
|
||||
except Exception as e:
|
||||
return 0.5, [f"Grounding check error: {str(e)}"]
|
||||
Reference in New Issue
Block a user