From 14a99322eb206e47012a57761080556cbc748f25 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Mon, 2 Mar 2026 08:57:39 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=202=20=E2=80=94=20RAG=20integrati?= =?UTF-8?q?on=20in=20Requirements=20+=20DSFA=20Draft?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add legal context enrichment from Qdrant vector corpus to the two highest-priority modules (Requirements AI assistant and DSFA drafting engine). Go SDK: - Add SearchCollection() with collection override + whitelist validation - Refactor Search() to delegate to shared searchInternal() Python backend: - New ComplianceRAGClient proxying POST /sdk/v1/rag/search (error-tolerant) - AI assistant: enrich interpret_requirement() and suggest_controls() with RAG - Requirements API: add ?include_legal_context=true query parameter Admin (Next.js): - Extract shared queryRAG() utility from chat route - Inject RAG legal context into v1 and v2 draft pipelines Tests for all three layers (Go, Python, TypeScript shared utility). Co-Authored-By: Claude Opus 4.6 --- .../app/api/sdk/drafting-engine/chat/route.ts | 30 +-- .../api/sdk/drafting-engine/draft/route.ts | 25 ++- .../lib/sdk/drafting-engine/rag-query.ts | 40 ++++ .../internal/api/handlers/rag_handlers.go | 21 ++- .../api/handlers/rag_handlers_test.go | 109 +++++++++++ ai-compliance-sdk/internal/ucca/legal_rag.go | 16 +- .../internal/ucca/legal_rag_test.go | 157 ++++++++++++++++ backend-compliance/compliance/api/routes.py | 37 +++- .../services/ai_compliance_assistant.py | 36 ++++ .../compliance/services/rag_client.py | 129 +++++++++++++ backend-compliance/tests/__init__.py | 0 backend-compliance/tests/test_rag_client.py | 175 ++++++++++++++++++ 12 files changed, 739 insertions(+), 36 deletions(-) create mode 100644 admin-compliance/lib/sdk/drafting-engine/rag-query.ts create mode 100644 ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go create mode 100644 ai-compliance-sdk/internal/ucca/legal_rag_test.go create mode 100644 backend-compliance/compliance/services/rag_client.py create mode 100644 backend-compliance/tests/__init__.py create mode 100644 backend-compliance/tests/test_rag_client.py diff --git a/admin-compliance/app/api/sdk/drafting-engine/chat/route.ts b/admin-compliance/app/api/sdk/drafting-engine/chat/route.ts index e24b211..0db1e63 100644 --- a/admin-compliance/app/api/sdk/drafting-engine/chat/route.ts +++ b/admin-compliance/app/api/sdk/drafting-engine/chat/route.ts @@ -7,8 +7,8 @@ */ import { NextRequest, NextResponse } from 'next/server' +import { queryRAG } from '@/lib/sdk/drafting-engine/rag-query' -const KLAUSUR_SERVICE_URL = process.env.KLAUSUR_SERVICE_URL || 'http://klausur-service:8086' const OLLAMA_URL = process.env.OLLAMA_URL || 'http://host.docker.internal:11434' const LLM_MODEL = process.env.COMPLIANCE_LLM_MODEL || 'qwen2.5vl:32b' @@ -30,34 +30,6 @@ Konsistenz zwischen Dokumenten sicherzustellen. ## Kompetenzbereich DSGVO, BDSG, AI Act, TTDSG, DSK-Kurzpapiere, SDM V3.0, BSI-Grundschutz, ISO 27001/27701, EDPB Guidelines, WP248` -/** - * Query the RAG corpus for relevant documents - */ -async function queryRAG(query: string): Promise { - try { - const url = `${KLAUSUR_SERVICE_URL}/api/v1/dsfa-rag/search?query=${encodeURIComponent(query)}&top_k=3` - const res = await fetch(url, { - headers: { 'Content-Type': 'application/json' }, - signal: AbortSignal.timeout(10000), - }) - - if (!res.ok) return '' - - const data = await res.json() - if (data.results?.length > 0) { - return data.results - .map( - (r: { source_name?: string; source_code?: string; content?: string }, i: number) => - `[Quelle ${i + 1}: ${r.source_name || r.source_code || 'Unbekannt'}]\n${r.content || ''}` - ) - .join('\n\n---\n\n') - } - return '' - } catch { - return '' - } -} - export async function POST(request: NextRequest) { try { const body = await request.json() diff --git a/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts b/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts index dde1950..cbd1c04 100644 --- a/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts +++ b/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts @@ -31,6 +31,7 @@ import { sanitizeAllowedFacts, validateNoRemainingPII, SanitizationError } from import { terminologyToPromptString, styleContractToPromptString } from '@/lib/sdk/drafting-engine/terminology' import { executeRepairLoop, type ProseBlockOutput, type RepairAudit } from '@/lib/sdk/drafting-engine/prose-validator' import { ProseCacheManager, computeChecksumSync, type CacheKeyParams } from '@/lib/sdk/drafting-engine/cache' +import { queryRAG } from '@/lib/sdk/drafting-engine/rag-query' // ============================================================================ // Shared State @@ -103,9 +104,20 @@ async function handleV1Draft(body: Record): Promise): Promise): Promise { + try { + const url = `${KLAUSUR_SERVICE_URL}/api/v1/dsfa-rag/search?query=${encodeURIComponent(query)}&top_k=${topK}` + const res = await fetch(url, { + headers: { 'Content-Type': 'application/json' }, + signal: AbortSignal.timeout(10000), + }) + + if (!res.ok) return '' + + const data = await res.json() + if (data.results?.length > 0) { + return data.results + .map( + (r: { source_name?: string; source_code?: string; content?: string }, i: number) => + `[Quelle ${i + 1}: ${r.source_name || r.source_code || 'Unbekannt'}]\n${r.content || ''}` + ) + .join('\n\n---\n\n') + } + return '' + } catch { + return '' + } +} diff --git a/ai-compliance-sdk/internal/api/handlers/rag_handlers.go b/ai-compliance-sdk/internal/api/handlers/rag_handlers.go index fb9d497..2ce83a5 100644 --- a/ai-compliance-sdk/internal/api/handlers/rag_handlers.go +++ b/ai-compliance-sdk/internal/api/handlers/rag_handlers.go @@ -21,9 +21,20 @@ func NewRAGHandlers(corpusVersionStore *ucca.CorpusVersionStore) *RAGHandlers { } } +// AllowedCollections is the whitelist of Qdrant collections that can be queried. +var AllowedCollections = map[string]bool{ + "bp_compliance_ce": true, + "bp_compliance_recht": true, + "bp_compliance_gesetze": true, + "bp_compliance_datenschutz": true, + "bp_dsfa_corpus": true, + "bp_legal_templates": true, +} + // SearchRequest represents a RAG search request. type SearchRequest struct { Query string `json:"query" binding:"required"` + Collection string `json:"collection,omitempty"` Regulations []string `json:"regulations,omitempty"` TopK int `json:"top_k,omitempty"` } @@ -41,7 +52,15 @@ func (h *RAGHandlers) Search(c *gin.Context) { req.TopK = 5 } - results, err := h.ragClient.Search(c.Request.Context(), req.Query, req.Regulations, req.TopK) + // Validate collection if specified + if req.Collection != "" { + if !AllowedCollections[req.Collection] { + c.JSON(http.StatusBadRequest, gin.H{"error": "Unknown collection: " + req.Collection + ". Allowed: bp_compliance_ce, bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz, bp_dsfa_corpus, bp_legal_templates"}) + return + } + } + + results, err := h.ragClient.SearchCollection(c.Request.Context(), req.Collection, req.Query, req.Regulations, req.TopK) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "RAG search failed: " + err.Error()}) return diff --git a/ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go b/ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go new file mode 100644 index 0000000..3f61c7f --- /dev/null +++ b/ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go @@ -0,0 +1,109 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestAllowedCollections(t *testing.T) { + allowed := []string{ + "bp_compliance_ce", + "bp_compliance_recht", + "bp_compliance_gesetze", + "bp_compliance_datenschutz", + "bp_dsfa_corpus", + "bp_legal_templates", + } + + for _, c := range allowed { + if !AllowedCollections[c] { + t.Errorf("Expected %s to be in AllowedCollections", c) + } + } + + disallowed := []string{ + "bp_unknown", + "", + "some_random_collection", + } + + for _, c := range disallowed { + if AllowedCollections[c] { + t.Errorf("Expected %s to NOT be in AllowedCollections", c) + } + } +} + +func TestSearch_InvalidCollection_Returns400(t *testing.T) { + gin.SetMode(gin.TestMode) + + handler := &RAGHandlers{} + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := SearchRequest{ + Query: "test query", + Collection: "bp_evil_collection", + TopK: 5, + } + bodyBytes, _ := json.Marshal(body) + c.Request, _ = http.NewRequest("POST", "/sdk/v1/rag/search", bytes.NewReader(bodyBytes)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.Search(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400, got %d", w.Code) + } + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + + errMsg, ok := resp["error"].(string) + if !ok || errMsg == "" { + t.Error("Expected error message in response") + } +} + +func TestSearch_WithCollectionParam_BindsCorrectly(t *testing.T) { + // Test that the SearchRequest struct correctly binds the collection field + body := `{"query":"DSGVO Art. 35","collection":"bp_compliance_recht","top_k":3}` + var req SearchRequest + err := json.Unmarshal([]byte(body), &req) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if req.Query != "DSGVO Art. 35" { + t.Errorf("Expected query 'DSGVO Art. 35', got '%s'", req.Query) + } + if req.Collection != "bp_compliance_recht" { + t.Errorf("Expected collection 'bp_compliance_recht', got '%s'", req.Collection) + } + if req.TopK != 3 { + t.Errorf("Expected top_k 3, got %d", req.TopK) + } +} + +func TestSearch_EmptyCollection_IsAllowed(t *testing.T) { + // Empty collection should be allowed (falls back to default in the handler) + body := `{"query":"test"}` + var req SearchRequest + err := json.Unmarshal([]byte(body), &req) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if req.Collection != "" { + t.Errorf("Expected empty collection, got '%s'", req.Collection) + } + + // Empty string is not in AllowedCollections, but the handler + // should skip validation for empty collection +} diff --git a/ai-compliance-sdk/internal/ucca/legal_rag.go b/ai-compliance-sdk/internal/ucca/legal_rag.go index 8c17eb8..886b88f 100644 --- a/ai-compliance-sdk/internal/ucca/legal_rag.go +++ b/ai-compliance-sdk/internal/ucca/legal_rag.go @@ -173,8 +173,22 @@ func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([] return embResp.Embedding, nil } +// SearchCollection queries a specific Qdrant collection for relevant passages. +// If collection is empty, it falls back to the default collection (bp_compliance_ce). +func (c *LegalRAGClient) SearchCollection(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) { + if collection == "" { + collection = c.collection + } + return c.searchInternal(ctx, collection, query, regulationIDs, topK) +} + // Search queries the compliance CE corpus for relevant passages. func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) { + return c.searchInternal(ctx, c.collection, query, regulationIDs, topK) +} + +// searchInternal performs the actual search against a given collection. +func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) { // Generate query embedding via Ollama bge-m3 embedding, err := c.generateEmbedding(ctx, query) if err != nil { @@ -206,7 +220,7 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs } // Call Qdrant - url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, c.collection) + url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, collection) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create search request: %w", err) diff --git a/ai-compliance-sdk/internal/ucca/legal_rag_test.go b/ai-compliance-sdk/internal/ucca/legal_rag_test.go new file mode 100644 index 0000000..d0db6e5 --- /dev/null +++ b/ai-compliance-sdk/internal/ucca/legal_rag_test.go @@ -0,0 +1,157 @@ +package ucca + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestSearchCollection_UsesCorrectCollection(t *testing.T) { + // Track which collection was requested + var requestedURL string + + // Mock Ollama (embedding) + ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(ollamaEmbeddingResponse{ + Embedding: make([]float64, 1024), + }) + })) + defer ollamaMock.Close() + + // Mock Qdrant + qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestedURL = r.URL.Path + json.NewEncoder(w).Encode(qdrantSearchResponse{ + Result: []qdrantSearchHit{}, + }) + })) + defer qdrantMock.Close() + + // Parse qdrant mock host/port + qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://") + parts := strings.Split(qdrantAddr, ":") + + client := &LegalRAGClient{ + qdrantHost: parts[0], + qdrantPort: parts[1], + ollamaURL: ollamaMock.URL, + embeddingModel: "bge-m3", + collection: "bp_compliance_ce", + httpClient: http.DefaultClient, + } + + // Test with explicit collection + _, err := client.SearchCollection(context.Background(), "bp_compliance_recht", "test query", nil, 3) + if err != nil { + t.Fatalf("SearchCollection failed: %v", err) + } + + if !strings.Contains(requestedURL, "/collections/bp_compliance_recht/") { + t.Errorf("Expected collection bp_compliance_recht in URL, got: %s", requestedURL) + } +} + +func TestSearchCollection_FallbackDefault(t *testing.T) { + var requestedURL string + + ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(ollamaEmbeddingResponse{ + Embedding: make([]float64, 1024), + }) + })) + defer ollamaMock.Close() + + qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestedURL = r.URL.Path + json.NewEncoder(w).Encode(qdrantSearchResponse{ + Result: []qdrantSearchHit{}, + }) + })) + defer qdrantMock.Close() + + qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://") + parts := strings.Split(qdrantAddr, ":") + + client := &LegalRAGClient{ + qdrantHost: parts[0], + qdrantPort: parts[1], + ollamaURL: ollamaMock.URL, + embeddingModel: "bge-m3", + collection: "bp_compliance_ce", + httpClient: http.DefaultClient, + } + + // Test with empty collection (should fall back to default) + _, err := client.SearchCollection(context.Background(), "", "test query", nil, 3) + if err != nil { + t.Fatalf("SearchCollection failed: %v", err) + } + + if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") { + t.Errorf("Expected default collection bp_compliance_ce in URL, got: %s", requestedURL) + } +} + +func TestSearch_StillWorks(t *testing.T) { + var requestedURL string + + ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(ollamaEmbeddingResponse{ + Embedding: make([]float64, 1024), + }) + })) + defer ollamaMock.Close() + + qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestedURL = r.URL.Path + json.NewEncoder(w).Encode(qdrantSearchResponse{ + Result: []qdrantSearchHit{ + { + ID: "1", + Score: 0.95, + Payload: map[string]interface{}{ + "chunk_text": "Test content", + "regulation_id": "eu_2016_679", + "regulation_name_de": "DSGVO", + "regulation_short": "DSGVO", + "category": "regulation", + "source": "https://example.com", + }, + }, + }, + }) + })) + defer qdrantMock.Close() + + qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://") + parts := strings.Split(qdrantAddr, ":") + + client := &LegalRAGClient{ + qdrantHost: parts[0], + qdrantPort: parts[1], + ollamaURL: ollamaMock.URL, + embeddingModel: "bge-m3", + collection: "bp_compliance_ce", + httpClient: http.DefaultClient, + } + + results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5) + if err != nil { + t.Fatalf("Search failed: %v", err) + } + + if len(results) != 1 { + t.Fatalf("Expected 1 result, got %d", len(results)) + } + + if results[0].RegulationCode != "eu_2016_679" { + t.Errorf("Expected regulation_code eu_2016_679, got %s", results[0].RegulationCode) + } + + if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") { + t.Errorf("Expected default collection in URL, got: %s", requestedURL) + } +} diff --git a/backend-compliance/compliance/api/routes.py b/backend-compliance/compliance/api/routes.py index 2b697d1..814b4bb 100644 --- a/backend-compliance/compliance/api/routes.py +++ b/backend-compliance/compliance/api/routes.py @@ -177,8 +177,12 @@ async def get_regulation_requirements( @router.get("/requirements/{requirement_id}") -async def get_requirement(requirement_id: str, db: Session = Depends(get_db)): - """Get a specific requirement by ID.""" +async def get_requirement( + requirement_id: str, + include_legal_context: bool = Query(False, description="Include RAG legal context"), + db: Session = Depends(get_db), +): + """Get a specific requirement by ID, optionally with RAG legal context.""" from ..db.models import RequirementDB, RegulationDB requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first() @@ -187,7 +191,7 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)): regulation = db.query(RegulationDB).filter(RegulationDB.id == requirement.regulation_id).first() - return { + result = { "id": requirement.id, "regulation_id": requirement.regulation_id, "regulation_code": regulation.code if regulation else None, @@ -214,6 +218,33 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)): "source_section": requirement.source_section, } + if include_legal_context: + try: + from ..services.rag_client import get_rag_client + from ..services.ai_compliance_assistant import AIComplianceAssistant + + rag = get_rag_client() + assistant = AIComplianceAssistant() + query = f"{requirement.title} {requirement.article or ''}" + collection = assistant._collection_for_regulation(regulation.code if regulation else "") + rag_results = await rag.search(query, collection=collection, top_k=3) + result["legal_context"] = [ + { + "text": r.text, + "regulation_code": r.regulation_code, + "regulation_short": r.regulation_short, + "article": r.article, + "score": r.score, + "source_url": r.source_url, + } + for r in rag_results + ] + except Exception as e: + logger.warning("Failed to fetch legal context for %s: %s", requirement_id, e) + result["legal_context"] = [] + + return result + @router.get("/requirements", response_model=PaginatedRequirementResponse) async def list_requirements_paginated( diff --git a/backend-compliance/compliance/services/ai_compliance_assistant.py b/backend-compliance/compliance/services/ai_compliance_assistant.py index 46c417a..6889086 100644 --- a/backend-compliance/compliance/services/ai_compliance_assistant.py +++ b/backend-compliance/compliance/services/ai_compliance_assistant.py @@ -16,6 +16,7 @@ from typing import List, Optional, Dict, Any from enum import Enum from .llm_provider import LLMProvider, get_shared_provider, LLMResponse +from .rag_client import get_rag_client logger = logging.getLogger(__name__) @@ -199,9 +200,23 @@ Bewerte die Abdeckung und identifiziere Lücken im JSON-Format: Gib NUR das JSON zurück.""" + # EU regulation codes → bp_compliance_ce, DE codes → bp_compliance_recht + _EU_CODES = {"DSGVO", "GDPR", "AIACT", "AI_ACT", "NIS2", "CRA"} + _DE_CODES = {"BDSG", "TDDDG", "DDG", "URHG", "TMG", "TKG"} + def __init__(self, llm_provider: Optional[LLMProvider] = None): """Initialize the assistant with an LLM provider.""" self.llm = llm_provider or get_shared_provider() + self.rag = get_rag_client() + + def _collection_for_regulation(self, regulation_code: str) -> str: + """Determine the RAG collection based on regulation code.""" + code_upper = regulation_code.upper() + if any(c in code_upper for c in self._EU_CODES): + return "bp_compliance_ce" + elif any(c in code_upper for c in self._DE_CODES): + return "bp_compliance_recht" + return "bp_compliance_ce" async def interpret_requirement( self, @@ -226,6 +241,17 @@ Gib NUR das JSON zurück.""" requirement_text=requirement_text or "Kein Text verfügbar" ) + # Enrich prompt with RAG legal context + try: + rag_query = f"{regulation_name} {article} {title}" + collection = self._collection_for_regulation(regulation_code) + rag_results = await self.rag.search(rag_query, collection=collection, top_k=3) + rag_context = self.rag.format_for_prompt(rag_results) + if rag_context: + prompt += f"\n\n{rag_context}" + except Exception as e: + logger.warning("RAG enrichment failed for interpret_requirement: %s", e) + try: response = await self.llm.complete( prompt=prompt, @@ -282,6 +308,16 @@ Gib NUR das JSON zurück.""" affected_modules=", ".join(affected_modules) if affected_modules else "Alle Module" ) + # Enrich prompt with RAG legal context + try: + rag_query = f"{regulation_name} {requirement_title} Massnahmen Controls" + rag_results = await self.rag.search(rag_query, top_k=3) + rag_context = self.rag.format_for_prompt(rag_results) + if rag_context: + prompt += f"\n\n{rag_context}" + except Exception as e: + logger.warning("RAG enrichment failed for suggest_controls: %s", e) + try: response = await self.llm.complete( prompt=prompt, diff --git a/backend-compliance/compliance/services/rag_client.py b/backend-compliance/compliance/services/rag_client.py new file mode 100644 index 0000000..2e0a22e --- /dev/null +++ b/backend-compliance/compliance/services/rag_client.py @@ -0,0 +1,129 @@ +""" +Compliance RAG Client — Proxy to Go SDK RAG Search. + +Lightweight HTTP client that queries the Go AI Compliance SDK's +POST /sdk/v1/rag/search endpoint. This avoids needing embedding +models or direct Qdrant access in Python. + +Error-tolerant: RAG failures never break the calling function. +""" + +import logging +import os +from dataclasses import dataclass +from typing import List, Optional + +import httpx + +logger = logging.getLogger(__name__) + +SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090") +RAG_SEARCH_TIMEOUT = 15.0 # seconds + + +@dataclass +class RAGSearchResult: + """A single search result from the compliance corpus.""" + text: str + regulation_code: str + regulation_name: str + regulation_short: str + category: str + article: str + paragraph: str + source_url: str + score: float + + +class ComplianceRAGClient: + """ + RAG client that proxies search requests to the Go SDK. + + Usage: + client = get_rag_client() + results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht") + context_str = client.format_for_prompt(results) + """ + + def __init__(self, base_url: str = SDK_URL): + self._search_url = f"{base_url}/sdk/v1/rag/search" + + async def search( + self, + query: str, + collection: str = "bp_compliance_ce", + regulations: Optional[List[str]] = None, + top_k: int = 5, + ) -> List[RAGSearchResult]: + """ + Search the RAG corpus via Go SDK. + + Returns an empty list on any error (never raises). + """ + payload = { + "query": query, + "collection": collection, + "top_k": top_k, + } + if regulations: + payload["regulations"] = regulations + + try: + async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client: + resp = await client.post(self._search_url, json=payload) + + if resp.status_code != 200: + logger.warning( + "RAG search returned %d: %s", resp.status_code, resp.text[:200] + ) + return [] + + data = resp.json() + results = [] + for r in data.get("results", []): + results.append(RAGSearchResult( + text=r.get("text", ""), + regulation_code=r.get("regulation_code", ""), + regulation_name=r.get("regulation_name", ""), + regulation_short=r.get("regulation_short", ""), + category=r.get("category", ""), + article=r.get("article", ""), + paragraph=r.get("paragraph", ""), + source_url=r.get("source_url", ""), + score=r.get("score", 0.0), + )) + return results + + except Exception as e: + logger.warning("RAG search failed: %s", e) + return [] + + def format_for_prompt( + self, results: List[RAGSearchResult], max_results: int = 5 + ) -> str: + """Format search results as Markdown for inclusion in an LLM prompt.""" + if not results: + return "" + + lines = ["## Relevanter Rechtskontext\n"] + for i, r in enumerate(results[:max_results]): + header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})" + if r.article: + header += f" — {r.article}" + lines.append(header) + text = r.text[:400] + "..." if len(r.text) > 400 else r.text + lines.append(f" > {text}\n") + + return "\n".join(lines) + + +# Singleton +_rag_client: Optional[ComplianceRAGClient] = None + + +def get_rag_client() -> ComplianceRAGClient: + """Get the shared RAG client instance.""" + global _rag_client + if _rag_client is None: + _rag_client = ComplianceRAGClient() + return _rag_client diff --git a/backend-compliance/tests/__init__.py b/backend-compliance/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend-compliance/tests/test_rag_client.py b/backend-compliance/tests/test_rag_client.py new file mode 100644 index 0000000..2921cab --- /dev/null +++ b/backend-compliance/tests/test_rag_client.py @@ -0,0 +1,175 @@ +"""Tests for ComplianceRAGClient.""" + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult + + +class TestComplianceRAGClient: + """Tests for the RAG client proxy.""" + + def setup_method(self): + self.client = ComplianceRAGClient(base_url="http://test-sdk:8090") + + @pytest.mark.asyncio + async def test_search_success(self): + """Successful search returns parsed results.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "query": "DSGVO Art. 35", + "results": [ + { + "text": "Art. 35 DSGVO regelt die Datenschutz-Folgenabschaetzung...", + "regulation_code": "eu_2016_679", + "regulation_name": "DSGVO", + "regulation_short": "DSGVO", + "category": "regulation", + "article": "Art. 35", + "paragraph": "", + "source_url": "https://example.com", + "score": 0.92, + } + ], + "count": 1, + } + + with patch("httpx.AsyncClient") as MockClient: + mock_instance = AsyncMock() + mock_instance.post.return_value = mock_response + mock_instance.__aenter__ = AsyncMock(return_value=mock_instance) + mock_instance.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value = mock_instance + + results = await self.client.search("DSGVO Art. 35", collection="bp_compliance_ce") + + assert len(results) == 1 + assert results[0].regulation_code == "eu_2016_679" + assert results[0].score == 0.92 + assert "Art. 35" in results[0].text + + @pytest.mark.asyncio + async def test_search_failure_returns_empty(self): + """Network errors return empty list, never raise.""" + with patch("httpx.AsyncClient") as MockClient: + mock_instance = AsyncMock() + mock_instance.post.side_effect = Exception("Connection refused") + mock_instance.__aenter__ = AsyncMock(return_value=mock_instance) + mock_instance.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value = mock_instance + + results = await self.client.search("test query") + + assert results == [] + + @pytest.mark.asyncio + async def test_search_http_error_returns_empty(self): + """HTTP errors return empty list.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + with patch("httpx.AsyncClient") as MockClient: + mock_instance = AsyncMock() + mock_instance.post.return_value = mock_response + mock_instance.__aenter__ = AsyncMock(return_value=mock_instance) + mock_instance.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value = mock_instance + + results = await self.client.search("test query") + + assert results == [] + + def test_format_for_prompt(self): + """format_for_prompt produces Markdown output.""" + results = [ + RAGSearchResult( + text="Die Verarbeitung personenbezogener Daten...", + regulation_code="eu_2016_679", + regulation_name="DSGVO", + regulation_short="DSGVO", + category="regulation", + article="Art. 35", + paragraph="", + source_url="https://example.com", + score=0.9, + ), + RAGSearchResult( + text="Risikobewertung fuer KI-Systeme...", + regulation_code="eu_2024_1689", + regulation_name="AI Act", + regulation_short="AI Act", + category="regulation", + article="", + paragraph="", + source_url="https://example.com", + score=0.85, + ), + ] + + output = self.client.format_for_prompt(results) + + assert "## Relevanter Rechtskontext" in output + assert "**DSGVO**" in output + assert "Art. 35" in output + assert "**AI Act**" in output + + def test_format_for_prompt_empty(self): + """Empty results return empty string.""" + assert self.client.format_for_prompt([]) == "" + + def test_format_for_prompt_truncation(self): + """Long text is truncated to 400 chars.""" + results = [ + RAGSearchResult( + text="A" * 500, + regulation_code="test", + regulation_name="Test", + regulation_short="Test", + category="test", + article="", + paragraph="", + source_url="", + score=0.5, + ), + ] + + output = self.client.format_for_prompt(results) + assert "..." in output + + +class TestCollectionMapping: + """Tests for regulation → collection mapping in AIComplianceAssistant.""" + + def test_eu_regulations_map_to_ce(self): + from compliance.services.ai_compliance_assistant import AIComplianceAssistant + assistant = AIComplianceAssistant.__new__(AIComplianceAssistant) + + assert assistant._collection_for_regulation("DSGVO") == "bp_compliance_ce" + assert assistant._collection_for_regulation("GDPR") == "bp_compliance_ce" + assert assistant._collection_for_regulation("AI_ACT") == "bp_compliance_ce" + assert assistant._collection_for_regulation("NIS2") == "bp_compliance_ce" + assert assistant._collection_for_regulation("CRA") == "bp_compliance_ce" + + def test_de_regulations_map_to_recht(self): + from compliance.services.ai_compliance_assistant import AIComplianceAssistant + assistant = AIComplianceAssistant.__new__(AIComplianceAssistant) + + assert assistant._collection_for_regulation("BDSG") == "bp_compliance_recht" + assert assistant._collection_for_regulation("TDDDG") == "bp_compliance_recht" + assert assistant._collection_for_regulation("TKG") == "bp_compliance_recht" + + def test_unknown_regulation_defaults_to_ce(self): + from compliance.services.ai_compliance_assistant import AIComplianceAssistant + assistant = AIComplianceAssistant.__new__(AIComplianceAssistant) + + assert assistant._collection_for_regulation("UNKNOWN") == "bp_compliance_ce" + assert assistant._collection_for_regulation("") == "bp_compliance_ce" + + def test_case_insensitive(self): + from compliance.services.ai_compliance_assistant import AIComplianceAssistant + assistant = AIComplianceAssistant.__new__(AIComplianceAssistant) + + assert assistant._collection_for_regulation("dsgvo") == "bp_compliance_ce" + assert assistant._collection_for_regulation("bdsg") == "bp_compliance_recht"