Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
149 lines
4.8 KiB
Python
149 lines
4.8 KiB
Python
"""
|
|
Re-Ranking Module for RAG Quality Improvement
|
|
|
|
NOTE: This module delegates ML-heavy operations to the embedding-service via HTTP.
|
|
|
|
Implements two-stage retrieval:
|
|
1. Initial retrieval with bi-encoder (fast, many results)
|
|
2. Re-ranking with cross-encoder (slower, but much more accurate)
|
|
|
|
This consistently improves RAG accuracy by 20-35% and reduces hallucinations.
|
|
|
|
Supported re-rankers (configured in embedding-service):
|
|
- local: sentence-transformers cross-encoder (default, no API key needed)
|
|
- cohere: Cohere Rerank API (requires COHERE_API_KEY)
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from typing import List, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration (for backward compatibility - actual config in embedding-service)
|
|
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
|
EMBEDDING_SERVICE_TIMEOUT = float(os.getenv("EMBEDDING_SERVICE_TIMEOUT", "60.0"))
|
|
RERANKER_BACKEND = os.getenv("RERANKER_BACKEND", "local")
|
|
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
|
|
LOCAL_RERANKER_MODEL = os.getenv("LOCAL_RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
|
|
|
|
|
class RerankerError(Exception):
|
|
"""Error during re-ranking."""
|
|
pass
|
|
|
|
|
|
async def rerank_documents(
|
|
query: str,
|
|
documents: List[str],
|
|
top_k: int = 5
|
|
) -> List[Tuple[int, float, str]]:
|
|
"""
|
|
Re-rank documents using embedding-service.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of document texts to re-rank
|
|
top_k: Number of top results to return
|
|
|
|
Returns:
|
|
List of (original_index, score, text) tuples, sorted by score descending
|
|
"""
|
|
if not documents:
|
|
return []
|
|
|
|
import httpx
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=EMBEDDING_SERVICE_TIMEOUT) as client:
|
|
response = await client.post(
|
|
f"{EMBEDDING_SERVICE_URL}/rerank",
|
|
json={
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_k": top_k
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return [
|
|
(r["index"], r["score"], r["text"])
|
|
for r in data["results"]
|
|
]
|
|
except httpx.TimeoutException:
|
|
raise RerankerError("Embedding service timeout during re-ranking")
|
|
except httpx.HTTPStatusError as e:
|
|
raise RerankerError(f"Re-ranking error: {e.response.status_code} - {e.response.text}")
|
|
except Exception as e:
|
|
raise RerankerError(f"Failed to re-rank documents: {e}")
|
|
|
|
|
|
async def rerank_search_results(
|
|
query: str,
|
|
results: List[dict],
|
|
text_field: str = "text",
|
|
top_k: int = 5
|
|
) -> List[dict]:
|
|
"""
|
|
Re-rank search results (dictionaries with text field).
|
|
|
|
Convenience function for re-ranking Qdrant search results.
|
|
|
|
Args:
|
|
query: The search query
|
|
results: List of search result dicts
|
|
text_field: Key in dict containing the text to rank on
|
|
top_k: Number of top results to return
|
|
|
|
Returns:
|
|
Re-ranked list of search result dicts with added 'rerank_score' field
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
texts = [r.get(text_field, "") for r in results]
|
|
reranked = await rerank_documents(query, texts, top_k)
|
|
|
|
reranked_results = []
|
|
for orig_idx, score, _ in reranked:
|
|
result = results[orig_idx].copy()
|
|
result["rerank_score"] = score
|
|
result["original_rank"] = orig_idx
|
|
reranked_results.append(result)
|
|
|
|
return reranked_results
|
|
|
|
|
|
def get_reranker_info() -> dict:
|
|
"""Get information about the current reranker configuration."""
|
|
import httpx
|
|
|
|
try:
|
|
with httpx.Client(timeout=5.0) as client:
|
|
response = client.get(f"{EMBEDDING_SERVICE_URL}/models")
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return {
|
|
"backend": data.get("reranker_backend", RERANKER_BACKEND),
|
|
"model": data.get("reranker_model", LOCAL_RERANKER_MODEL),
|
|
"model_license": "See embedding-service",
|
|
"commercial_safe": True,
|
|
"cohere_configured": bool(COHERE_API_KEY),
|
|
"embedding_service_url": EMBEDDING_SERVICE_URL,
|
|
"embedding_service_available": True,
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Could not reach embedding-service: {e}")
|
|
|
|
# Fallback when embedding-service is not available
|
|
return {
|
|
"backend": RERANKER_BACKEND,
|
|
"model": LOCAL_RERANKER_MODEL,
|
|
"model_license": "Unknown (embedding-service unavailable)",
|
|
"commercial_safe": True,
|
|
"cohere_configured": bool(COHERE_API_KEY),
|
|
"embedding_service_url": EMBEDDING_SERVICE_URL,
|
|
"embedding_service_available": False,
|
|
}
|