Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
Phase 1 (LLM Quality): - Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill) - Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts Phase 2 (Retrieval Quality): - Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go) - Fallback to dense-only search if Query API unavailable - Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default) - CPU-only PyTorch dependency to keep Docker image small Phase 3 (Data Layer): - Cross-regulation dedup pass (threshold 0.95) links controls across regulations - DedupResult.link_type field distinguishes dedup_merge vs cross_regulation - Chunk size defaults updated 512/50 → 1024/128 for new ingestions only - Existing collections and controls are NOT affected Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
214 lines
6.8 KiB
Python
214 lines
6.8 KiB
Python
"""
|
|
Compliance RAG Client — Proxy to Go SDK RAG Search.
|
|
|
|
Lightweight HTTP client that queries the Go AI Compliance SDK's
|
|
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
|
|
models or direct Qdrant access in Python.
|
|
|
|
Error-tolerant: RAG failures never break the calling function.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
|
|
RAG_SEARCH_TIMEOUT = 15.0 # seconds
|
|
|
|
|
|
@dataclass
|
|
class RAGSearchResult:
|
|
"""A single search result from the compliance corpus."""
|
|
text: str
|
|
regulation_code: str
|
|
regulation_name: str
|
|
regulation_short: str
|
|
category: str
|
|
article: str
|
|
paragraph: str
|
|
source_url: str
|
|
score: float
|
|
collection: str = ""
|
|
|
|
|
|
class ComplianceRAGClient:
|
|
"""
|
|
RAG client that proxies search requests to the Go SDK.
|
|
|
|
Usage:
|
|
client = get_rag_client()
|
|
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
|
|
context_str = client.format_for_prompt(results)
|
|
"""
|
|
|
|
def __init__(self, base_url: str = SDK_URL):
|
|
self._search_url = f"{base_url}/sdk/v1/rag/search"
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
collection: str = "bp_compliance_ce",
|
|
regulations: Optional[List[str]] = None,
|
|
top_k: int = 5,
|
|
) -> List[RAGSearchResult]:
|
|
"""
|
|
Search the RAG corpus via Go SDK.
|
|
|
|
Returns an empty list on any error (never raises).
|
|
"""
|
|
payload = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"top_k": top_k,
|
|
}
|
|
if regulations:
|
|
payload["regulations"] = regulations
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
|
|
resp = await client.post(self._search_url, json=payload)
|
|
|
|
if resp.status_code != 200:
|
|
logger.warning(
|
|
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
|
|
)
|
|
return []
|
|
|
|
data = resp.json()
|
|
results = []
|
|
for r in data.get("results", []):
|
|
results.append(RAGSearchResult(
|
|
text=r.get("text", ""),
|
|
regulation_code=r.get("regulation_code", ""),
|
|
regulation_name=r.get("regulation_name", ""),
|
|
regulation_short=r.get("regulation_short", ""),
|
|
category=r.get("category", ""),
|
|
article=r.get("article", ""),
|
|
paragraph=r.get("paragraph", ""),
|
|
source_url=r.get("source_url", ""),
|
|
score=r.get("score", 0.0),
|
|
collection=collection,
|
|
))
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.warning("RAG search failed: %s", e)
|
|
return []
|
|
|
|
async def search_with_rerank(
|
|
self,
|
|
query: str,
|
|
collection: str = "bp_compliance_ce",
|
|
regulations: Optional[List[str]] = None,
|
|
top_k: int = 5,
|
|
) -> List[RAGSearchResult]:
|
|
"""
|
|
Search with optional cross-encoder re-ranking.
|
|
|
|
Fetches top_k*4 results from RAG, then re-ranks with cross-encoder
|
|
and returns top_k. Falls back to regular search if reranker is disabled.
|
|
"""
|
|
from .reranker import get_reranker
|
|
|
|
reranker = get_reranker()
|
|
if reranker is None:
|
|
return await self.search(query, collection, regulations, top_k)
|
|
|
|
# Fetch more candidates for re-ranking
|
|
candidates = await self.search(
|
|
query, collection, regulations, top_k=max(top_k * 4, 20)
|
|
)
|
|
if not candidates:
|
|
return []
|
|
|
|
texts = [c.text for c in candidates]
|
|
try:
|
|
ranked_indices = reranker.rerank(query, texts, top_k=top_k)
|
|
return [candidates[i] for i in ranked_indices]
|
|
except Exception as e:
|
|
logger.warning("Reranking failed, returning unranked: %s", e)
|
|
return candidates[:top_k]
|
|
|
|
async def scroll(
|
|
self,
|
|
collection: str,
|
|
offset: Optional[str] = None,
|
|
limit: int = 100,
|
|
) -> tuple[List[RAGSearchResult], Optional[str]]:
|
|
"""
|
|
Scroll through ALL chunks in a collection (paginated).
|
|
|
|
Returns (chunks, next_offset). next_offset is None when done.
|
|
"""
|
|
scroll_url = self._search_url.replace("/search", "/scroll")
|
|
params = {"collection": collection, "limit": str(limit)}
|
|
if offset:
|
|
params["offset"] = offset
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.get(scroll_url, params=params)
|
|
|
|
if resp.status_code != 200:
|
|
logger.warning(
|
|
"RAG scroll returned %d: %s", resp.status_code, resp.text[:200]
|
|
)
|
|
return [], None
|
|
|
|
data = resp.json()
|
|
results = []
|
|
for r in data.get("chunks", []):
|
|
results.append(RAGSearchResult(
|
|
text=r.get("text", ""),
|
|
regulation_code=r.get("regulation_code", ""),
|
|
regulation_name=r.get("regulation_name", ""),
|
|
regulation_short=r.get("regulation_short", ""),
|
|
category=r.get("category", ""),
|
|
article=r.get("article", ""),
|
|
paragraph=r.get("paragraph", ""),
|
|
source_url=r.get("source_url", ""),
|
|
score=0.0,
|
|
collection=collection,
|
|
))
|
|
next_offset = data.get("next_offset") or None
|
|
return results, next_offset
|
|
|
|
except Exception as e:
|
|
logger.warning("RAG scroll failed: %s", e)
|
|
return [], None
|
|
|
|
def format_for_prompt(
|
|
self, results: List[RAGSearchResult], max_results: int = 5
|
|
) -> str:
|
|
"""Format search results as Markdown for inclusion in an LLM prompt."""
|
|
if not results:
|
|
return ""
|
|
|
|
lines = ["## Relevanter Rechtskontext\n"]
|
|
for i, r in enumerate(results[:max_results]):
|
|
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
|
|
if r.article:
|
|
header += f" — {r.article}"
|
|
lines.append(header)
|
|
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
|
|
lines.append(f" > {text}\n")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
# Singleton
|
|
_rag_client: Optional[ComplianceRAGClient] = None
|
|
|
|
|
|
def get_rag_client() -> ComplianceRAGClient:
|
|
"""Get the shared RAG client instance."""
|
|
global _rag_client
|
|
if _rag_client is None:
|
|
_rag_client = ComplianceRAGClient()
|
|
return _rag_client
|