Files
Benjamin Admin c52dbdb8f1
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
feat(rag): optimize RAG pipeline — JSON-Mode, CoT, Hybrid Search, Re-Ranking, Cross-Reg Dedup, chunk 1024
Phase 1 (LLM Quality):
- Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill)
- Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts

Phase 2 (Retrieval Quality):
- Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go)
- Fallback to dense-only search if Query API unavailable
- Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default)
- CPU-only PyTorch dependency to keep Docker image small

Phase 3 (Data Layer):
- Cross-regulation dedup pass (threshold 0.95) links controls across regulations
- DedupResult.link_type field distinguishes dedup_merge vs cross_regulation
- Chunk size defaults updated 512/50 → 1024/128 for new ingestions only
- Existing collections and controls are NOT affected

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-21 11:49:43 +01:00

214 lines
6.8 KiB
Python

"""
Compliance RAG Client — Proxy to Go SDK RAG Search.
Lightweight HTTP client that queries the Go AI Compliance SDK's
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
models or direct Qdrant access in Python.
Error-tolerant: RAG failures never break the calling function.
"""
import logging
import os
from dataclasses import dataclass
from typing import List, Optional
import httpx
logger = logging.getLogger(__name__)
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
RAG_SEARCH_TIMEOUT = 15.0 # seconds
@dataclass
class RAGSearchResult:
"""A single search result from the compliance corpus."""
text: str
regulation_code: str
regulation_name: str
regulation_short: str
category: str
article: str
paragraph: str
source_url: str
score: float
collection: str = ""
class ComplianceRAGClient:
"""
RAG client that proxies search requests to the Go SDK.
Usage:
client = get_rag_client()
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
context_str = client.format_for_prompt(results)
"""
def __init__(self, base_url: str = SDK_URL):
self._search_url = f"{base_url}/sdk/v1/rag/search"
async def search(
self,
query: str,
collection: str = "bp_compliance_ce",
regulations: Optional[List[str]] = None,
top_k: int = 5,
) -> List[RAGSearchResult]:
"""
Search the RAG corpus via Go SDK.
Returns an empty list on any error (never raises).
"""
payload = {
"query": query,
"collection": collection,
"top_k": top_k,
}
if regulations:
payload["regulations"] = regulations
try:
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
resp = await client.post(self._search_url, json=payload)
if resp.status_code != 200:
logger.warning(
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
)
return []
data = resp.json()
results = []
for r in data.get("results", []):
results.append(RAGSearchResult(
text=r.get("text", ""),
regulation_code=r.get("regulation_code", ""),
regulation_name=r.get("regulation_name", ""),
regulation_short=r.get("regulation_short", ""),
category=r.get("category", ""),
article=r.get("article", ""),
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=r.get("score", 0.0),
collection=collection,
))
return results
except Exception as e:
logger.warning("RAG search failed: %s", e)
return []
async def search_with_rerank(
self,
query: str,
collection: str = "bp_compliance_ce",
regulations: Optional[List[str]] = None,
top_k: int = 5,
) -> List[RAGSearchResult]:
"""
Search with optional cross-encoder re-ranking.
Fetches top_k*4 results from RAG, then re-ranks with cross-encoder
and returns top_k. Falls back to regular search if reranker is disabled.
"""
from .reranker import get_reranker
reranker = get_reranker()
if reranker is None:
return await self.search(query, collection, regulations, top_k)
# Fetch more candidates for re-ranking
candidates = await self.search(
query, collection, regulations, top_k=max(top_k * 4, 20)
)
if not candidates:
return []
texts = [c.text for c in candidates]
try:
ranked_indices = reranker.rerank(query, texts, top_k=top_k)
return [candidates[i] for i in ranked_indices]
except Exception as e:
logger.warning("Reranking failed, returning unranked: %s", e)
return candidates[:top_k]
async def scroll(
self,
collection: str,
offset: Optional[str] = None,
limit: int = 100,
) -> tuple[List[RAGSearchResult], Optional[str]]:
"""
Scroll through ALL chunks in a collection (paginated).
Returns (chunks, next_offset). next_offset is None when done.
"""
scroll_url = self._search_url.replace("/search", "/scroll")
params = {"collection": collection, "limit": str(limit)}
if offset:
params["offset"] = offset
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(scroll_url, params=params)
if resp.status_code != 200:
logger.warning(
"RAG scroll returned %d: %s", resp.status_code, resp.text[:200]
)
return [], None
data = resp.json()
results = []
for r in data.get("chunks", []):
results.append(RAGSearchResult(
text=r.get("text", ""),
regulation_code=r.get("regulation_code", ""),
regulation_name=r.get("regulation_name", ""),
regulation_short=r.get("regulation_short", ""),
category=r.get("category", ""),
article=r.get("article", ""),
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=0.0,
collection=collection,
))
next_offset = data.get("next_offset") or None
return results, next_offset
except Exception as e:
logger.warning("RAG scroll failed: %s", e)
return [], None
def format_for_prompt(
self, results: List[RAGSearchResult], max_results: int = 5
) -> str:
"""Format search results as Markdown for inclusion in an LLM prompt."""
if not results:
return ""
lines = ["## Relevanter Rechtskontext\n"]
for i, r in enumerate(results[:max_results]):
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
if r.article:
header += f"{r.article}"
lines.append(header)
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
lines.append(f" > {text}\n")
return "\n".join(lines)
# Singleton
_rag_client: Optional[ComplianceRAGClient] = None
def get_rag_client() -> ComplianceRAGClient:
"""Get the shared RAG client instance."""
global _rag_client
if _rag_client is None:
_rag_client = ComplianceRAGClient()
return _rag_client