feat: Phase 2 — RAG integration in Requirements + DSFA Draft
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
Add legal context enrichment from Qdrant vector corpus to the two highest-priority modules (Requirements AI assistant and DSFA drafting engine). Go SDK: - Add SearchCollection() with collection override + whitelist validation - Refactor Search() to delegate to shared searchInternal() Python backend: - New ComplianceRAGClient proxying POST /sdk/v1/rag/search (error-tolerant) - AI assistant: enrich interpret_requirement() and suggest_controls() with RAG - Requirements API: add ?include_legal_context=true query parameter Admin (Next.js): - Extract shared queryRAG() utility from chat route - Inject RAG legal context into v1 and v2 draft pipelines Tests for all three layers (Go, Python, TypeScript shared utility). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -177,8 +177,12 @@ async def get_regulation_requirements(
|
||||
|
||||
|
||||
@router.get("/requirements/{requirement_id}")
|
||||
async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
||||
"""Get a specific requirement by ID."""
|
||||
async def get_requirement(
|
||||
requirement_id: str,
|
||||
include_legal_context: bool = Query(False, description="Include RAG legal context"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a specific requirement by ID, optionally with RAG legal context."""
|
||||
from ..db.models import RequirementDB, RegulationDB
|
||||
|
||||
requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first()
|
||||
@@ -187,7 +191,7 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
||||
|
||||
regulation = db.query(RegulationDB).filter(RegulationDB.id == requirement.regulation_id).first()
|
||||
|
||||
return {
|
||||
result = {
|
||||
"id": requirement.id,
|
||||
"regulation_id": requirement.regulation_id,
|
||||
"regulation_code": regulation.code if regulation else None,
|
||||
@@ -214,6 +218,33 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
||||
"source_section": requirement.source_section,
|
||||
}
|
||||
|
||||
if include_legal_context:
|
||||
try:
|
||||
from ..services.rag_client import get_rag_client
|
||||
from ..services.ai_compliance_assistant import AIComplianceAssistant
|
||||
|
||||
rag = get_rag_client()
|
||||
assistant = AIComplianceAssistant()
|
||||
query = f"{requirement.title} {requirement.article or ''}"
|
||||
collection = assistant._collection_for_regulation(regulation.code if regulation else "")
|
||||
rag_results = await rag.search(query, collection=collection, top_k=3)
|
||||
result["legal_context"] = [
|
||||
{
|
||||
"text": r.text,
|
||||
"regulation_code": r.regulation_code,
|
||||
"regulation_short": r.regulation_short,
|
||||
"article": r.article,
|
||||
"score": r.score,
|
||||
"source_url": r.source_url,
|
||||
}
|
||||
for r in rag_results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch legal context for %s: %s", requirement_id, e)
|
||||
result["legal_context"] = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/requirements", response_model=PaginatedRequirementResponse)
|
||||
async def list_requirements_paginated(
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
from .llm_provider import LLMProvider, get_shared_provider, LLMResponse
|
||||
from .rag_client import get_rag_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -199,9 +200,23 @@ Bewerte die Abdeckung und identifiziere Lücken im JSON-Format:
|
||||
|
||||
Gib NUR das JSON zurück."""
|
||||
|
||||
# EU regulation codes → bp_compliance_ce, DE codes → bp_compliance_recht
|
||||
_EU_CODES = {"DSGVO", "GDPR", "AIACT", "AI_ACT", "NIS2", "CRA"}
|
||||
_DE_CODES = {"BDSG", "TDDDG", "DDG", "URHG", "TMG", "TKG"}
|
||||
|
||||
def __init__(self, llm_provider: Optional[LLMProvider] = None):
|
||||
"""Initialize the assistant with an LLM provider."""
|
||||
self.llm = llm_provider or get_shared_provider()
|
||||
self.rag = get_rag_client()
|
||||
|
||||
def _collection_for_regulation(self, regulation_code: str) -> str:
|
||||
"""Determine the RAG collection based on regulation code."""
|
||||
code_upper = regulation_code.upper()
|
||||
if any(c in code_upper for c in self._EU_CODES):
|
||||
return "bp_compliance_ce"
|
||||
elif any(c in code_upper for c in self._DE_CODES):
|
||||
return "bp_compliance_recht"
|
||||
return "bp_compliance_ce"
|
||||
|
||||
async def interpret_requirement(
|
||||
self,
|
||||
@@ -226,6 +241,17 @@ Gib NUR das JSON zurück."""
|
||||
requirement_text=requirement_text or "Kein Text verfügbar"
|
||||
)
|
||||
|
||||
# Enrich prompt with RAG legal context
|
||||
try:
|
||||
rag_query = f"{regulation_name} {article} {title}"
|
||||
collection = self._collection_for_regulation(regulation_code)
|
||||
rag_results = await self.rag.search(rag_query, collection=collection, top_k=3)
|
||||
rag_context = self.rag.format_for_prompt(rag_results)
|
||||
if rag_context:
|
||||
prompt += f"\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning("RAG enrichment failed for interpret_requirement: %s", e)
|
||||
|
||||
try:
|
||||
response = await self.llm.complete(
|
||||
prompt=prompt,
|
||||
@@ -282,6 +308,16 @@ Gib NUR das JSON zurück."""
|
||||
affected_modules=", ".join(affected_modules) if affected_modules else "Alle Module"
|
||||
)
|
||||
|
||||
# Enrich prompt with RAG legal context
|
||||
try:
|
||||
rag_query = f"{regulation_name} {requirement_title} Massnahmen Controls"
|
||||
rag_results = await self.rag.search(rag_query, top_k=3)
|
||||
rag_context = self.rag.format_for_prompt(rag_results)
|
||||
if rag_context:
|
||||
prompt += f"\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning("RAG enrichment failed for suggest_controls: %s", e)
|
||||
|
||||
try:
|
||||
response = await self.llm.complete(
|
||||
prompt=prompt,
|
||||
|
||||
129
backend-compliance/compliance/services/rag_client.py
Normal file
129
backend-compliance/compliance/services/rag_client.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Compliance RAG Client — Proxy to Go SDK RAG Search.
|
||||
|
||||
Lightweight HTTP client that queries the Go AI Compliance SDK's
|
||||
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
|
||||
models or direct Qdrant access in Python.
|
||||
|
||||
Error-tolerant: RAG failures never break the calling function.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
|
||||
RAG_SEARCH_TIMEOUT = 15.0 # seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGSearchResult:
|
||||
"""A single search result from the compliance corpus."""
|
||||
text: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
regulation_short: str
|
||||
category: str
|
||||
article: str
|
||||
paragraph: str
|
||||
source_url: str
|
||||
score: float
|
||||
|
||||
|
||||
class ComplianceRAGClient:
|
||||
"""
|
||||
RAG client that proxies search requests to the Go SDK.
|
||||
|
||||
Usage:
|
||||
client = get_rag_client()
|
||||
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
|
||||
context_str = client.format_for_prompt(results)
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = SDK_URL):
|
||||
self._search_url = f"{base_url}/sdk/v1/rag/search"
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
collection: str = "bp_compliance_ce",
|
||||
regulations: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[RAGSearchResult]:
|
||||
"""
|
||||
Search the RAG corpus via Go SDK.
|
||||
|
||||
Returns an empty list on any error (never raises).
|
||||
"""
|
||||
payload = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"top_k": top_k,
|
||||
}
|
||||
if regulations:
|
||||
payload["regulations"] = regulations
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
|
||||
resp = await client.post(self._search_url, json=payload)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
|
||||
)
|
||||
return []
|
||||
|
||||
data = resp.json()
|
||||
results = []
|
||||
for r in data.get("results", []):
|
||||
results.append(RAGSearchResult(
|
||||
text=r.get("text", ""),
|
||||
regulation_code=r.get("regulation_code", ""),
|
||||
regulation_name=r.get("regulation_name", ""),
|
||||
regulation_short=r.get("regulation_short", ""),
|
||||
category=r.get("category", ""),
|
||||
article=r.get("article", ""),
|
||||
paragraph=r.get("paragraph", ""),
|
||||
source_url=r.get("source_url", ""),
|
||||
score=r.get("score", 0.0),
|
||||
))
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("RAG search failed: %s", e)
|
||||
return []
|
||||
|
||||
def format_for_prompt(
|
||||
self, results: List[RAGSearchResult], max_results: int = 5
|
||||
) -> str:
|
||||
"""Format search results as Markdown for inclusion in an LLM prompt."""
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
lines = ["## Relevanter Rechtskontext\n"]
|
||||
for i, r in enumerate(results[:max_results]):
|
||||
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
|
||||
if r.article:
|
||||
header += f" — {r.article}"
|
||||
lines.append(header)
|
||||
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
|
||||
lines.append(f" > {text}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Singleton
|
||||
_rag_client: Optional[ComplianceRAGClient] = None
|
||||
|
||||
|
||||
def get_rag_client() -> ComplianceRAGClient:
|
||||
"""Get the shared RAG client instance."""
|
||||
global _rag_client
|
||||
if _rag_client is None:
|
||||
_rag_client = ComplianceRAGClient()
|
||||
return _rag_client
|
||||
0
backend-compliance/tests/__init__.py
Normal file
0
backend-compliance/tests/__init__.py
Normal file
175
backend-compliance/tests/test_rag_client.py
Normal file
175
backend-compliance/tests/test_rag_client.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for ComplianceRAGClient."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult
|
||||
|
||||
|
||||
class TestComplianceRAGClient:
|
||||
"""Tests for the RAG client proxy."""
|
||||
|
||||
def setup_method(self):
|
||||
self.client = ComplianceRAGClient(base_url="http://test-sdk:8090")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_success(self):
|
||||
"""Successful search returns parsed results."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"query": "DSGVO Art. 35",
|
||||
"results": [
|
||||
{
|
||||
"text": "Art. 35 DSGVO regelt die Datenschutz-Folgenabschaetzung...",
|
||||
"regulation_code": "eu_2016_679",
|
||||
"regulation_name": "DSGVO",
|
||||
"regulation_short": "DSGVO",
|
||||
"category": "regulation",
|
||||
"article": "Art. 35",
|
||||
"paragraph": "",
|
||||
"source_url": "https://example.com",
|
||||
"score": 0.92,
|
||||
}
|
||||
],
|
||||
"count": 1,
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.post.return_value = mock_response
|
||||
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = mock_instance
|
||||
|
||||
results = await self.client.search("DSGVO Art. 35", collection="bp_compliance_ce")
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].regulation_code == "eu_2016_679"
|
||||
assert results[0].score == 0.92
|
||||
assert "Art. 35" in results[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_failure_returns_empty(self):
|
||||
"""Network errors return empty list, never raise."""
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.post.side_effect = Exception("Connection refused")
|
||||
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = mock_instance
|
||||
|
||||
results = await self.client.search("test query")
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_http_error_returns_empty(self):
|
||||
"""HTTP errors return empty list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.post.return_value = mock_response
|
||||
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||
MockClient.return_value = mock_instance
|
||||
|
||||
results = await self.client.search("test query")
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_format_for_prompt(self):
|
||||
"""format_for_prompt produces Markdown output."""
|
||||
results = [
|
||||
RAGSearchResult(
|
||||
text="Die Verarbeitung personenbezogener Daten...",
|
||||
regulation_code="eu_2016_679",
|
||||
regulation_name="DSGVO",
|
||||
regulation_short="DSGVO",
|
||||
category="regulation",
|
||||
article="Art. 35",
|
||||
paragraph="",
|
||||
source_url="https://example.com",
|
||||
score=0.9,
|
||||
),
|
||||
RAGSearchResult(
|
||||
text="Risikobewertung fuer KI-Systeme...",
|
||||
regulation_code="eu_2024_1689",
|
||||
regulation_name="AI Act",
|
||||
regulation_short="AI Act",
|
||||
category="regulation",
|
||||
article="",
|
||||
paragraph="",
|
||||
source_url="https://example.com",
|
||||
score=0.85,
|
||||
),
|
||||
]
|
||||
|
||||
output = self.client.format_for_prompt(results)
|
||||
|
||||
assert "## Relevanter Rechtskontext" in output
|
||||
assert "**DSGVO**" in output
|
||||
assert "Art. 35" in output
|
||||
assert "**AI Act**" in output
|
||||
|
||||
def test_format_for_prompt_empty(self):
|
||||
"""Empty results return empty string."""
|
||||
assert self.client.format_for_prompt([]) == ""
|
||||
|
||||
def test_format_for_prompt_truncation(self):
|
||||
"""Long text is truncated to 400 chars."""
|
||||
results = [
|
||||
RAGSearchResult(
|
||||
text="A" * 500,
|
||||
regulation_code="test",
|
||||
regulation_name="Test",
|
||||
regulation_short="Test",
|
||||
category="test",
|
||||
article="",
|
||||
paragraph="",
|
||||
source_url="",
|
||||
score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
output = self.client.format_for_prompt(results)
|
||||
assert "..." in output
|
||||
|
||||
|
||||
class TestCollectionMapping:
|
||||
"""Tests for regulation → collection mapping in AIComplianceAssistant."""
|
||||
|
||||
def test_eu_regulations_map_to_ce(self):
|
||||
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||
|
||||
assert assistant._collection_for_regulation("DSGVO") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("GDPR") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("AI_ACT") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("NIS2") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("CRA") == "bp_compliance_ce"
|
||||
|
||||
def test_de_regulations_map_to_recht(self):
|
||||
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||
|
||||
assert assistant._collection_for_regulation("BDSG") == "bp_compliance_recht"
|
||||
assert assistant._collection_for_regulation("TDDDG") == "bp_compliance_recht"
|
||||
assert assistant._collection_for_regulation("TKG") == "bp_compliance_recht"
|
||||
|
||||
def test_unknown_regulation_defaults_to_ce(self):
|
||||
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||
|
||||
assert assistant._collection_for_regulation("UNKNOWN") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("") == "bp_compliance_ce"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||
|
||||
assert assistant._collection_for_regulation("dsgvo") == "bp_compliance_ce"
|
||||
assert assistant._collection_for_regulation("bdsg") == "bp_compliance_recht"
|
||||
Reference in New Issue
Block a user