fix: Pipeline-Skalierung — 6 Optimierungen für 80k+ Controls

1. control_generator: GeneratorResult.status Default "completed" → "running" (Bug)
2. control_generator: Anthropic API mit Phase-Timeouts + Retry bei Disconnect
3. control_generator: regulation_exclude Filter + Harmonization via Qdrant statt In-Memory
4. decomposition_pass: Enrich Pass Batch-UPDATEs (400k → ~400 DB-Calls)
5. decomposition_pass: Merge Pass single Query statt N+1
6. batch_dedup_runner: Cross-Group Dedup parallelisiert (asyncio.gather)
7. canonical_control_routes: Framework Controls API Pagination (limit/offset)
8. DB-Indizes: idx_oc_parent_release, idx_oc_trigger_null, idx_cc_framework

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-11 14:09:32 +02:00
parent fc71117bf2
commit f89ce46631
5 changed files with 291 additions and 141 deletions

View File

@@ -274,6 +274,8 @@ async def list_framework_controls(
verification_method: Optional[str] = Query(None), verification_method: Optional[str] = Query(None),
category: Optional[str] = Query(None), category: Optional[str] = Query(None),
target_audience: Optional[str] = Query(None), target_audience: Optional[str] = Query(None),
limit: Optional[int] = Query(None, ge=1, le=5000),
offset: Optional[int] = Query(None, ge=0),
): ):
"""List controls belonging to a framework.""" """List controls belonging to a framework."""
with SessionLocal() as db: with SessionLocal() as db:
@@ -309,6 +311,12 @@ async def list_framework_controls(
params["ta"] = json.dumps([target_audience]) params["ta"] = json.dumps([target_audience])
query += " ORDER BY control_id" query += " ORDER BY control_id"
if limit is not None:
query += " LIMIT :lim"
params["lim"] = limit
if offset is not None:
query += " OFFSET :off"
params["off"] = offset
rows = db.execute(text(query), params).fetchall() rows = db.execute(text(query), params).fetchall()
return [_control_row(r) for r in rows] return [_control_row(r) for r in rows]

View File

@@ -55,6 +55,7 @@ class GenerateRequest(BaseModel):
skip_web_search: bool = False skip_web_search: bool = False
dry_run: bool = False dry_run: bool = False
regulation_filter: Optional[List[str]] = None # Only process these regulation_code prefixes regulation_filter: Optional[List[str]] = None # Only process these regulation_code prefixes
regulation_exclude: Optional[List[str]] = None # Skip these regulation_code prefixes
skip_prefilter: bool = False # Skip local LLM pre-filter, send all chunks to API skip_prefilter: bool = False # Skip local LLM pre-filter, send all chunks to API
@@ -148,6 +149,7 @@ async def start_generation(req: GenerateRequest):
skip_web_search=req.skip_web_search, skip_web_search=req.skip_web_search,
dry_run=req.dry_run, dry_run=req.dry_run,
regulation_filter=req.regulation_filter, regulation_filter=req.regulation_filter,
regulation_exclude=req.regulation_exclude,
skip_prefilter=req.skip_prefilter, skip_prefilter=req.skip_prefilter,
) )

View File

@@ -17,6 +17,7 @@ Usage:
stats = await runner.run(hint_filter="implement:multi_factor_auth:none") stats = await runner.run(hint_filter="implement:multi_factor_auth:none")
""" """
import asyncio
import json import json
import logging import logging
import time import time
@@ -342,67 +343,78 @@ class BatchDedupRunner:
cross_linked = 0 cross_linked = 0
cross_review = 0 cross_review = 0
for i, r in enumerate(rows): # Process in parallel batches for embedding + Qdrant search
uuid = r[0] PARALLEL_BATCH = 10
async def _embed_and_search(r):
"""Embed one control and search Qdrant — safe for asyncio.gather."""
hint = r[3] or "" hint = r[3] or ""
parts = hint.split(":", 2) parts = hint.split(":", 2)
action = parts[0] if len(parts) > 0 else "" action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else "" obj = parts[1] if len(parts) > 1 else ""
canonical = canonicalize_text(action, obj, r[2]) canonical = canonicalize_text(action, obj, r[2])
embedding = await get_embedding(canonical) embedding = await get_embedding(canonical)
if not embedding: if not embedding:
continue return None
results = await qdrant_search_cross_regulation( results = await qdrant_search_cross_regulation(
embedding, top_k=5, collection=self.collection, embedding, top_k=5, collection=self.collection,
) )
return (r, results)
for batch_start in range(0, len(rows), PARALLEL_BATCH):
batch = rows[batch_start:batch_start + PARALLEL_BATCH]
tasks = [_embed_and_search(r) for r in batch]
results_batch = await asyncio.gather(*tasks, return_exceptions=True)
for res in results_batch:
if res is None or isinstance(res, Exception):
if isinstance(res, Exception):
logger.error("BatchDedup embed/search error: %s", res)
self.stats["errors"] += 1
continue
r, results = res
ctrl_uuid = r[0]
hint = r[3] or ""
if not results: if not results:
continue continue
# Find best match from a DIFFERENT hint group
for match in results: for match in results:
match_score = match.get("score", 0.0) match_score = match.get("score", 0.0)
match_payload = match.get("payload", {}) match_payload = match.get("payload", {})
match_uuid = match_payload.get("control_uuid", "") match_uuid = match_payload.get("control_uuid", "")
# Skip self-match if match_uuid == ctrl_uuid:
if match_uuid == uuid:
continue continue
# Must be a different hint group (otherwise already handled in Phase 1)
match_action = match_payload.get("action_normalized", "")
match_object = match_payload.get("object_normalized", "")
# Simple check: different control UUID is enough
if match_score > LINK_THRESHOLD: if match_score > LINK_THRESHOLD:
# Mark the worse one as duplicate
try: try:
self.db.execute(text(""" self.db.execute(text("""
UPDATE canonical_controls UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:dup AS uuid) WHERE id = CAST(:dup AS uuid)
AND release_state != 'duplicate' AND release_state != 'duplicate'
"""), {"master": match_uuid, "dup": uuid}) """), {"master": match_uuid, "dup": ctrl_uuid})
self.db.execute(text(""" self.db.execute(text("""
INSERT INTO control_parent_links INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence) (control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"cu": match_uuid, "pu": uuid, "conf": match_score}) """), {"cu": match_uuid, "pu": ctrl_uuid, "conf": match_score})
# Transfer parent links transferred = self._transfer_parent_links(match_uuid, ctrl_uuid)
transferred = self._transfer_parent_links(match_uuid, uuid)
self.stats["parent_links_transferred"] += transferred self.stats["parent_links_transferred"] += transferred
self.db.commit() self.db.commit()
cross_linked += 1 cross_linked += 1
except Exception as e: except Exception as e:
logger.error("BatchDedup cross-group link error %s%s: %s", logger.error("BatchDedup cross-group link error %s%s: %s",
uuid, match_uuid, e) ctrl_uuid, match_uuid, e)
self.db.rollback() self.db.rollback()
self.stats["errors"] += 1 self.stats["errors"] += 1
break # Only one cross-link per control break
elif match_score > REVIEW_THRESHOLD: elif match_score > REVIEW_THRESHOLD:
self._write_review( self._write_review(
{"control_id": r[1], "title": r[2], "objective": "", {"control_id": r[1], "title": r[2], "objective": "",
@@ -412,10 +424,11 @@ class BatchDedupRunner:
cross_review += 1 cross_review += 1
break break
self._progress_count = i + 1 processed = min(batch_start + PARALLEL_BATCH, len(rows))
if (i + 1) % 500 == 0: self._progress_count = processed
if processed % 500 < PARALLEL_BATCH:
logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review", logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review",
i + 1, len(rows), cross_linked, cross_review) processed, len(rows), cross_linked, cross_review)
self.stats["cross_group_linked"] = cross_linked self.stats["cross_group_linked"] = cross_linked
self.stats["cross_group_review"] = cross_review self.stats["cross_group_review"] = cross_review

View File

@@ -92,6 +92,7 @@ REGULATION_LICENSE_MAP: dict[str, dict] = {
"eucsa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU Cybersecurity Act"}, "eucsa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU Cybersecurity Act"},
"dataact": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Act"}, "dataact": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Act"},
"dora": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Operational Resilience Act"}, "dora": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Operational Resilience Act"},
"eu_2017_745": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Medizinprodukteverordnung (EU) 2017/745 (MDR)"},
"ehds": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Health Data Space"}, "ehds": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Health Data Space"},
"gpsr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung"}, "gpsr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung"},
"eu_2023_988": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung (GPSR)"}, "eu_2023_988": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung (GPSR)"},
@@ -132,6 +133,39 @@ REGULATION_LICENSE_MAP: dict[str, dict] = {
"ao": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, "ao": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"},
"ao_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, "ao_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"},
"battdg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Batteriegesetz"}, "battdg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Batteriegesetz"},
# New German Laws (2026-04-10 ingestion)
"de_bsig_2025": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Gesetz (BSIG 2025, NIS2-Umsetzung)"},
"de_tdddg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TDDDG"},
"de_gwg_2017": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Geldwaeschegesetz (GwG)"},
"de_agg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Allgemeines Gleichbehandlungsgesetz (AGG)"},
"de_hinschg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Hinweisgeberschutzgesetz (HinSchG)"},
"de_lksg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Lieferkettensorgfaltspflichtengesetz (LkSG)"},
"de_kritisdachg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "KRITIS-Dachgesetz (KRITISDachG)"},
"de_bsi_kritisv": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Kritisverordnung (BSI-KritisV)"},
# DSK/BfDI Guidance (amtliche Orientierungshilfen)
"dsk_oh_ki_datenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH KI und Datenschutz"},
"dsk_oh_ki_systeme_tom": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH TOM bei KI-Systemen"},
"dsk_oh_ki_rag": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Generative KI mit RAG"},
"dsk_oh_telemedien": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Telemedien"},
"dsk_oh_digitale_dienste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Digitale Dienste"},
"dsk_oh_direktwerbung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Direktwerbung"},
"dsk_oh_videokonferenz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videokonferenzsysteme"},
"dsk_oh_videoueberwachung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videoueberwachung"},
"dsk_oh_email_verschluesselung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH E-Mail-Verschluesselung"},
"dsk_oh_whistleblowing": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Whistleblowing"},
"dsk_oh_onlinedienste_zugang": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Onlinedienste Zugang"},
"dsk_oh_datenuebermittlung_drittlaender": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Datenuebermittlung Drittlaender"},
"dsk_sdm_methode": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Standard-Datenschutzmodell SDM V3.1a"},
"dsk_ah_eu_us_dpf": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK AH EU-US Data Privacy Framework"},
"dsk_ah_bussgeldkonzept": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Bussgeldkonzept"},
"dsk_ah_dsfa_mussliste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK DSFA Muss-Liste"},
"dsk_ah_verzeichnis_vvt": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Hinweise VVT"},
"dsk_ah_zertifizierung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Zertifizierungskriterien"},
"dsk_beschluss_ms365": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Festlegung Microsoft 365"},
"dsk_pos_ki_verordnung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Positionspapier KI-Verordnung"},
"dsk_entschl_beschaeftigtendatenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Entschliessung Beschaeftigtendatenschutz"},
"bfdi_handreichung_ki_behoerden": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI in Behoerden"},
"bfdi_handreichung_ki_sicherheit": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI Sicherheitsbehoerden"},
# Austrian Laws # Austrian Laws
"at_dsg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Österreichisches Datenschutzgesetz (DSG)"}, "at_dsg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Österreichisches Datenschutzgesetz (DSG)"},
"at_abgb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB"}, "at_abgb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB"},
@@ -459,6 +493,7 @@ class GeneratorConfig(BaseModel):
dry_run: bool = False dry_run: bool = False
existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one
regulation_filter: Optional[List[str]] = None # Only process chunks matching these regulation_code prefixes regulation_filter: Optional[List[str]] = None # Only process chunks matching these regulation_code prefixes
regulation_exclude: Optional[List[str]] = None # Skip chunks matching these regulation_code prefixes
skip_prefilter: bool = False # If True, skip local LLM pre-filter (send all chunks to API) skip_prefilter: bool = False # If True, skip local LLM pre-filter (send all chunks to API)
@@ -501,7 +536,7 @@ class GeneratedControl:
@dataclass @dataclass
class GeneratorResult: class GeneratorResult:
job_id: str = "" job_id: str = ""
status: str = "completed" status: str = "running"
total_chunks_scanned: int = 0 total_chunks_scanned: int = 0
controls_generated: int = 0 controls_generated: int = 0
controls_verified: int = 0 controls_verified: int = 0
@@ -583,8 +618,8 @@ Antworte NUR mit JSON: {{"relevant": true/false, "reason": "kurze Begründung"}}
return True, f"error: {e}" return True, f"error: {e}"
async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> str: async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None, max_retries: int = 2) -> str:
"""Call Anthropic Messages API.""" """Call Anthropic Messages API with retry on timeout."""
headers = { headers = {
"x-api-key": ANTHROPIC_API_KEY, "x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
@@ -598,8 +633,12 @@ async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> st
if system_prompt: if system_prompt:
payload["system"] = system_prompt payload["system"] = system_prompt
# Use explicit per-phase timeouts to prevent indefinite hangs
timeout = httpx.Timeout(connect=30.0, read=LLM_TIMEOUT, write=30.0, pool=30.0)
for attempt in range(1 + max_retries):
try: try:
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post( resp = await client.post(
"https://api.anthropic.com/v1/messages", "https://api.anthropic.com/v1/messages",
headers=headers, headers=headers,
@@ -613,6 +652,14 @@ async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> st
if content and isinstance(content, list): if content and isinstance(content, list):
return content[0].get("text", "") return content[0].get("text", "")
return "" return ""
except (httpx.TimeoutException, httpx.RemoteProtocolError) as e:
if attempt < max_retries:
logger.warning("Anthropic request attempt %d/%d failed: %s — retrying...", attempt + 1, max_retries + 1, e)
import asyncio
await asyncio.sleep(2 ** attempt)
continue
logger.error("Anthropic request failed after %d attempts: %s (type: %s)", max_retries + 1, e, type(e).__name__)
return ""
except Exception as e: except Exception as e:
logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__) logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__)
return "" return ""
@@ -941,6 +988,12 @@ class ControlGeneratorPipeline:
if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter): if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter):
continue continue
# Exclude specific regulation_codes
if config.regulation_exclude and reg_code:
code_lower = reg_code.lower()
if any(code_lower.startswith(f.lower()) for f in config.regulation_exclude):
continue
reg_name = (payload.get("regulation_name_de", "") reg_name = (payload.get("regulation_name_de", "")
or payload.get("regulation_name", "") or payload.get("regulation_name", "")
or payload.get("source_name", "") or payload.get("source_name", "")
@@ -1423,20 +1476,31 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
if structure_items: if structure_items:
s_chunks = [c for c, _ in structure_items] s_chunks = [c for c, _ in structure_items]
s_lics = [l for _, l in structure_items] s_lics = [l for _, l in structure_items]
try:
s_controls = await self._structure_batch(s_chunks, s_lics) s_controls = await self._structure_batch(s_chunks, s_lics)
except Exception as e:
import traceback
logger.warning("Batch structure failed: %s — creating fallback controls\n%s", e, traceback.format_exc())
s_controls = [self._fallback_control(c) for c in s_chunks]
for (chunk, _), ctrl in zip(structure_items, s_controls): for (chunk, _), ctrl in zip(structure_items, s_controls):
orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk)
all_controls[orig_idx] = ctrl all_controls[orig_idx] = ctrl
if reform_items: if reform_items:
r_chunks = [c for c, _ in reform_items] r_chunks = [c for c, _ in reform_items]
try:
r_controls = await self._reformulate_batch(r_chunks, config) r_controls = await self._reformulate_batch(r_chunks, config)
except Exception as e:
logger.warning("Batch reform failed: %s — creating fallback controls", e)
r_controls = [self._fallback_control(c) for c in r_chunks]
for (chunk, _), ctrl in zip(reform_items, r_controls): for (chunk, _), ctrl in zip(reform_items, r_controls):
orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk)
if ctrl: if ctrl:
# Too-Close-Check for Rule 3 # Too-Close-Check for Rule 3
similarity = await check_similarity(chunk.text, f"{ctrl.objective} {ctrl.rationale}") similarity = await check_similarity(chunk.text, f"{ctrl.objective} {ctrl.rationale}")
if similarity.status == "FAIL": if similarity is None:
logger.warning("Similarity check returned None — skipping too-close check")
elif similarity.status == "FAIL":
ctrl.release_state = "too_close" ctrl.release_state = "too_close"
ctrl.generation_metadata["similarity_status"] = "FAIL" ctrl.generation_metadata["similarity_status"] = "FAIL"
ctrl.generation_metadata["similarity_scores"] = { ctrl.generation_metadata["similarity_scores"] = {
@@ -1502,36 +1566,42 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
# ── Stage 4: Harmonization ───────────────────────────────────────── # ── Stage 4: Harmonization ─────────────────────────────────────────
async def _check_harmonization(self, new_control: GeneratedControl) -> Optional[list]: async def _check_harmonization(self, new_control: GeneratedControl) -> Optional[list]:
"""Check if a new control duplicates existing ones via embedding similarity.""" """Check if a new control duplicates existing ones via Qdrant vector search.
existing = self._load_existing_controls()
if not existing:
return None
# Pre-load all existing embeddings in batch (once per pipeline run)
if not self._existing_embeddings:
await self._preload_embeddings(existing)
Uses the atomic_controls_dedup collection for fast nearest-neighbor lookup
instead of pre-loading all embeddings into memory.
"""
new_text = f"{new_control.title} {new_control.objective}" new_text = f"{new_control.title} {new_control.objective}"
new_emb = await _get_embedding(new_text) new_emb = await _get_embedding(new_text)
if not new_emb: if not new_emb:
return None return None
similar = [] try:
for ex in existing: async with httpx.AsyncClient(timeout=10.0) as client:
ex_key = ex.get("control_id", "") resp = await client.post(
ex_emb = self._existing_embeddings.get(ex_key, []) f"{QDRANT_URL}/collections/atomic_controls_dedup/points/search",
if not ex_emb: json={
continue "vector": new_emb,
"limit": 5,
"score_threshold": HARMONIZATION_THRESHOLD,
"with_payload": {"include": ["control_id", "title"]},
},
)
if resp.status_code == 200:
results = resp.json().get("result", [])
if results:
return [
{
"control_id": r["payload"].get("control_id", ""),
"title": r["payload"].get("title", ""),
"similarity": round(r["score"], 3),
}
for r in results
]
except Exception as e:
logger.warning("Qdrant dedup search failed: %s — skipping harmonization", e)
cosine = _cosine_sim(new_emb, ex_emb) return None
if cosine > HARMONIZATION_THRESHOLD:
similar.append({
"control_id": ex.get("control_id", ""),
"title": ex.get("title", ""),
"similarity": round(cosine, 3),
})
return similar if similar else None
async def _preload_embeddings(self, existing: list[dict]): async def _preload_embeddings(self, existing: list[dict]):
"""Pre-load embeddings for all existing controls in batches.""" """Pre-load embeddings for all existing controls in batches."""
@@ -1580,9 +1650,11 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
if severity not in ("low", "medium", "high", "critical"): if severity not in ("low", "medium", "high", "critical"):
severity = "medium" severity = "medium"
tags = data.get("tags", []) tags = data.get("tags") or []
if isinstance(tags, str): if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",")] tags = [t.strip() for t in tags.split(",")]
if not isinstance(tags, list):
tags = []
# Use LLM-provided domain if available, fallback to keyword-detected domain # Use LLM-provided domain if available, fallback to keyword-detected domain
llm_domain = data.get("domain") llm_domain = data.get("domain")

View File

@@ -349,8 +349,15 @@ Antworte NUR mit einem JSON-Array. Keine Erklärungen."""
def _build_pass0a_prompt( def _build_pass0a_prompt(
title: str, objective: str, requirements: str, title: str, objective: str, requirements: str,
test_procedure: str, source_ref: str test_procedure: str, source_ref: str,
source_original_text: str = ""
) -> str: ) -> str:
original_block = ""
if source_original_text:
original_block = f"""
ORIGINALTEXT (Gesetz/Verordnung — nutze fuer praezisere Pflichtableitung):
{source_original_text[:3000]}
"""
return f"""\ return f"""\
Analysiere das folgende Control und extrahiere alle einzelnen normativen \ Analysiere das folgende Control und extrahiere alle einzelnen normativen \
Pflichten als JSON-Array. Pflichten als JSON-Array.
@@ -361,7 +368,7 @@ Ziel: {objective}
Anforderungen: {requirements} Anforderungen: {requirements}
Prüfverfahren: {test_procedure} Prüfverfahren: {test_procedure}
Quellreferenz: {source_ref} Quellreferenz: {source_ref}
{original_block}
Antworte als JSON-Array: Antworte als JSON-Array:
[ [
{{ {{
@@ -2407,7 +2414,8 @@ class DecompositionPass:
query = """ query = """
SELECT cc.id, cc.control_id, cc.title, cc.objective, SELECT cc.id, cc.control_id, cc.title, cc.objective,
cc.requirements, cc.test_procedure, cc.requirements, cc.test_procedure,
cc.source_citation, cc.category cc.source_citation, cc.category,
cc.source_original_text
FROM canonical_controls cc FROM canonical_controls cc
WHERE cc.release_state NOT IN ('deprecated') WHERE cc.release_state NOT IN ('deprecated')
AND cc.parent_control_uuid IS NULL AND cc.parent_control_uuid IS NULL
@@ -2473,6 +2481,7 @@ class DecompositionPass:
"test_procedure": test_str, "test_procedure": test_str,
"source_ref": source_str, "source_ref": source_str,
"category": row[7] or "", "category": row[7] or "",
"source_original_text": row[8] or "",
}) })
# Process in batches # Process in batches
@@ -2507,6 +2516,7 @@ class DecompositionPass:
requirements=ctrl["requirements"], requirements=ctrl["requirements"],
test_procedure=ctrl["test_procedure"], test_procedure=ctrl["test_procedure"],
source_ref=ctrl["source_ref"], source_ref=ctrl["source_ref"],
source_original_text=ctrl.get("source_original_text", ""),
) )
llm_response = await _llm_anthropic( llm_response = await _llm_anthropic(
prompt=prompt, prompt=prompt,
@@ -2529,6 +2539,7 @@ class DecompositionPass:
requirements=ctrl["requirements"], requirements=ctrl["requirements"],
test_procedure=ctrl["test_procedure"], test_procedure=ctrl["test_procedure"],
source_ref=ctrl["source_ref"], source_ref=ctrl["source_ref"],
source_original_text=ctrl.get("source_original_text", ""),
) )
llm_response = await _llm_ollama( llm_response = await _llm_ollama(
prompt=prompt, prompt=prompt,
@@ -3008,29 +3019,36 @@ class DecompositionPass:
"obligations_kept": 0, "obligations_kept": 0,
} }
# Get all parents that have >1 validated obligation # Load ALL obligations in one query (avoids N+1 per parent)
parents = self.db.execute(text(""" all_obligs = self.db.execute(text("""
SELECT parent_control_uuid, count(*) AS cnt SELECT oc.id, oc.candidate_id, oc.obligation_text, oc.action, oc.object,
oc.parent_control_uuid
FROM obligation_candidates oc
WHERE oc.release_state = 'validated'
AND oc.merged_into_id IS NULL
AND oc.parent_control_uuid IN (
SELECT parent_control_uuid
FROM obligation_candidates FROM obligation_candidates
WHERE release_state = 'validated' WHERE release_state = 'validated'
AND merged_into_id IS NULL AND merged_into_id IS NULL
GROUP BY parent_control_uuid GROUP BY parent_control_uuid
HAVING count(*) > 1 HAVING count(*) > 1
)
ORDER BY oc.parent_control_uuid, oc.created_at
""")).fetchall() """)).fetchall()
for parent_uuid, cnt in parents: # Group by parent in Python
stats["parents_checked"] += 1 from collections import defaultdict
obligs = self.db.execute(text(""" parent_groups: dict[str, list] = defaultdict(list)
SELECT id, candidate_id, obligation_text, action, object for row in all_obligs:
FROM obligation_candidates parent_groups[str(row[5])].append(row)
WHERE parent_control_uuid = CAST(:pid AS uuid)
AND release_state = 'validated'
AND merged_into_id IS NULL
ORDER BY created_at
"""), {"pid": str(parent_uuid)}).fetchall()
merged_ids = set() merge_batch: list[dict] = []
oblig_list = list(obligs) MERGE_FLUSH_SIZE = 200
for parent_uuid, oblig_list in parent_groups.items():
stats["parents_checked"] += 1
merged_ids: set[str] = set()
for i in range(len(oblig_list)): for i in range(len(oblig_list)):
if str(oblig_list[i][0]) in merged_ids: if str(oblig_list[i][0]) in merged_ids:
@@ -3044,13 +3062,11 @@ class DecompositionPass:
obj_i = (oblig_list[i][4] or "").lower().strip() obj_i = (oblig_list[i][4] or "").lower().strip()
obj_j = (oblig_list[j][4] or "").lower().strip() obj_j = (oblig_list[j][4] or "").lower().strip()
# Check if actions are similar enough to be duplicates
if not _text_similar(action_i, action_j, threshold=0.75): if not _text_similar(action_i, action_j, threshold=0.75):
continue continue
if not _text_similar(obj_i, obj_j, threshold=0.60): if not _text_similar(obj_i, obj_j, threshold=0.60):
continue continue
# Keep the more abstract one (shorter text = less specific)
text_i = oblig_list[i][2] or "" text_i = oblig_list[i][2] or ""
text_j = oblig_list[j][2] or "" text_j = oblig_list[j][2] or ""
if _is_more_implementation_specific(text_j, text_i): if _is_more_implementation_specific(text_j, text_i):
@@ -3060,17 +3076,30 @@ class DecompositionPass:
survivor_id = str(oblig_list[j][0]) survivor_id = str(oblig_list[j][0])
merged_id = str(oblig_list[i][0]) merged_id = str(oblig_list[i][0])
merge_batch.append({"survivor": survivor_id, "merged": merged_id})
merged_ids.add(merged_id)
stats["obligations_merged"] += 1
# Flush batch periodically
if len(merge_batch) >= MERGE_FLUSH_SIZE:
for m in merge_batch:
self.db.execute(text(""" self.db.execute(text("""
UPDATE obligation_candidates UPDATE obligation_candidates
SET release_state = 'merged', SET release_state = 'merged',
merged_into_id = CAST(:survivor AS uuid) merged_into_id = CAST(:survivor AS uuid)
WHERE id = CAST(:merged AS uuid) WHERE id = CAST(:merged AS uuid)
"""), {"survivor": survivor_id, "merged": merged_id}) """), m)
self.db.commit()
merge_batch.clear()
merged_ids.add(merged_id) # Flush remainder
stats["obligations_merged"] += 1 for m in merge_batch:
self.db.execute(text("""
# Commit per parent to avoid large transactions UPDATE obligation_candidates
SET release_state = 'merged',
merged_into_id = CAST(:survivor AS uuid)
WHERE id = CAST(:merged AS uuid)
"""), m)
self.db.commit() self.db.commit()
stats["obligations_kept"] = self.db.execute(text(""" stats["obligations_kept"] = self.db.execute(text("""
@@ -3106,6 +3135,10 @@ class DecompositionPass:
AND trigger_type IS NULL AND trigger_type IS NULL
""")).fetchall() """)).fetchall()
# Classify all obligations first, then batch-update
BATCH_SIZE = 500
pending_updates: list[dict] = []
for row in obligs: for row in obligs:
oc_id = str(row[0]) oc_id = str(row[0])
obl_text = row[1] or "" obl_text = row[1] or ""
@@ -3116,22 +3149,42 @@ class DecompositionPass:
trigger = _classify_trigger_type(obl_text, condition) trigger = _classify_trigger_type(obl_text, condition)
impl = _is_implementation_specific_text(obl_text, action, obj) impl = _is_implementation_specific_text(obl_text, action, obj)
self.db.execute(text(""" pending_updates.append({"oid": oc_id, "trigger": trigger, "impl": impl})
UPDATE obligation_candidates
SET trigger_type = :trigger,
is_implementation_specific = :impl
WHERE id = CAST(:oid AS uuid)
"""), {"trigger": trigger, "impl": impl, "oid": oc_id})
stats["enriched"] += 1 stats["enriched"] += 1
stats[f"trigger_{trigger}"] += 1 stats[f"trigger_{trigger}"] = stats.get(f"trigger_{trigger}", 0) + 1
if impl: if impl:
stats["implementation_specific"] += 1 stats["implementation_specific"] += 1
# Flush batch
if len(pending_updates) >= BATCH_SIZE:
self._flush_enrich_batch(pending_updates)
pending_updates.clear()
# Flush remainder
if pending_updates:
self._flush_enrich_batch(pending_updates)
self.db.commit() self.db.commit()
logger.info("Enrich pass: %s", stats) logger.info("Enrich pass: %s", stats)
return stats return stats
def _flush_enrich_batch(self, updates: list[dict]):
"""Batch-UPDATE obligation_candidates for enrich pass."""
# Group by (trigger, impl) to minimize UPDATE statements
from collections import defaultdict
groups: dict[tuple, list[str]] = defaultdict(list)
for u in updates:
groups[(u["trigger"], u["impl"])].append(u["oid"])
for (trigger, impl), ids in groups.items():
# Use ANY(ARRAY[...]) for batch WHERE clause
self.db.execute(text("""
UPDATE obligation_candidates
SET trigger_type = :trigger,
is_implementation_specific = :impl
WHERE id = ANY(CAST(:ids AS uuid[]))
"""), {"trigger": trigger, "impl": impl, "ids": ids})
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# Decomposition Status # Decomposition Status
# ------------------------------------------------------------------- # -------------------------------------------------------------------
@@ -3365,7 +3418,8 @@ class DecompositionPass:
query = """ query = """
SELECT cc.id, cc.control_id, cc.title, cc.objective, SELECT cc.id, cc.control_id, cc.title, cc.objective,
cc.requirements, cc.test_procedure, cc.requirements, cc.test_procedure,
cc.source_citation, cc.category cc.source_citation, cc.category,
cc.source_original_text
FROM canonical_controls cc FROM canonical_controls cc
WHERE cc.release_state NOT IN ('deprecated') WHERE cc.release_state NOT IN ('deprecated')
AND cc.parent_control_uuid IS NULL AND cc.parent_control_uuid IS NULL
@@ -3414,6 +3468,7 @@ class DecompositionPass:
"test_procedure": _format_field(row[5] or ""), "test_procedure": _format_field(row[5] or ""),
"source_ref": _format_citation(row[6] or ""), "source_ref": _format_citation(row[6] or ""),
"category": row[7] or "", "category": row[7] or "",
"source_original_text": row[8] or "",
}) })
if not prepared: if not prepared: