fix: Pipeline-Skalierung — 6 Optimierungen für 80k+ Controls

1. control_generator: GeneratorResult.status Default "completed" → "running" (Bug)
2. control_generator: Anthropic API mit Phase-Timeouts + Retry bei Disconnect
3. control_generator: regulation_exclude Filter + Harmonization via Qdrant statt In-Memory
4. decomposition_pass: Enrich Pass Batch-UPDATEs (400k → ~400 DB-Calls)
5. decomposition_pass: Merge Pass single Query statt N+1
6. batch_dedup_runner: Cross-Group Dedup parallelisiert (asyncio.gather)
7. canonical_control_routes: Framework Controls API Pagination (limit/offset)
8. DB-Indizes: idx_oc_parent_release, idx_oc_trigger_null, idx_cc_framework

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-11 14:09:32 +02:00
parent fc71117bf2
commit f89ce46631
5 changed files with 291 additions and 141 deletions

View File

@@ -17,6 +17,7 @@ Usage:
stats = await runner.run(hint_filter="implement:multi_factor_auth:none")
"""
import asyncio
import json
import logging
import time
@@ -342,80 +343,92 @@ class BatchDedupRunner:
cross_linked = 0
cross_review = 0
for i, r in enumerate(rows):
uuid = r[0]
# Process in parallel batches for embedding + Qdrant search
PARALLEL_BATCH = 10
async def _embed_and_search(r):
"""Embed one control and search Qdrant — safe for asyncio.gather."""
hint = r[3] or ""
parts = hint.split(":", 2)
action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else ""
canonical = canonicalize_text(action, obj, r[2])
embedding = await get_embedding(canonical)
if not embedding:
continue
return None
results = await qdrant_search_cross_regulation(
embedding, top_k=5, collection=self.collection,
)
if not results:
continue
return (r, results)
# Find best match from a DIFFERENT hint group
for match in results:
match_score = match.get("score", 0.0)
match_payload = match.get("payload", {})
match_uuid = match_payload.get("control_uuid", "")
for batch_start in range(0, len(rows), PARALLEL_BATCH):
batch = rows[batch_start:batch_start + PARALLEL_BATCH]
tasks = [_embed_and_search(r) for r in batch]
results_batch = await asyncio.gather(*tasks, return_exceptions=True)
# Skip self-match
if match_uuid == uuid:
for res in results_batch:
if res is None or isinstance(res, Exception):
if isinstance(res, Exception):
logger.error("BatchDedup embed/search error: %s", res)
self.stats["errors"] += 1
continue
# Must be a different hint group (otherwise already handled in Phase 1)
match_action = match_payload.get("action_normalized", "")
match_object = match_payload.get("object_normalized", "")
# Simple check: different control UUID is enough
if match_score > LINK_THRESHOLD:
# Mark the worse one as duplicate
try:
self.db.execute(text("""
UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:dup AS uuid)
AND release_state != 'duplicate'
"""), {"master": match_uuid, "dup": uuid})
r, results = res
ctrl_uuid = r[0]
hint = r[3] or ""
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"cu": match_uuid, "pu": uuid, "conf": match_score})
if not results:
continue
# Transfer parent links
transferred = self._transfer_parent_links(match_uuid, uuid)
self.stats["parent_links_transferred"] += transferred
for match in results:
match_score = match.get("score", 0.0)
match_payload = match.get("payload", {})
match_uuid = match_payload.get("control_uuid", "")
self.db.commit()
cross_linked += 1
except Exception as e:
logger.error("BatchDedup cross-group link error %s%s: %s",
uuid, match_uuid, e)
self.db.rollback()
self.stats["errors"] += 1
break # Only one cross-link per control
elif match_score > REVIEW_THRESHOLD:
self._write_review(
{"control_id": r[1], "title": r[2], "objective": "",
"merge_group_hint": hint, "pattern_id": None},
match_payload, match_score,
)
cross_review += 1
break
if match_uuid == ctrl_uuid:
continue
self._progress_count = i + 1
if (i + 1) % 500 == 0:
if match_score > LINK_THRESHOLD:
try:
self.db.execute(text("""
UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:dup AS uuid)
AND release_state != 'duplicate'
"""), {"master": match_uuid, "dup": ctrl_uuid})
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"cu": match_uuid, "pu": ctrl_uuid, "conf": match_score})
transferred = self._transfer_parent_links(match_uuid, ctrl_uuid)
self.stats["parent_links_transferred"] += transferred
self.db.commit()
cross_linked += 1
except Exception as e:
logger.error("BatchDedup cross-group link error %s%s: %s",
ctrl_uuid, match_uuid, e)
self.db.rollback()
self.stats["errors"] += 1
break
elif match_score > REVIEW_THRESHOLD:
self._write_review(
{"control_id": r[1], "title": r[2], "objective": "",
"merge_group_hint": hint, "pattern_id": None},
match_payload, match_score,
)
cross_review += 1
break
processed = min(batch_start + PARALLEL_BATCH, len(rows))
self._progress_count = processed
if processed % 500 < PARALLEL_BATCH:
logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review",
i + 1, len(rows), cross_linked, cross_review)
processed, len(rows), cross_linked, cross_review)
self.stats["cross_group_linked"] = cross_linked
self.stats["cross_group_review"] = cross_review

View File

@@ -92,6 +92,7 @@ REGULATION_LICENSE_MAP: dict[str, dict] = {
"eucsa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU Cybersecurity Act"},
"dataact": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Act"},
"dora": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Operational Resilience Act"},
"eu_2017_745": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Medizinprodukteverordnung (EU) 2017/745 (MDR)"},
"ehds": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Health Data Space"},
"gpsr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung"},
"eu_2023_988": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung (GPSR)"},
@@ -132,6 +133,39 @@ REGULATION_LICENSE_MAP: dict[str, dict] = {
"ao": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"},
"ao_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"},
"battdg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Batteriegesetz"},
# New German Laws (2026-04-10 ingestion)
"de_bsig_2025": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Gesetz (BSIG 2025, NIS2-Umsetzung)"},
"de_tdddg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TDDDG"},
"de_gwg_2017": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Geldwaeschegesetz (GwG)"},
"de_agg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Allgemeines Gleichbehandlungsgesetz (AGG)"},
"de_hinschg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Hinweisgeberschutzgesetz (HinSchG)"},
"de_lksg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Lieferkettensorgfaltspflichtengesetz (LkSG)"},
"de_kritisdachg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "KRITIS-Dachgesetz (KRITISDachG)"},
"de_bsi_kritisv": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BSI-Kritisverordnung (BSI-KritisV)"},
# DSK/BfDI Guidance (amtliche Orientierungshilfen)
"dsk_oh_ki_datenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH KI und Datenschutz"},
"dsk_oh_ki_systeme_tom": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH TOM bei KI-Systemen"},
"dsk_oh_ki_rag": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Generative KI mit RAG"},
"dsk_oh_telemedien": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Telemedien"},
"dsk_oh_digitale_dienste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Digitale Dienste"},
"dsk_oh_direktwerbung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Direktwerbung"},
"dsk_oh_videokonferenz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videokonferenzsysteme"},
"dsk_oh_videoueberwachung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Videoueberwachung"},
"dsk_oh_email_verschluesselung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH E-Mail-Verschluesselung"},
"dsk_oh_whistleblowing": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Whistleblowing"},
"dsk_oh_onlinedienste_zugang": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Onlinedienste Zugang"},
"dsk_oh_datenuebermittlung_drittlaender": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK OH Datenuebermittlung Drittlaender"},
"dsk_sdm_methode": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Standard-Datenschutzmodell SDM V3.1a"},
"dsk_ah_eu_us_dpf": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK AH EU-US Data Privacy Framework"},
"dsk_ah_bussgeldkonzept": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Bussgeldkonzept"},
"dsk_ah_dsfa_mussliste": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK DSFA Muss-Liste"},
"dsk_ah_verzeichnis_vvt": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Hinweise VVT"},
"dsk_ah_zertifizierung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Zertifizierungskriterien"},
"dsk_beschluss_ms365": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Festlegung Microsoft 365"},
"dsk_pos_ki_verordnung": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Positionspapier KI-Verordnung"},
"dsk_entschl_beschaeftigtendatenschutz": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "DSK Entschliessung Beschaeftigtendatenschutz"},
"bfdi_handreichung_ki_behoerden": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI in Behoerden"},
"bfdi_handreichung_ki_sicherheit": {"license": "DE_PUBLIC", "rule": 1, "source_type": "guideline", "name": "BfDI Handreichung KI Sicherheitsbehoerden"},
# Austrian Laws
"at_dsg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Österreichisches Datenschutzgesetz (DSG)"},
"at_abgb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB"},
@@ -459,6 +493,7 @@ class GeneratorConfig(BaseModel):
dry_run: bool = False
existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one
regulation_filter: Optional[List[str]] = None # Only process chunks matching these regulation_code prefixes
regulation_exclude: Optional[List[str]] = None # Skip chunks matching these regulation_code prefixes
skip_prefilter: bool = False # If True, skip local LLM pre-filter (send all chunks to API)
@@ -501,7 +536,7 @@ class GeneratedControl:
@dataclass
class GeneratorResult:
job_id: str = ""
status: str = "completed"
status: str = "running"
total_chunks_scanned: int = 0
controls_generated: int = 0
controls_verified: int = 0
@@ -583,8 +618,8 @@ Antworte NUR mit JSON: {{"relevant": true/false, "reason": "kurze Begründung"}}
return True, f"error: {e}"
async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call Anthropic Messages API."""
async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None, max_retries: int = 2) -> str:
"""Call Anthropic Messages API with retry on timeout."""
headers = {
"x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
@@ -598,24 +633,36 @@ async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> st
if system_prompt:
payload["system"] = system_prompt
try:
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
resp = await client.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
)
if resp.status_code != 200:
logger.error("Anthropic API %d: %s", resp.status_code, resp.text[:300])
# Use explicit per-phase timeouts to prevent indefinite hangs
timeout = httpx.Timeout(connect=30.0, read=LLM_TIMEOUT, write=30.0, pool=30.0)
for attempt in range(1 + max_retries):
try:
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
)
if resp.status_code != 200:
logger.error("Anthropic API %d: %s", resp.status_code, resp.text[:300])
return ""
data = resp.json()
content = data.get("content", [])
if content and isinstance(content, list):
return content[0].get("text", "")
return ""
data = resp.json()
content = data.get("content", [])
if content and isinstance(content, list):
return content[0].get("text", "")
except (httpx.TimeoutException, httpx.RemoteProtocolError) as e:
if attempt < max_retries:
logger.warning("Anthropic request attempt %d/%d failed: %s — retrying...", attempt + 1, max_retries + 1, e)
import asyncio
await asyncio.sleep(2 ** attempt)
continue
logger.error("Anthropic request failed after %d attempts: %s (type: %s)", max_retries + 1, e, type(e).__name__)
return ""
except Exception as e:
logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__)
return ""
except Exception as e:
logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__)
return ""
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
@@ -941,6 +988,12 @@ class ControlGeneratorPipeline:
if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter):
continue
# Exclude specific regulation_codes
if config.regulation_exclude and reg_code:
code_lower = reg_code.lower()
if any(code_lower.startswith(f.lower()) for f in config.regulation_exclude):
continue
reg_name = (payload.get("regulation_name_de", "")
or payload.get("regulation_name", "")
or payload.get("source_name", "")
@@ -1423,20 +1476,31 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
if structure_items:
s_chunks = [c for c, _ in structure_items]
s_lics = [l for _, l in structure_items]
s_controls = await self._structure_batch(s_chunks, s_lics)
try:
s_controls = await self._structure_batch(s_chunks, s_lics)
except Exception as e:
import traceback
logger.warning("Batch structure failed: %s — creating fallback controls\n%s", e, traceback.format_exc())
s_controls = [self._fallback_control(c) for c in s_chunks]
for (chunk, _), ctrl in zip(structure_items, s_controls):
orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk)
all_controls[orig_idx] = ctrl
if reform_items:
r_chunks = [c for c, _ in reform_items]
r_controls = await self._reformulate_batch(r_chunks, config)
try:
r_controls = await self._reformulate_batch(r_chunks, config)
except Exception as e:
logger.warning("Batch reform failed: %s — creating fallback controls", e)
r_controls = [self._fallback_control(c) for c in r_chunks]
for (chunk, _), ctrl in zip(reform_items, r_controls):
orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk)
if ctrl:
# Too-Close-Check for Rule 3
similarity = await check_similarity(chunk.text, f"{ctrl.objective} {ctrl.rationale}")
if similarity.status == "FAIL":
if similarity is None:
logger.warning("Similarity check returned None — skipping too-close check")
elif similarity.status == "FAIL":
ctrl.release_state = "too_close"
ctrl.generation_metadata["similarity_status"] = "FAIL"
ctrl.generation_metadata["similarity_scores"] = {
@@ -1502,36 +1566,42 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
# ── Stage 4: Harmonization ─────────────────────────────────────────
async def _check_harmonization(self, new_control: GeneratedControl) -> Optional[list]:
"""Check if a new control duplicates existing ones via embedding similarity."""
existing = self._load_existing_controls()
if not existing:
return None
# Pre-load all existing embeddings in batch (once per pipeline run)
if not self._existing_embeddings:
await self._preload_embeddings(existing)
"""Check if a new control duplicates existing ones via Qdrant vector search.
Uses the atomic_controls_dedup collection for fast nearest-neighbor lookup
instead of pre-loading all embeddings into memory.
"""
new_text = f"{new_control.title} {new_control.objective}"
new_emb = await _get_embedding(new_text)
if not new_emb:
return None
similar = []
for ex in existing:
ex_key = ex.get("control_id", "")
ex_emb = self._existing_embeddings.get(ex_key, [])
if not ex_emb:
continue
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{QDRANT_URL}/collections/atomic_controls_dedup/points/search",
json={
"vector": new_emb,
"limit": 5,
"score_threshold": HARMONIZATION_THRESHOLD,
"with_payload": {"include": ["control_id", "title"]},
},
)
if resp.status_code == 200:
results = resp.json().get("result", [])
if results:
return [
{
"control_id": r["payload"].get("control_id", ""),
"title": r["payload"].get("title", ""),
"similarity": round(r["score"], 3),
}
for r in results
]
except Exception as e:
logger.warning("Qdrant dedup search failed: %s — skipping harmonization", e)
cosine = _cosine_sim(new_emb, ex_emb)
if cosine > HARMONIZATION_THRESHOLD:
similar.append({
"control_id": ex.get("control_id", ""),
"title": ex.get("title", ""),
"similarity": round(cosine, 3),
})
return similar if similar else None
return None
async def _preload_embeddings(self, existing: list[dict]):
"""Pre-load embeddings for all existing controls in batches."""
@@ -1580,9 +1650,11 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne
if severity not in ("low", "medium", "high", "critical"):
severity = "medium"
tags = data.get("tags", [])
tags = data.get("tags") or []
if isinstance(tags, str):
tags = [t.strip() for t in tags.split(",")]
if not isinstance(tags, list):
tags = []
# Use LLM-provided domain if available, fallback to keyword-detected domain
llm_domain = data.get("domain")

View File

@@ -349,8 +349,15 @@ Antworte NUR mit einem JSON-Array. Keine Erklärungen."""
def _build_pass0a_prompt(
title: str, objective: str, requirements: str,
test_procedure: str, source_ref: str
test_procedure: str, source_ref: str,
source_original_text: str = ""
) -> str:
original_block = ""
if source_original_text:
original_block = f"""
ORIGINALTEXT (Gesetz/Verordnung — nutze fuer praezisere Pflichtableitung):
{source_original_text[:3000]}
"""
return f"""\
Analysiere das folgende Control und extrahiere alle einzelnen normativen \
Pflichten als JSON-Array.
@@ -361,7 +368,7 @@ Ziel: {objective}
Anforderungen: {requirements}
Prüfverfahren: {test_procedure}
Quellreferenz: {source_ref}
{original_block}
Antworte als JSON-Array:
[
{{
@@ -2407,7 +2414,8 @@ class DecompositionPass:
query = """
SELECT cc.id, cc.control_id, cc.title, cc.objective,
cc.requirements, cc.test_procedure,
cc.source_citation, cc.category
cc.source_citation, cc.category,
cc.source_original_text
FROM canonical_controls cc
WHERE cc.release_state NOT IN ('deprecated')
AND cc.parent_control_uuid IS NULL
@@ -2473,6 +2481,7 @@ class DecompositionPass:
"test_procedure": test_str,
"source_ref": source_str,
"category": row[7] or "",
"source_original_text": row[8] or "",
})
# Process in batches
@@ -2507,6 +2516,7 @@ class DecompositionPass:
requirements=ctrl["requirements"],
test_procedure=ctrl["test_procedure"],
source_ref=ctrl["source_ref"],
source_original_text=ctrl.get("source_original_text", ""),
)
llm_response = await _llm_anthropic(
prompt=prompt,
@@ -2529,6 +2539,7 @@ class DecompositionPass:
requirements=ctrl["requirements"],
test_procedure=ctrl["test_procedure"],
source_ref=ctrl["source_ref"],
source_original_text=ctrl.get("source_original_text", ""),
)
llm_response = await _llm_ollama(
prompt=prompt,
@@ -3008,29 +3019,36 @@ class DecompositionPass:
"obligations_kept": 0,
}
# Get all parents that have >1 validated obligation
parents = self.db.execute(text("""
SELECT parent_control_uuid, count(*) AS cnt
FROM obligation_candidates
WHERE release_state = 'validated'
AND merged_into_id IS NULL
GROUP BY parent_control_uuid
HAVING count(*) > 1
# Load ALL obligations in one query (avoids N+1 per parent)
all_obligs = self.db.execute(text("""
SELECT oc.id, oc.candidate_id, oc.obligation_text, oc.action, oc.object,
oc.parent_control_uuid
FROM obligation_candidates oc
WHERE oc.release_state = 'validated'
AND oc.merged_into_id IS NULL
AND oc.parent_control_uuid IN (
SELECT parent_control_uuid
FROM obligation_candidates
WHERE release_state = 'validated'
AND merged_into_id IS NULL
GROUP BY parent_control_uuid
HAVING count(*) > 1
)
ORDER BY oc.parent_control_uuid, oc.created_at
""")).fetchall()
for parent_uuid, cnt in parents:
stats["parents_checked"] += 1
obligs = self.db.execute(text("""
SELECT id, candidate_id, obligation_text, action, object
FROM obligation_candidates
WHERE parent_control_uuid = CAST(:pid AS uuid)
AND release_state = 'validated'
AND merged_into_id IS NULL
ORDER BY created_at
"""), {"pid": str(parent_uuid)}).fetchall()
# Group by parent in Python
from collections import defaultdict
parent_groups: dict[str, list] = defaultdict(list)
for row in all_obligs:
parent_groups[str(row[5])].append(row)
merged_ids = set()
oblig_list = list(obligs)
merge_batch: list[dict] = []
MERGE_FLUSH_SIZE = 200
for parent_uuid, oblig_list in parent_groups.items():
stats["parents_checked"] += 1
merged_ids: set[str] = set()
for i in range(len(oblig_list)):
if str(oblig_list[i][0]) in merged_ids:
@@ -3044,13 +3062,11 @@ class DecompositionPass:
obj_i = (oblig_list[i][4] or "").lower().strip()
obj_j = (oblig_list[j][4] or "").lower().strip()
# Check if actions are similar enough to be duplicates
if not _text_similar(action_i, action_j, threshold=0.75):
continue
if not _text_similar(obj_i, obj_j, threshold=0.60):
continue
# Keep the more abstract one (shorter text = less specific)
text_i = oblig_list[i][2] or ""
text_j = oblig_list[j][2] or ""
if _is_more_implementation_specific(text_j, text_i):
@@ -3060,18 +3076,31 @@ class DecompositionPass:
survivor_id = str(oblig_list[j][0])
merged_id = str(oblig_list[i][0])
merge_batch.append({"survivor": survivor_id, "merged": merged_id})
merged_ids.add(merged_id)
stats["obligations_merged"] += 1
# Flush batch periodically
if len(merge_batch) >= MERGE_FLUSH_SIZE:
for m in merge_batch:
self.db.execute(text("""
UPDATE obligation_candidates
SET release_state = 'merged',
merged_into_id = CAST(:survivor AS uuid)
WHERE id = CAST(:merged AS uuid)
"""), {"survivor": survivor_id, "merged": merged_id})
"""), m)
self.db.commit()
merge_batch.clear()
merged_ids.add(merged_id)
stats["obligations_merged"] += 1
# Commit per parent to avoid large transactions
self.db.commit()
# Flush remainder
for m in merge_batch:
self.db.execute(text("""
UPDATE obligation_candidates
SET release_state = 'merged',
merged_into_id = CAST(:survivor AS uuid)
WHERE id = CAST(:merged AS uuid)
"""), m)
self.db.commit()
stats["obligations_kept"] = self.db.execute(text("""
SELECT count(*) FROM obligation_candidates
@@ -3106,6 +3135,10 @@ class DecompositionPass:
AND trigger_type IS NULL
""")).fetchall()
# Classify all obligations first, then batch-update
BATCH_SIZE = 500
pending_updates: list[dict] = []
for row in obligs:
oc_id = str(row[0])
obl_text = row[1] or ""
@@ -3116,22 +3149,42 @@ class DecompositionPass:
trigger = _classify_trigger_type(obl_text, condition)
impl = _is_implementation_specific_text(obl_text, action, obj)
self.db.execute(text("""
UPDATE obligation_candidates
SET trigger_type = :trigger,
is_implementation_specific = :impl
WHERE id = CAST(:oid AS uuid)
"""), {"trigger": trigger, "impl": impl, "oid": oc_id})
pending_updates.append({"oid": oc_id, "trigger": trigger, "impl": impl})
stats["enriched"] += 1
stats[f"trigger_{trigger}"] += 1
stats[f"trigger_{trigger}"] = stats.get(f"trigger_{trigger}", 0) + 1
if impl:
stats["implementation_specific"] += 1
# Flush batch
if len(pending_updates) >= BATCH_SIZE:
self._flush_enrich_batch(pending_updates)
pending_updates.clear()
# Flush remainder
if pending_updates:
self._flush_enrich_batch(pending_updates)
self.db.commit()
logger.info("Enrich pass: %s", stats)
return stats
def _flush_enrich_batch(self, updates: list[dict]):
"""Batch-UPDATE obligation_candidates for enrich pass."""
# Group by (trigger, impl) to minimize UPDATE statements
from collections import defaultdict
groups: dict[tuple, list[str]] = defaultdict(list)
for u in updates:
groups[(u["trigger"], u["impl"])].append(u["oid"])
for (trigger, impl), ids in groups.items():
# Use ANY(ARRAY[...]) for batch WHERE clause
self.db.execute(text("""
UPDATE obligation_candidates
SET trigger_type = :trigger,
is_implementation_specific = :impl
WHERE id = ANY(CAST(:ids AS uuid[]))
"""), {"trigger": trigger, "impl": impl, "ids": ids})
# -------------------------------------------------------------------
# Decomposition Status
# -------------------------------------------------------------------
@@ -3365,7 +3418,8 @@ class DecompositionPass:
query = """
SELECT cc.id, cc.control_id, cc.title, cc.objective,
cc.requirements, cc.test_procedure,
cc.source_citation, cc.category
cc.source_citation, cc.category,
cc.source_original_text
FROM canonical_controls cc
WHERE cc.release_state NOT IN ('deprecated')
AND cc.parent_control_uuid IS NULL
@@ -3414,6 +3468,7 @@ class DecompositionPass:
"test_procedure": _format_field(row[5] or ""),
"source_ref": _format_citation(row[6] or ""),
"category": row[7] or "",
"source_original_text": row[8] or "",
})
if not prepared: