feat: add compliance modules 2-5 (dashboard, security templates, process manager, evidence collector)
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 32s
CI/CD / test-python-backend-compliance (push) Successful in 34s
CI/CD / test-python-document-crawler (push) Successful in 23s
CI/CD / test-python-dsms-gateway (push) Successful in 21s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Successful in 2s

Module 2: Extended Compliance Dashboard with roadmap, module-status, next-actions, snapshots, score-history
Module 3: 7 German security document templates (IT-Sicherheitskonzept, Datenschutz, Backup, Logging, Incident-Response, Zugriff, Risikomanagement)
Module 4: Compliance Process Manager with CRUD, complete/skip/seed, ~50 seed tasks, 3-tab UI
Module 5: Evidence Collector Extended with automated checks, control-mapping, coverage report, 4-tab UI

Also includes: canonical control library enhancements (verification method, categories, dedup), control generator improvements, RAG client extensions

52 tests pass, frontend builds clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-14 21:03:04 +01:00
parent 13d13c8226
commit 49ce417428
35 changed files with 8741 additions and 422 deletions

View File

@@ -46,7 +46,7 @@ EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
ANTHROPIC_MODEL = os.getenv("CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3:30b-a3b")
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "120"))
HARMONIZATION_THRESHOLD = 0.85 # Cosine similarity above this = duplicate
@@ -157,6 +157,9 @@ REGULATION_LICENSE_MAP: dict[str, dict] = {
"edpb_transfers_07_2020":{"license": "EU_PUBLIC", "rule": 1, "name": "EDPB Transfers 07/2020"},
"edpb_video_03_2019": {"license": "EU_PUBLIC", "rule": 1, "name": "EDPB Video Surveillance"},
"edps_dpia_list": {"license": "EU_PUBLIC", "rule": 1, "name": "EDPS DPIA Liste"},
"edpb_certification_01_2018": {"license": "EU_PUBLIC", "rule": 1, "name": "EDPB Certification 01/2018"},
"edpb_certification_01_2019": {"license": "EU_PUBLIC", "rule": 1, "name": "EDPB Certification 01/2019"},
"eaa": {"license": "EU_LAW", "rule": 1, "name": "European Accessibility Act"},
# WP29 (pre-EDPB) Guidelines
"wp244_profiling": {"license": "EU_PUBLIC", "rule": 1, "name": "WP29 Profiling"},
"wp251_profiling": {"license": "EU_PUBLIC", "rule": 1, "name": "WP29 Data Portability"},
@@ -240,6 +243,83 @@ DOMAIN_KEYWORDS = {
}
CATEGORY_KEYWORDS = {
"encryption": ["encryption", "cryptography", "tls", "ssl", "certificate", "hashing",
"aes", "rsa", "verschlüsselung", "kryptographie", "zertifikat", "cipher"],
"authentication": ["authentication", "login", "password", "credential", "mfa", "2fa",
"session", "oauth", "authentifizierung", "anmeldung", "passwort"],
"network": ["network", "firewall", "dns", "vpn", "proxy", "segmentation",
"netzwerk", "routing", "port", "intrusion", "ids", "ips"],
"data_protection": ["data protection", "privacy", "personal data", "datenschutz",
"personenbezogen", "dsgvo", "gdpr", "löschung", "verarbeitung", "einwilligung"],
"logging": ["logging", "monitoring", "audit trail", "siem", "alert", "anomaly",
"protokollierung", "überwachung", "nachvollziehbar"],
"incident": ["incident", "response", "breach", "recovery", "vorfall", "sicherheitsvorfall"],
"continuity": ["backup", "disaster recovery", "notfall", "wiederherstellung", "notfallplan",
"business continuity", "ausfallsicherheit"],
"compliance": ["compliance", "audit", "regulation", "certification", "konformität",
"prüfung", "zertifizierung", "nachweis"],
"supply_chain": ["supplier", "vendor", "third party", "lieferant", "auftragnehmer",
"unterauftragnehmer", "supply chain", "dienstleister"],
"physical": ["physical", "building", "access zone", "physisch", "gebäude", "zutritt",
"schließsystem", "rechenzentrum"],
"personnel": ["training", "awareness", "employee", "schulung", "mitarbeiter",
"sensibilisierung", "personal", "unterweisung"],
"application": ["application", "software", "code review", "sdlc", "secure coding",
"anwendung", "entwicklung", "software-entwicklung", "api"],
"system": ["hardening", "patch", "configuration", "update", "härtung", "konfiguration",
"betriebssystem", "system", "server"],
"risk": ["risk assessment", "risk management", "risiko", "bewertung", "risikobewertung",
"risikoanalyse", "bedrohung", "threat"],
"governance": ["governance", "policy", "organization", "isms", "sicherheitsorganisation",
"richtlinie", "verantwortlichkeit", "rolle"],
"hardware": ["hardware", "platform", "firmware", "bios", "tpm", "chip",
"plattform", "geräte"],
"identity": ["identity", "iam", "directory", "ldap", "sso", "provisioning",
"identität", "identitätsmanagement", "benutzerverzeichnis"],
}
VERIFICATION_KEYWORDS = {
"code_review": ["source code", "code review", "static analysis", "sast", "dast",
"dependency check", "quellcode", "codeanalyse", "secure coding",
"software development", "api", "input validation", "output encoding"],
"document": ["policy", "procedure", "documentation", "training", "awareness",
"richtlinie", "dokumentation", "schulung", "nachweis", "vertrag",
"organizational", "process", "role", "responsibility"],
"tool": ["scanner", "monitoring", "siem", "ids", "ips", "firewall", "antivirus",
"vulnerability scan", "penetration test", "tool", "automated"],
"hybrid": [], # Assigned when multiple methods match equally
}
def _detect_category(text: str) -> Optional[str]:
"""Detect the most likely category from text content."""
text_lower = text.lower()
scores: dict[str, int] = {}
for cat, keywords in CATEGORY_KEYWORDS.items():
scores[cat] = sum(1 for kw in keywords if kw in text_lower)
if not scores or max(scores.values()) == 0:
return None
return max(scores, key=scores.get)
def _detect_verification_method(text: str) -> Optional[str]:
"""Detect verification method from text content."""
text_lower = text.lower()
scores: dict[str, int] = {}
for method, keywords in VERIFICATION_KEYWORDS.items():
if method == "hybrid":
continue
scores[method] = sum(1 for kw in keywords if kw in text_lower)
if not scores or max(scores.values()) == 0:
return None
top = sorted(scores.items(), key=lambda x: -x[1])
# If top two are close, it's hybrid
if len(top) >= 2 and top[0][1] > 0 and top[1][1] > 0 and top[1][1] >= top[0][1] * 0.7:
return "hybrid"
return top[0][0] if top[0][1] > 0 else None
def _detect_domain(text: str) -> str:
"""Detect the most likely domain from text content."""
text_lower = text.lower()
@@ -259,10 +339,11 @@ class GeneratorConfig(BaseModel):
collections: Optional[List[str]] = None
domain: Optional[str] = None
batch_size: int = 5
max_controls: int = 50
max_controls: int = 0 # 0 = unlimited (process ALL chunks)
skip_processed: bool = True
skip_web_search: bool = False
dry_run: bool = False
existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one
@dataclass
@@ -287,6 +368,9 @@ class GeneratedControl:
source_citation: Optional[dict] = None
customer_visible: bool = True
generation_metadata: dict = field(default_factory=dict)
# Classification fields
verification_method: Optional[str] = None # code_review, document, tool, hybrid
category: Optional[str] = None # one of 17 categories
@dataclass
@@ -299,6 +383,7 @@ class GeneratorResult:
controls_needs_review: int = 0
controls_too_close: int = 0
controls_duplicates_found: int = 0
chunks_skipped_prefilter: int = 0
errors: list = field(default_factory=list)
controls: list = field(default_factory=list)
@@ -310,14 +395,68 @@ class GeneratorResult:
async def _llm_chat(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call LLM — Anthropic Claude (primary) or Ollama (fallback)."""
if ANTHROPIC_API_KEY:
logger.info("Calling Anthropic API (model=%s)...", ANTHROPIC_MODEL)
result = await _llm_anthropic(prompt, system_prompt)
if result:
logger.info("Anthropic API success (%d chars)", len(result))
return result
logger.warning("Anthropic failed, falling back to Ollama")
logger.info("Calling Ollama (model=%s)...", OLLAMA_MODEL)
return await _llm_ollama(prompt, system_prompt)
async def _llm_local(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call local Ollama LLM only (for pre-filtering and classification tasks)."""
return await _llm_ollama(prompt, system_prompt)
PREFILTER_SYSTEM_PROMPT = """Du bist ein Compliance-Analyst. Deine Aufgabe: Prüfe ob ein Textabschnitt eine konkrete Sicherheitsanforderung, Datenschutzpflicht, oder technische/organisatorische Maßnahme enthält.
Antworte NUR mit einem JSON-Objekt: {"relevant": true/false, "reason": "kurze Begründung"}
Relevant = true wenn der Text mindestens EINE der folgenden enthält:
- Konkrete Pflicht/Anforderung ("muss", "soll", "ist sicherzustellen")
- Technische Sicherheitsmaßnahme (Verschlüsselung, Zugriffskontrolle, Logging)
- Organisatorische Maßnahme (Schulung, Dokumentation, Audit)
- Datenschutz-Vorgabe (Löschpflicht, Einwilligung, Zweckbindung)
- Risikomanagement-Anforderung
Relevant = false wenn der Text NUR enthält:
- Definitionen ohne Pflichten
- Inhaltsverzeichnisse oder Verweise
- Reine Begriffsbestimmungen
- Übergangsvorschriften ohne Substanz
- Adressaten/Geltungsbereich ohne Anforderung"""
async def _prefilter_chunk(chunk_text: str) -> tuple[bool, str]:
"""Use local LLM to check if a chunk contains an actionable requirement.
Returns (is_relevant, reason).
Much cheaper than sending every chunk to Anthropic.
"""
prompt = f"""Prüfe ob dieser Textabschnitt eine konkrete Sicherheitsanforderung oder Compliance-Pflicht enthält.
Text:
---
{chunk_text[:1500]}
---
Antworte NUR mit JSON: {{"relevant": true/false, "reason": "kurze Begründung"}}"""
try:
raw = await _llm_local(prompt, PREFILTER_SYSTEM_PROMPT)
data = _parse_llm_json(raw)
if data:
return data.get("relevant", True), data.get("reason", "")
# If parsing fails, assume relevant (don't skip)
return True, "parse_failed"
except Exception as e:
logger.warning("Prefilter failed: %s — treating as relevant", e)
return True, f"error: {e}"
async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call Anthropic Messages API."""
headers = {
@@ -364,6 +503,8 @@ async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
"model": OLLAMA_MODEL,
"messages": messages,
"stream": False,
"options": {"num_predict": 512}, # Limit response length for speed
"think": False, # Disable thinking for faster responses
}
try:
@@ -397,6 +538,26 @@ async def _get_embedding(text: str) -> list[float]:
return []
async def _get_embeddings_batch(texts: list[str], batch_size: int = 32) -> list[list[float]]:
"""Get embedding vectors for multiple texts in batches."""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{EMBEDDING_URL}/embed",
json={"texts": batch},
)
resp.raise_for_status()
embeddings = resp.json().get("embeddings", [])
all_embeddings.extend(embeddings)
except Exception as e:
logger.warning("Batch embedding failed for %d texts: %s", len(batch), e)
all_embeddings.extend([[] for _ in batch])
return all_embeddings
def _cosine_sim(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
if not a or not b or len(a) != len(b):
@@ -464,50 +625,96 @@ class ControlGeneratorPipeline:
# ── Stage 1: RAG Scan ──────────────────────────────────────────────
async def _scan_rag(self, config: GeneratorConfig) -> list[RAGSearchResult]:
"""Load unprocessed chunks from RAG collections."""
"""Scroll through ALL chunks in RAG collections.
Uses the scroll endpoint to iterate over every chunk (not just top-K search).
Filters out already-processed chunks by hash.
"""
collections = config.collections or ALL_COLLECTIONS
all_results: list[RAGSearchResult] = []
queries = [
"security requirement control measure",
"Sicherheitsanforderung Maßnahme Prüfaspekt",
"compliance requirement audit criterion",
"data protection privacy obligation",
"access control authentication authorization",
]
# Pre-load all processed hashes for fast filtering
processed_hashes: set[str] = set()
if config.skip_processed:
try:
result = self.db.execute(
text("SELECT chunk_hash FROM canonical_processed_chunks")
)
processed_hashes = {row[0] for row in result}
logger.info("Loaded %d processed chunk hashes", len(processed_hashes))
except Exception as e:
logger.warning("Error loading processed hashes: %s", e)
if config.domain:
domain_kw = DOMAIN_KEYWORDS.get(config.domain, [])
if domain_kw:
queries.append(" ".join(domain_kw[:5]))
seen_hashes: set[str] = set()
for collection in collections:
for query in queries:
results = await self.rag.search(
query=query,
offset = None
page = 0
collection_total = 0
collection_new = 0
seen_offsets: set[str] = set() # Detect scroll loops
while True:
chunks, next_offset = await self.rag.scroll(
collection=collection,
top_k=20,
offset=offset,
limit=200,
)
all_results.extend(results)
# Deduplicate by text hash
seen_hashes: set[str] = set()
unique: list[RAGSearchResult] = []
for r in all_results:
h = hashlib.sha256(r.text.encode()).hexdigest()
if h not in seen_hashes:
seen_hashes.add(h)
unique.append(r)
if not chunks:
break
# Filter out already-processed chunks
if config.skip_processed and unique:
hashes = [hashlib.sha256(r.text.encode()).hexdigest() for r in unique]
processed = self._get_processed_hashes(hashes)
unique = [r for r, h in zip(unique, hashes) if h not in processed]
collection_total += len(chunks)
logger.info("RAG scan: %d unique chunks (%d after filtering processed)",
len(seen_hashes), len(unique))
return unique[:config.max_controls * 3] # Over-fetch to account for duplicates
for chunk in chunks:
if not chunk.text or len(chunk.text.strip()) < 50:
continue # Skip empty/tiny chunks
h = hashlib.sha256(chunk.text.encode()).hexdigest()
# Skip duplicates (same text in multiple collections)
if h in seen_hashes:
continue
seen_hashes.add(h)
# Skip already-processed
if h in processed_hashes:
continue
all_results.append(chunk)
collection_new += 1
page += 1
if page % 50 == 0:
logger.info(
"Scrolling %s: page %d, %d total chunks, %d new unprocessed",
collection, page, collection_total, collection_new,
)
# Stop conditions
if not next_offset:
break
# Detect infinite scroll loops (Qdrant mixed ID types)
if next_offset in seen_offsets:
logger.warning(
"Scroll loop detected in %s at offset %s (page %d) — stopping",
collection, next_offset, page,
)
break
seen_offsets.add(next_offset)
offset = next_offset
logger.info(
"Collection %s: %d total chunks scrolled, %d new unprocessed",
collection, collection_total, collection_new,
)
logger.info(
"RAG scroll complete: %d total unique seen, %d new unprocessed to process",
len(seen_hashes), len(all_results),
)
return all_results
def _get_processed_hashes(self, hashes: list[str]) -> set[str]:
"""Check which chunk hashes are already processed."""
@@ -568,6 +775,8 @@ Quelle: {chunk.regulation_name} ({chunk.regulation_code}), {chunk.article}"""
"url": chunk.source_url or "",
}
control.customer_visible = True
control.verification_method = _detect_verification_method(chunk.text)
control.category = _detect_category(chunk.text)
control.generation_metadata = {
"processing_path": "structured",
"license_rule": 1,
@@ -617,6 +826,8 @@ Quelle: {chunk.regulation_name}, {chunk.article}"""
"url": chunk.source_url or "",
}
control.customer_visible = True
control.verification_method = _detect_verification_method(chunk.text)
control.category = _detect_category(chunk.text)
control.generation_metadata = {
"processing_path": "structured",
"license_rule": 2,
@@ -661,6 +872,8 @@ Gib JSON zurück mit diesen Feldern:
control.source_original_text = None # NEVER store original
control.source_citation = None # NEVER cite source
control.customer_visible = False # Only our formulation
control.verification_method = _detect_verification_method(chunk.text)
control.category = _detect_category(chunk.text)
# generation_metadata: NO source names, NO original texts
control.generation_metadata = {
"processing_path": "llm_reform",
@@ -676,6 +889,10 @@ Gib JSON zurück mit diesen Feldern:
if not existing:
return None
# Pre-load all existing embeddings in batch (once per pipeline run)
if not self._existing_embeddings:
await self._preload_embeddings(existing)
new_text = f"{new_control.title} {new_control.objective}"
new_emb = await _get_embedding(new_text)
if not new_emb:
@@ -684,14 +901,7 @@ Gib JSON zurück mit diesen Feldern:
similar = []
for ex in existing:
ex_key = ex.get("control_id", "")
ex_text = f"{ex.get('title', '')} {ex.get('objective', '')}"
# Get or compute embedding for existing control
if ex_key not in self._existing_embeddings:
emb = await _get_embedding(ex_text)
self._existing_embeddings[ex_key] = emb
ex_emb = self._existing_embeddings.get(ex_key, [])
if not ex_emb:
continue
@@ -705,6 +915,20 @@ Gib JSON zurück mit diesen Feldern:
return similar if similar else None
async def _preload_embeddings(self, existing: list[dict]):
"""Pre-load embeddings for all existing controls in batches."""
texts = [f"{ex.get('title', '')} {ex.get('objective', '')}" for ex in existing]
keys = [ex.get("control_id", "") for ex in existing]
logger.info("Pre-loading embeddings for %d existing controls...", len(texts))
embeddings = await _get_embeddings_batch(texts)
for key, emb in zip(keys, embeddings):
self._existing_embeddings[key] = emb
loaded = sum(1 for emb in embeddings if emb)
logger.info("Pre-loaded %d/%d embeddings", loaded, len(texts))
def _load_existing_controls(self) -> list[dict]:
"""Load existing controls from DB (cached per pipeline run)."""
if self._existing_controls is not None:
@@ -799,10 +1023,11 @@ Gib JSON zurück mit diesen Feldern:
return str(uuid.uuid4())
def _update_job(self, job_id: str, result: GeneratorResult):
"""Update job with final stats."""
"""Update job with current stats. Sets completed_at only when status is final."""
is_final = result.status in ("completed", "failed")
try:
self.db.execute(
text("""
text(f"""
UPDATE canonical_generation_jobs
SET status = :status,
total_chunks_scanned = :scanned,
@@ -811,8 +1036,8 @@ Gib JSON zurück mit diesen Feldern:
controls_needs_review = :needs_review,
controls_too_close = :too_close,
controls_duplicates_found = :duplicates,
errors = :errors,
completed_at = NOW()
errors = :errors
{"" if not is_final else ", completed_at = NOW()"}
WHERE id = CAST(:job_id AS uuid)
"""),
{
@@ -857,14 +1082,16 @@ Gib JSON zurück mit diesen Feldern:
severity, risk_score, implementation_effort,
open_anchors, release_state, tags,
license_rule, source_original_text, source_citation,
customer_visible, generation_metadata
customer_visible, generation_metadata,
verification_method, category
) VALUES (
:framework_id, :control_id, :title, :objective, :rationale,
:scope, :requirements, :test_procedure, :evidence,
:severity, :risk_score, :implementation_effort,
:open_anchors, :release_state, :tags,
:license_rule, :source_original_text, :source_citation,
:customer_visible, :generation_metadata
:customer_visible, :generation_metadata,
:verification_method, :category
)
ON CONFLICT (framework_id, control_id) DO NOTHING
RETURNING id
@@ -890,6 +1117,8 @@ Gib JSON zurück mit diesen Feldern:
"source_citation": json.dumps(control.source_citation) if control.source_citation else None,
"customer_visible": control.customer_visible,
"generation_metadata": json.dumps(control.generation_metadata) if control.generation_metadata else None,
"verification_method": control.verification_method,
"category": control.category,
},
)
self.db.commit()
@@ -926,7 +1155,7 @@ Gib JSON zurück mit diesen Feldern:
"""),
{
"hash": chunk_hash,
"collection": "bp_compliance_ce", # Default, we don't track collection per result
"collection": chunk.collection or "bp_compliance_ce",
"regulation_code": chunk.regulation_code,
"doc_version": "1.0",
"license": license_info.get("license", ""),
@@ -946,8 +1175,11 @@ Gib JSON zurück mit diesen Feldern:
"""Execute the full 7-stage pipeline."""
result = GeneratorResult()
# Create job
job_id = self._create_job(config)
# Create or reuse job
if config.existing_job_id:
job_id = config.existing_job_id
else:
job_id = self._create_job(config)
result.job_id = job_id
try:
@@ -962,13 +1194,37 @@ Gib JSON zurück mit diesen Feldern:
# Process chunks
controls_count = 0
for chunk in chunks:
if controls_count >= config.max_controls:
break
chunks_skipped_prefilter = 0
for i, chunk in enumerate(chunks):
try:
# Progress logging every 50 chunks
if i > 0 and i % 50 == 0:
logger.info(
"Progress: %d/%d chunks processed, %d controls generated, %d skipped by prefilter",
i, len(chunks), controls_count, chunks_skipped_prefilter,
)
self._update_job(job_id, result)
# Stage 1.5: Local LLM pre-filter — skip chunks without requirements
if not config.dry_run:
is_relevant, prefilter_reason = await _prefilter_chunk(chunk.text)
if not is_relevant:
chunks_skipped_prefilter += 1
# Mark as processed so we don't re-check next time
license_info = self._classify_license(chunk)
self._mark_chunk_processed(
chunk, license_info, "prefilter_skip", [], job_id
)
continue
control = await self._process_single_chunk(chunk, config, job_id)
if control is None:
# No control generated — still mark as processed
if not config.dry_run:
license_info = self._classify_license(chunk)
self._mark_chunk_processed(
chunk, license_info, "no_control", [], job_id
)
continue
# Count by state
@@ -989,6 +1245,12 @@ Gib JSON zurück mit diesen Feldern:
license_info = self._classify_license(chunk)
path = "llm_reform" if license_info["rule"] == 3 else "structured"
self._mark_chunk_processed(chunk, license_info, path, [ctrl_uuid], job_id)
else:
# Store failed — still mark as processed
license_info = self._classify_license(chunk)
self._mark_chunk_processed(
chunk, license_info, "store_failed", [], job_id
)
result.controls_generated += 1
result.controls.append(asdict(control))
@@ -1006,6 +1268,21 @@ Gib JSON zurück mit diesen Feldern:
error_msg = f"Error processing chunk {chunk.regulation_code}/{chunk.article}: {e}"
logger.error(error_msg)
result.errors.append(error_msg)
# Mark failed chunks as processed too (so we don't retry endlessly)
try:
if not config.dry_run:
license_info = self._classify_license(chunk)
self._mark_chunk_processed(
chunk, license_info, "error", [], job_id
)
except Exception:
pass
result.chunks_skipped_prefilter = chunks_skipped_prefilter
logger.info(
"Pipeline complete: %d controls generated, %d chunks skipped by prefilter, %d total chunks",
controls_count, chunks_skipped_prefilter, len(chunks),
)
result.status = "completed"

View File

@@ -33,6 +33,7 @@ class RAGSearchResult:
paragraph: str
source_url: str
score: float
collection: str = ""
class ComplianceRAGClient:
@@ -91,6 +92,7 @@ class ComplianceRAGClient:
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=r.get("score", 0.0),
collection=collection,
))
return results
@@ -98,6 +100,54 @@ class ComplianceRAGClient:
logger.warning("RAG search failed: %s", e)
return []
async def scroll(
self,
collection: str,
offset: Optional[str] = None,
limit: int = 100,
) -> tuple[List[RAGSearchResult], Optional[str]]:
"""
Scroll through ALL chunks in a collection (paginated).
Returns (chunks, next_offset). next_offset is None when done.
"""
scroll_url = self._search_url.replace("/search", "/scroll")
params = {"collection": collection, "limit": str(limit)}
if offset:
params["offset"] = offset
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(scroll_url, params=params)
if resp.status_code != 200:
logger.warning(
"RAG scroll returned %d: %s", resp.status_code, resp.text[:200]
)
return [], None
data = resp.json()
results = []
for r in data.get("chunks", []):
results.append(RAGSearchResult(
text=r.get("text", ""),
regulation_code=r.get("regulation_code", ""),
regulation_name=r.get("regulation_name", ""),
regulation_short=r.get("regulation_short", ""),
category=r.get("category", ""),
article=r.get("article", ""),
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=0.0,
collection=collection,
))
next_offset = data.get("next_offset") or None
return results, next_offset
except Exception as e:
logger.warning("RAG scroll failed: %s", e)
return [], None
def format_for_prompt(
self, results: List[RAGSearchResult], max_results: int = 5
) -> str: