From f89ce466312b7d58fe3dee5a6be5fd28eafdaa42 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sat, 11 Apr 2026 14:09:32 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20Pipeline-Skalierung=20=E2=80=94=206=20Op?= =?UTF-8?q?timierungen=20f=C3=BCr=2080k+=20Controls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. control_generator: GeneratorResult.status Default "completed" → "running" (Bug) 2. control_generator: Anthropic API mit Phase-Timeouts + Retry bei Disconnect 3. control_generator: regulation_exclude Filter + Harmonization via Qdrant statt In-Memory 4. decomposition_pass: Enrich Pass Batch-UPDATEs (400k → ~400 DB-Calls) 5. decomposition_pass: Merge Pass single Query statt N+1 6. batch_dedup_runner: Cross-Group Dedup parallelisiert (asyncio.gather) 7. canonical_control_routes: Framework Controls API Pagination (limit/offset) 8. DB-Indizes: idx_oc_parent_release, idx_oc_trigger_null, idx_cc_framework Co-Authored-By: Claude Opus 4.6 (1M context) --- .../api/canonical_control_routes.py | 8 + .../api/control_generator_routes.py | 2 + .../services/batch_dedup_runner.py | 123 +++++++------ .../services/control_generator.py | 164 +++++++++++++----- .../services/decomposition_pass.py | 135 +++++++++----- 5 files changed, 291 insertions(+), 141 deletions(-) diff --git a/control-pipeline/api/canonical_control_routes.py b/control-pipeline/api/canonical_control_routes.py index 525de9d..58426f5 100644 --- a/control-pipeline/api/canonical_control_routes.py +++ b/control-pipeline/api/canonical_control_routes.py @@ -274,6 +274,8 @@ async def list_framework_controls( verification_method: Optional[str] = Query(None), category: Optional[str] = Query(None), target_audience: Optional[str] = Query(None), + limit: Optional[int] = Query(None, ge=1, le=5000), + offset: Optional[int] = Query(None, ge=0), ): """List controls belonging to a framework.""" with SessionLocal() as db: @@ -309,6 +311,12 @@ async def list_framework_controls( params["ta"] = json.dumps([target_audience]) query += " ORDER BY control_id" + if limit is not None: + query += " LIMIT :lim" + params["lim"] = limit + if offset is not None: + query += " OFFSET :off" + params["off"] = offset rows = db.execute(text(query), params).fetchall() return [_control_row(r) for r in rows] diff --git a/control-pipeline/api/control_generator_routes.py b/control-pipeline/api/control_generator_routes.py index efbd311..0d95fcf 100644 --- a/control-pipeline/api/control_generator_routes.py +++ b/control-pipeline/api/control_generator_routes.py @@ -55,6 +55,7 @@ class GenerateRequest(BaseModel): skip_web_search: bool = False dry_run: bool = False regulation_filter: Optional[List[str]] = None # Only process these regulation_code prefixes + regulation_exclude: Optional[List[str]] = None # Skip these regulation_code prefixes skip_prefilter: bool = False # Skip local LLM pre-filter, send all chunks to API @@ -148,6 +149,7 @@ async def start_generation(req: GenerateRequest): skip_web_search=req.skip_web_search, dry_run=req.dry_run, regulation_filter=req.regulation_filter, + regulation_exclude=req.regulation_exclude, skip_prefilter=req.skip_prefilter, ) diff --git a/control-pipeline/services/batch_dedup_runner.py b/control-pipeline/services/batch_dedup_runner.py index fa7b18b..715cf27 100644 --- a/control-pipeline/services/batch_dedup_runner.py +++ b/control-pipeline/services/batch_dedup_runner.py @@ -17,6 +17,7 @@ Usage: stats = await runner.run(hint_filter="implement:multi_factor_auth:none") """ +import asyncio import json import logging import time @@ -342,80 +343,92 @@ class BatchDedupRunner: cross_linked = 0 cross_review = 0 - for i, r in enumerate(rows): - uuid = r[0] + # Process in parallel batches for embedding + Qdrant search + PARALLEL_BATCH = 10 + + async def _embed_and_search(r): + """Embed one control and search Qdrant — safe for asyncio.gather.""" hint = r[3] or "" parts = hint.split(":", 2) action = parts[0] if len(parts) > 0 else "" obj = parts[1] if len(parts) > 1 else "" - canonical = canonicalize_text(action, obj, r[2]) embedding = await get_embedding(canonical) if not embedding: - continue - + return None results = await qdrant_search_cross_regulation( embedding, top_k=5, collection=self.collection, ) - if not results: - continue + return (r, results) - # Find best match from a DIFFERENT hint group - for match in results: - match_score = match.get("score", 0.0) - match_payload = match.get("payload", {}) - match_uuid = match_payload.get("control_uuid", "") + for batch_start in range(0, len(rows), PARALLEL_BATCH): + batch = rows[batch_start:batch_start + PARALLEL_BATCH] + tasks = [_embed_and_search(r) for r in batch] + results_batch = await asyncio.gather(*tasks, return_exceptions=True) - # Skip self-match - if match_uuid == uuid: + for res in results_batch: + if res is None or isinstance(res, Exception): + if isinstance(res, Exception): + logger.error("BatchDedup embed/search error: %s", res) + self.stats["errors"] += 1 continue - # Must be a different hint group (otherwise already handled in Phase 1) - match_action = match_payload.get("action_normalized", "") - match_object = match_payload.get("object_normalized", "") - # Simple check: different control UUID is enough - if match_score > LINK_THRESHOLD: - # Mark the worse one as duplicate - try: - self.db.execute(text(""" - UPDATE canonical_controls - SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) - WHERE id = CAST(:dup AS uuid) - AND release_state != 'duplicate' - """), {"master": match_uuid, "dup": uuid}) + r, results = res + ctrl_uuid = r[0] + hint = r[3] or "" - self.db.execute(text(""" - INSERT INTO control_parent_links - (control_uuid, parent_control_uuid, link_type, confidence) - VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) - ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING - """), {"cu": match_uuid, "pu": uuid, "conf": match_score}) + if not results: + continue - # Transfer parent links - transferred = self._transfer_parent_links(match_uuid, uuid) - self.stats["parent_links_transferred"] += transferred + for match in results: + match_score = match.get("score", 0.0) + match_payload = match.get("payload", {}) + match_uuid = match_payload.get("control_uuid", "") - self.db.commit() - cross_linked += 1 - except Exception as e: - logger.error("BatchDedup cross-group link error %s→%s: %s", - uuid, match_uuid, e) - self.db.rollback() - self.stats["errors"] += 1 - break # Only one cross-link per control - elif match_score > REVIEW_THRESHOLD: - self._write_review( - {"control_id": r[1], "title": r[2], "objective": "", - "merge_group_hint": hint, "pattern_id": None}, - match_payload, match_score, - ) - cross_review += 1 - break + if match_uuid == ctrl_uuid: + continue - self._progress_count = i + 1 - if (i + 1) % 500 == 0: + if match_score > LINK_THRESHOLD: + try: + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:dup AS uuid) + AND release_state != 'duplicate' + """), {"master": match_uuid, "dup": ctrl_uuid}) + + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"cu": match_uuid, "pu": ctrl_uuid, "conf": match_score}) + + transferred = self._transfer_parent_links(match_uuid, ctrl_uuid) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + cross_linked += 1 + except Exception as e: + logger.error("BatchDedup cross-group link error %s→%s: %s", + ctrl_uuid, match_uuid, e) + self.db.rollback() + self.stats["errors"] += 1 + break + elif match_score > REVIEW_THRESHOLD: + self._write_review( + {"control_id": r[1], "title": r[2], "objective": "", + "merge_group_hint": hint, "pattern_id": None}, + match_payload, match_score, + ) + cross_review += 1 + break + + processed = min(batch_start + PARALLEL_BATCH, len(rows)) + self._progress_count = processed + if processed % 500 < PARALLEL_BATCH: logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review", - i + 1, len(rows), cross_linked, cross_review) + processed, len(rows), cross_linked, cross_review) self.stats["cross_group_linked"] = cross_linked self.stats["cross_group_review"] = cross_review diff --git a/control-pipeline/services/control_generator.py b/control-pipeline/services/control_generator.py index b582cb3..a05dacc 100644 --- a/control-pipeline/services/control_generator.py +++ b/control-pipeline/services/control_generator.py @@ -92,6 +92,7 @@ REGULATION_LICENSE_MAP: dict[str, dict] = { "eucsa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU Cybersecurity Act"}, "dataact": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Act"}, "dora": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Operational Resilience Act"}, + "eu_2017_745": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Medizinprodukteverordnung (EU) 2017/745 (MDR)"}, "ehds": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Health Data Space"}, "gpsr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung"}, "eu_2023_988": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung (GPSR)"}, @@ -132,6 +133,39 @@ REGULATION_LICENSE_MAP: dict[str, dict] = { "ao": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, "ao_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, "battdg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Batteriegesetz"}, + # New German Laws (2026-04-10 ingestion) + "de_bsig_2025": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Gesetz (BSIG 2025, NIS2-Umsetzung)"}, + "de_tdddg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TDDDG"}, + "de_gwg_2017": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Geldwaeschegesetz (GwG)"}, + "de_agg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Allgemeines Gleichbehandlungsgesetz (AGG)"}, + "de_hinschg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Hinweisgeberschutzgesetz (HinSchG)"}, + "de_lksg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Lieferkettensorgfaltspflichtengesetz (LkSG)"}, + "de_kritisdachg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "KRITIS-Dachgesetz (KRITISDachG)"}, + "de_bsi_kritisv": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Kritisverordnung (BSI-KritisV)"}, + # DSK/BfDI Guidance (amtliche Orientierungshilfen) + "dsk_oh_ki_datenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH KI und Datenschutz"}, + "dsk_oh_ki_systeme_tom": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH TOM bei KI-Systemen"}, + "dsk_oh_ki_rag": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Generative KI mit RAG"}, + "dsk_oh_telemedien": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Telemedien"}, + "dsk_oh_digitale_dienste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Digitale Dienste"}, + "dsk_oh_direktwerbung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Direktwerbung"}, + "dsk_oh_videokonferenz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videokonferenzsysteme"}, + "dsk_oh_videoueberwachung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videoueberwachung"}, + "dsk_oh_email_verschluesselung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH E-Mail-Verschluesselung"}, + "dsk_oh_whistleblowing": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Whistleblowing"}, + "dsk_oh_onlinedienste_zugang": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Onlinedienste Zugang"}, + "dsk_oh_datenuebermittlung_drittlaender": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Datenuebermittlung Drittlaender"}, + "dsk_sdm_methode": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Standard-Datenschutzmodell SDM V3.1a"}, + "dsk_ah_eu_us_dpf": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK AH EU-US Data Privacy Framework"}, + "dsk_ah_bussgeldkonzept": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Bussgeldkonzept"}, + "dsk_ah_dsfa_mussliste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK DSFA Muss-Liste"}, + "dsk_ah_verzeichnis_vvt": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Hinweise VVT"}, + "dsk_ah_zertifizierung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Zertifizierungskriterien"}, + "dsk_beschluss_ms365": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Festlegung Microsoft 365"}, + "dsk_pos_ki_verordnung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Positionspapier KI-Verordnung"}, + "dsk_entschl_beschaeftigtendatenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Entschliessung Beschaeftigtendatenschutz"}, + "bfdi_handreichung_ki_behoerden": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI in Behoerden"}, + "bfdi_handreichung_ki_sicherheit": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI Sicherheitsbehoerden"}, # Austrian Laws "at_dsg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Österreichisches Datenschutzgesetz (DSG)"}, "at_abgb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB"}, @@ -459,6 +493,7 @@ class GeneratorConfig(BaseModel): dry_run: bool = False existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one regulation_filter: Optional[List[str]] = None # Only process chunks matching these regulation_code prefixes + regulation_exclude: Optional[List[str]] = None # Skip chunks matching these regulation_code prefixes skip_prefilter: bool = False # If True, skip local LLM pre-filter (send all chunks to API) @@ -501,7 +536,7 @@ class GeneratedControl: @dataclass class GeneratorResult: job_id: str = "" - status: str = "completed" + status: str = "running" total_chunks_scanned: int = 0 controls_generated: int = 0 controls_verified: int = 0 @@ -583,8 +618,8 @@ Antworte NUR mit JSON: {{"relevant": true/false, "reason": "kurze Begründung"}} return True, f"error: {e}" -async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> str: - """Call Anthropic Messages API.""" +async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None, max_retries: int = 2) -> str: + """Call Anthropic Messages API with retry on timeout.""" headers = { "x-api-key": ANTHROPIC_API_KEY, "anthropic-version": "2023-06-01", @@ -598,24 +633,36 @@ async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> st if system_prompt: payload["system"] = system_prompt - try: - async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: - resp = await client.post( - "https://api.anthropic.com/v1/messages", - headers=headers, - json=payload, - ) - if resp.status_code != 200: - logger.error("Anthropic API %d: %s", resp.status_code, resp.text[:300]) + # Use explicit per-phase timeouts to prevent indefinite hangs + timeout = httpx.Timeout(connect=30.0, read=LLM_TIMEOUT, write=30.0, pool=30.0) + + for attempt in range(1 + max_retries): + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload, + ) + if resp.status_code != 200: + logger.error("Anthropic API %d: %s", resp.status_code, resp.text[:300]) + return "" + data = resp.json() + content = data.get("content", []) + if content and isinstance(content, list): + return content[0].get("text", "") return "" - data = resp.json() - content = data.get("content", []) - if content and isinstance(content, list): - return content[0].get("text", "") + except (httpx.TimeoutException, httpx.RemoteProtocolError) as e: + if attempt < max_retries: + logger.warning("Anthropic request attempt %d/%d failed: %s — retrying...", attempt + 1, max_retries + 1, e) + import asyncio + await asyncio.sleep(2 ** attempt) + continue + logger.error("Anthropic request failed after %d attempts: %s (type: %s)", max_retries + 1, e, type(e).__name__) + return "" + except Exception as e: + logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__) return "" - except Exception as e: - logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__) - return "" async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str: @@ -941,6 +988,12 @@ class ControlGeneratorPipeline: if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter): continue + # Exclude specific regulation_codes + if config.regulation_exclude and reg_code: + code_lower = reg_code.lower() + if any(code_lower.startswith(f.lower()) for f in config.regulation_exclude): + continue + reg_name = (payload.get("regulation_name_de", "") or payload.get("regulation_name", "") or payload.get("source_name", "") @@ -1423,20 +1476,31 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne if structure_items: s_chunks = [c for c, _ in structure_items] s_lics = [l for _, l in structure_items] - s_controls = await self._structure_batch(s_chunks, s_lics) + try: + s_controls = await self._structure_batch(s_chunks, s_lics) + except Exception as e: + import traceback + logger.warning("Batch structure failed: %s — creating fallback controls\n%s", e, traceback.format_exc()) + s_controls = [self._fallback_control(c) for c in s_chunks] for (chunk, _), ctrl in zip(structure_items, s_controls): orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) all_controls[orig_idx] = ctrl if reform_items: r_chunks = [c for c, _ in reform_items] - r_controls = await self._reformulate_batch(r_chunks, config) + try: + r_controls = await self._reformulate_batch(r_chunks, config) + except Exception as e: + logger.warning("Batch reform failed: %s — creating fallback controls", e) + r_controls = [self._fallback_control(c) for c in r_chunks] for (chunk, _), ctrl in zip(reform_items, r_controls): orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) if ctrl: # Too-Close-Check for Rule 3 similarity = await check_similarity(chunk.text, f"{ctrl.objective} {ctrl.rationale}") - if similarity.status == "FAIL": + if similarity is None: + logger.warning("Similarity check returned None — skipping too-close check") + elif similarity.status == "FAIL": ctrl.release_state = "too_close" ctrl.generation_metadata["similarity_status"] = "FAIL" ctrl.generation_metadata["similarity_scores"] = { @@ -1502,36 +1566,42 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne # ── Stage 4: Harmonization ───────────────────────────────────────── async def _check_harmonization(self, new_control: GeneratedControl) -> Optional[list]: - """Check if a new control duplicates existing ones via embedding similarity.""" - existing = self._load_existing_controls() - if not existing: - return None - - # Pre-load all existing embeddings in batch (once per pipeline run) - if not self._existing_embeddings: - await self._preload_embeddings(existing) + """Check if a new control duplicates existing ones via Qdrant vector search. + Uses the atomic_controls_dedup collection for fast nearest-neighbor lookup + instead of pre-loading all embeddings into memory. + """ new_text = f"{new_control.title} {new_control.objective}" new_emb = await _get_embedding(new_text) if not new_emb: return None - similar = [] - for ex in existing: - ex_key = ex.get("control_id", "") - ex_emb = self._existing_embeddings.get(ex_key, []) - if not ex_emb: - continue + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{QDRANT_URL}/collections/atomic_controls_dedup/points/search", + json={ + "vector": new_emb, + "limit": 5, + "score_threshold": HARMONIZATION_THRESHOLD, + "with_payload": {"include": ["control_id", "title"]}, + }, + ) + if resp.status_code == 200: + results = resp.json().get("result", []) + if results: + return [ + { + "control_id": r["payload"].get("control_id", ""), + "title": r["payload"].get("title", ""), + "similarity": round(r["score"], 3), + } + for r in results + ] + except Exception as e: + logger.warning("Qdrant dedup search failed: %s — skipping harmonization", e) - cosine = _cosine_sim(new_emb, ex_emb) - if cosine > HARMONIZATION_THRESHOLD: - similar.append({ - "control_id": ex.get("control_id", ""), - "title": ex.get("title", ""), - "similarity": round(cosine, 3), - }) - - return similar if similar else None + return None async def _preload_embeddings(self, existing: list[dict]): """Pre-load embeddings for all existing controls in batches.""" @@ -1580,9 +1650,11 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne if severity not in ("low", "medium", "high", "critical"): severity = "medium" - tags = data.get("tags", []) + tags = data.get("tags") or [] if isinstance(tags, str): tags = [t.strip() for t in tags.split(",")] + if not isinstance(tags, list): + tags = [] # Use LLM-provided domain if available, fallback to keyword-detected domain llm_domain = data.get("domain") diff --git a/control-pipeline/services/decomposition_pass.py b/control-pipeline/services/decomposition_pass.py index 7cad3e4..9080d04 100644 --- a/control-pipeline/services/decomposition_pass.py +++ b/control-pipeline/services/decomposition_pass.py @@ -349,8 +349,15 @@ Antworte NUR mit einem JSON-Array. Keine Erklärungen.""" def _build_pass0a_prompt( title: str, objective: str, requirements: str, - test_procedure: str, source_ref: str + test_procedure: str, source_ref: str, + source_original_text: str = "" ) -> str: + original_block = "" + if source_original_text: + original_block = f""" +ORIGINALTEXT (Gesetz/Verordnung — nutze fuer praezisere Pflichtableitung): +{source_original_text[:3000]} +""" return f"""\ Analysiere das folgende Control und extrahiere alle einzelnen normativen \ Pflichten als JSON-Array. @@ -361,7 +368,7 @@ Ziel: {objective} Anforderungen: {requirements} Prüfverfahren: {test_procedure} Quellreferenz: {source_ref} - +{original_block} Antworte als JSON-Array: [ {{ @@ -2407,7 +2414,8 @@ class DecompositionPass: query = """ SELECT cc.id, cc.control_id, cc.title, cc.objective, cc.requirements, cc.test_procedure, - cc.source_citation, cc.category + cc.source_citation, cc.category, + cc.source_original_text FROM canonical_controls cc WHERE cc.release_state NOT IN ('deprecated') AND cc.parent_control_uuid IS NULL @@ -2473,6 +2481,7 @@ class DecompositionPass: "test_procedure": test_str, "source_ref": source_str, "category": row[7] or "", + "source_original_text": row[8] or "", }) # Process in batches @@ -2507,6 +2516,7 @@ class DecompositionPass: requirements=ctrl["requirements"], test_procedure=ctrl["test_procedure"], source_ref=ctrl["source_ref"], + source_original_text=ctrl.get("source_original_text", ""), ) llm_response = await _llm_anthropic( prompt=prompt, @@ -2529,6 +2539,7 @@ class DecompositionPass: requirements=ctrl["requirements"], test_procedure=ctrl["test_procedure"], source_ref=ctrl["source_ref"], + source_original_text=ctrl.get("source_original_text", ""), ) llm_response = await _llm_ollama( prompt=prompt, @@ -3008,29 +3019,36 @@ class DecompositionPass: "obligations_kept": 0, } - # Get all parents that have >1 validated obligation - parents = self.db.execute(text(""" - SELECT parent_control_uuid, count(*) AS cnt - FROM obligation_candidates - WHERE release_state = 'validated' - AND merged_into_id IS NULL - GROUP BY parent_control_uuid - HAVING count(*) > 1 + # Load ALL obligations in one query (avoids N+1 per parent) + all_obligs = self.db.execute(text(""" + SELECT oc.id, oc.candidate_id, oc.obligation_text, oc.action, oc.object, + oc.parent_control_uuid + FROM obligation_candidates oc + WHERE oc.release_state = 'validated' + AND oc.merged_into_id IS NULL + AND oc.parent_control_uuid IN ( + SELECT parent_control_uuid + FROM obligation_candidates + WHERE release_state = 'validated' + AND merged_into_id IS NULL + GROUP BY parent_control_uuid + HAVING count(*) > 1 + ) + ORDER BY oc.parent_control_uuid, oc.created_at """)).fetchall() - for parent_uuid, cnt in parents: - stats["parents_checked"] += 1 - obligs = self.db.execute(text(""" - SELECT id, candidate_id, obligation_text, action, object - FROM obligation_candidates - WHERE parent_control_uuid = CAST(:pid AS uuid) - AND release_state = 'validated' - AND merged_into_id IS NULL - ORDER BY created_at - """), {"pid": str(parent_uuid)}).fetchall() + # Group by parent in Python + from collections import defaultdict + parent_groups: dict[str, list] = defaultdict(list) + for row in all_obligs: + parent_groups[str(row[5])].append(row) - merged_ids = set() - oblig_list = list(obligs) + merge_batch: list[dict] = [] + MERGE_FLUSH_SIZE = 200 + + for parent_uuid, oblig_list in parent_groups.items(): + stats["parents_checked"] += 1 + merged_ids: set[str] = set() for i in range(len(oblig_list)): if str(oblig_list[i][0]) in merged_ids: @@ -3044,13 +3062,11 @@ class DecompositionPass: obj_i = (oblig_list[i][4] or "").lower().strip() obj_j = (oblig_list[j][4] or "").lower().strip() - # Check if actions are similar enough to be duplicates if not _text_similar(action_i, action_j, threshold=0.75): continue if not _text_similar(obj_i, obj_j, threshold=0.60): continue - # Keep the more abstract one (shorter text = less specific) text_i = oblig_list[i][2] or "" text_j = oblig_list[j][2] or "" if _is_more_implementation_specific(text_j, text_i): @@ -3060,18 +3076,31 @@ class DecompositionPass: survivor_id = str(oblig_list[j][0]) merged_id = str(oblig_list[i][0]) + merge_batch.append({"survivor": survivor_id, "merged": merged_id}) + merged_ids.add(merged_id) + stats["obligations_merged"] += 1 + + # Flush batch periodically + if len(merge_batch) >= MERGE_FLUSH_SIZE: + for m in merge_batch: self.db.execute(text(""" UPDATE obligation_candidates SET release_state = 'merged', merged_into_id = CAST(:survivor AS uuid) WHERE id = CAST(:merged AS uuid) - """), {"survivor": survivor_id, "merged": merged_id}) + """), m) + self.db.commit() + merge_batch.clear() - merged_ids.add(merged_id) - stats["obligations_merged"] += 1 - - # Commit per parent to avoid large transactions - self.db.commit() + # Flush remainder + for m in merge_batch: + self.db.execute(text(""" + UPDATE obligation_candidates + SET release_state = 'merged', + merged_into_id = CAST(:survivor AS uuid) + WHERE id = CAST(:merged AS uuid) + """), m) + self.db.commit() stats["obligations_kept"] = self.db.execute(text(""" SELECT count(*) FROM obligation_candidates @@ -3106,6 +3135,10 @@ class DecompositionPass: AND trigger_type IS NULL """)).fetchall() + # Classify all obligations first, then batch-update + BATCH_SIZE = 500 + pending_updates: list[dict] = [] + for row in obligs: oc_id = str(row[0]) obl_text = row[1] or "" @@ -3116,22 +3149,42 @@ class DecompositionPass: trigger = _classify_trigger_type(obl_text, condition) impl = _is_implementation_specific_text(obl_text, action, obj) - self.db.execute(text(""" - UPDATE obligation_candidates - SET trigger_type = :trigger, - is_implementation_specific = :impl - WHERE id = CAST(:oid AS uuid) - """), {"trigger": trigger, "impl": impl, "oid": oc_id}) - + pending_updates.append({"oid": oc_id, "trigger": trigger, "impl": impl}) stats["enriched"] += 1 - stats[f"trigger_{trigger}"] += 1 + stats[f"trigger_{trigger}"] = stats.get(f"trigger_{trigger}", 0) + 1 if impl: stats["implementation_specific"] += 1 + # Flush batch + if len(pending_updates) >= BATCH_SIZE: + self._flush_enrich_batch(pending_updates) + pending_updates.clear() + + # Flush remainder + if pending_updates: + self._flush_enrich_batch(pending_updates) + self.db.commit() logger.info("Enrich pass: %s", stats) return stats + def _flush_enrich_batch(self, updates: list[dict]): + """Batch-UPDATE obligation_candidates for enrich pass.""" + # Group by (trigger, impl) to minimize UPDATE statements + from collections import defaultdict + groups: dict[tuple, list[str]] = defaultdict(list) + for u in updates: + groups[(u["trigger"], u["impl"])].append(u["oid"]) + + for (trigger, impl), ids in groups.items(): + # Use ANY(ARRAY[...]) for batch WHERE clause + self.db.execute(text(""" + UPDATE obligation_candidates + SET trigger_type = :trigger, + is_implementation_specific = :impl + WHERE id = ANY(CAST(:ids AS uuid[])) + """), {"trigger": trigger, "impl": impl, "ids": ids}) + # ------------------------------------------------------------------- # Decomposition Status # ------------------------------------------------------------------- @@ -3365,7 +3418,8 @@ class DecompositionPass: query = """ SELECT cc.id, cc.control_id, cc.title, cc.objective, cc.requirements, cc.test_procedure, - cc.source_citation, cc.category + cc.source_citation, cc.category, + cc.source_original_text FROM canonical_controls cc WHERE cc.release_state NOT IN ('deprecated') AND cc.parent_control_uuid IS NULL @@ -3414,6 +3468,7 @@ class DecompositionPass: "test_procedure": _format_field(row[5] or ""), "source_ref": _format_citation(row[6] or ""), "category": row[7] or "", + "source_original_text": row[8] or "", }) if not prepared: