""" Compliance RAG Client — Proxy to Go SDK RAG Search. Lightweight HTTP client that queries the Go AI Compliance SDK's POST /sdk/v1/rag/search endpoint. This avoids needing embedding models or direct Qdrant access in Python. Error-tolerant: RAG failures never break the calling function. """ import logging import os from dataclasses import dataclass from typing import List, Optional import httpx logger = logging.getLogger(__name__) SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090") RAG_SEARCH_TIMEOUT = 15.0 # seconds @dataclass class RAGSearchResult: """A single search result from the compliance corpus.""" text: str regulation_code: str regulation_name: str regulation_short: str category: str article: str paragraph: str source_url: str score: float collection: str = "" class ComplianceRAGClient: """ RAG client that proxies search requests to the Go SDK. Usage: client = get_rag_client() results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht") context_str = client.format_for_prompt(results) """ def __init__(self, base_url: str = SDK_URL): self._search_url = f"{base_url}/sdk/v1/rag/search" async def search( self, query: str, collection: str = "bp_compliance_ce", regulations: Optional[List[str]] = None, top_k: int = 5, ) -> List[RAGSearchResult]: """ Search the RAG corpus via Go SDK. Returns an empty list on any error (never raises). """ payload = { "query": query, "collection": collection, "top_k": top_k, } if regulations: payload["regulations"] = regulations try: async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client: resp = await client.post(self._search_url, json=payload) if resp.status_code != 200: logger.warning( "RAG search returned %d: %s", resp.status_code, resp.text[:200] ) return [] data = resp.json() results = [] for r in data.get("results", []): results.append(RAGSearchResult( text=r.get("text", ""), regulation_code=r.get("regulation_code", ""), regulation_name=r.get("regulation_name", ""), regulation_short=r.get("regulation_short", ""), category=r.get("category", ""), article=r.get("article", ""), paragraph=r.get("paragraph", ""), source_url=r.get("source_url", ""), score=r.get("score", 0.0), collection=collection, )) return results except Exception as e: logger.warning("RAG search failed: %s", e) return [] async def search_with_rerank( self, query: str, collection: str = "bp_compliance_ce", regulations: Optional[List[str]] = None, top_k: int = 5, ) -> List[RAGSearchResult]: """ Search with optional cross-encoder re-ranking. Fetches top_k*4 results from RAG, then re-ranks with cross-encoder and returns top_k. Falls back to regular search if reranker is disabled. """ from .reranker import get_reranker reranker = get_reranker() if reranker is None: return await self.search(query, collection, regulations, top_k) # Fetch more candidates for re-ranking candidates = await self.search( query, collection, regulations, top_k=max(top_k * 4, 20) ) if not candidates: return [] texts = [c.text for c in candidates] try: ranked_indices = reranker.rerank(query, texts, top_k=top_k) return [candidates[i] for i in ranked_indices] except Exception as e: logger.warning("Reranking failed, returning unranked: %s", e) return candidates[:top_k] async def scroll( self, collection: str, offset: Optional[str] = None, limit: int = 100, ) -> tuple[List[RAGSearchResult], Optional[str]]: """ Scroll through ALL chunks in a collection (paginated). Returns (chunks, next_offset). next_offset is None when done. """ scroll_url = self._search_url.replace("/search", "/scroll") params = {"collection": collection, "limit": str(limit)} if offset: params["offset"] = offset try: async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.get(scroll_url, params=params) if resp.status_code != 200: logger.warning( "RAG scroll returned %d: %s", resp.status_code, resp.text[:200] ) return [], None data = resp.json() results = [] for r in data.get("chunks", []): results.append(RAGSearchResult( text=r.get("text", ""), regulation_code=r.get("regulation_code", ""), regulation_name=r.get("regulation_name", ""), regulation_short=r.get("regulation_short", ""), category=r.get("category", ""), article=r.get("article", ""), paragraph=r.get("paragraph", ""), source_url=r.get("source_url", ""), score=0.0, collection=collection, )) next_offset = data.get("next_offset") or None return results, next_offset except Exception as e: logger.warning("RAG scroll failed: %s", e) return [], None def format_for_prompt( self, results: List[RAGSearchResult], max_results: int = 5 ) -> str: """Format search results as Markdown for inclusion in an LLM prompt.""" if not results: return "" lines = ["## Relevanter Rechtskontext\n"] for i, r in enumerate(results[:max_results]): header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})" if r.article: header += f" — {r.article}" lines.append(header) text = r.text[:400] + "..." if len(r.text) > 400 else r.text lines.append(f" > {text}\n") return "\n".join(lines) # Singleton _rag_client: Optional[ComplianceRAGClient] = None def get_rag_client() -> ComplianceRAGClient: """Get the shared RAG client instance.""" global _rag_client if _rag_client is None: _rag_client = ComplianceRAGClient() return _rag_client