From a14e2f3a002df2e2d4bd63c251777233ef362128 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sat, 21 Mar 2026 22:27:09 +0100 Subject: [PATCH] feat(decomposition): add merge pass, enrichment, and Pass 0b refinements Add obligation refinement pipeline between Pass 0a and 0b: - Merge pass: rule-based dedup of implementation-level duplicate obligations within the same parent control (Jaccard similarity on action+object) - Enrich pass: classify trigger_type (event/periodic/continuous) and detect is_implementation_specific from obligation text (regex-based, no LLM) - Pass 0b: skip merged obligations, cap severity for impl-specific, override category to 'testing' for test obligations - Migration 075: merged_into_id, trigger_type, is_implementation_specific - Two new API endpoints: merge-obligations, enrich-obligations - 30+ new tests (122 total, all passing) Co-Authored-By: Claude Opus 4.6 --- .../compliance/api/crosswalk_routes.py | 49 ++ .../compliance/services/decomposition_pass.py | 268 +++++++++- .../migrations/075_obligation_refinement.sql | 38 ++ .../tests/test_decomposition_pass.py | 461 +++++++++++++++++- 4 files changed, 804 insertions(+), 12 deletions(-) create mode 100644 backend-compliance/migrations/075_obligation_refinement.sql diff --git a/backend-compliance/compliance/api/crosswalk_routes.py b/backend-compliance/compliance/api/crosswalk_routes.py index 1a0c668..3d5f754 100644 --- a/backend-compliance/compliance/api/crosswalk_routes.py +++ b/backend-compliance/compliance/api/crosswalk_routes.py @@ -13,6 +13,8 @@ Endpoints: GET /v1/canonical/crosswalk/stats — Coverage statistics POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction + POST /v1/canonical/migrate/merge-obligations — Merge implementation-level dupes + POST /v1/canonical/migrate/enrich-obligations — Add trigger_type, impl metadata POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification @@ -157,6 +159,9 @@ class DecompositionStatusResponse(BaseModel): rejected: int = 0 composed: int = 0 atomic_controls: int = 0 + merged: int = 0 + enriched: int = 0 + ready_for_pass0b: int = 0 decomposition_pct: float = 0.0 composition_pct: float = 0.0 @@ -488,6 +493,50 @@ async def migrate_decompose(req: MigrationRequest): db.close() +@router.post("/migrate/merge-obligations", response_model=MigrationResponse) +async def migrate_merge_obligations(): + """Merge implementation-level duplicate obligations within each parent. + + Run AFTER Pass 0a, BEFORE Pass 0b. No LLM calls — rule-based. + Merges obligations that share similar action+object into the more + abstract survivor, marking the concrete duplicate as 'merged'. + """ + from compliance.services.decomposition_pass import DecompositionPass + + db = SessionLocal() + try: + decomp = DecompositionPass(db=db) + stats = decomp.run_merge_pass() + return MigrationResponse(status="completed", stats=stats) + except Exception as e: + logger.error("Merge pass failed: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@router.post("/migrate/enrich-obligations", response_model=MigrationResponse) +async def migrate_enrich_obligations(): + """Add trigger_type and is_implementation_specific metadata. + + Run AFTER merge pass, BEFORE Pass 0b. No LLM calls — rule-based. + Classifies trigger_type (event/periodic/continuous) from obligation text + and detects implementation-specific obligations (concrete tools/protocols). + """ + from compliance.services.decomposition_pass import DecompositionPass + + db = SessionLocal() + try: + decomp = DecompositionPass(db=db) + stats = decomp.enrich_obligations() + return MigrationResponse(status="completed", stats=stats) + except Exception as e: + logger.error("Enrich pass failed: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + @router.post("/migrate/compose-atomic", response_model=MigrationResponse) async def migrate_compose_atomic(req: MigrationRequest): """Pass 0b: Compose atomic controls from obligation candidates. diff --git a/backend-compliance/compliance/services/decomposition_pass.py b/backend-compliance/compliance/services/decomposition_pass.py index 6092b9f..034e02e 100644 --- a/backend-compliance/compliance/services/decomposition_pass.py +++ b/backend-compliance/compliance/services/decomposition_pass.py @@ -126,6 +126,83 @@ _REPORTING_SIGNALS = [ _REPORTING_RE = re.compile("|".join(_REPORTING_SIGNALS), re.IGNORECASE) +# --------------------------------------------------------------------------- +# Merge & Enrichment helpers +# --------------------------------------------------------------------------- + +# Trigger-type detection patterns +_EVENT_TRIGGERS = re.compile( + r"\b(vorfall|incident|breach|verletzung|sicherheitsvorfall|meldung|entdeckung" + r"|feststellung|erkennung|ereignis|eintritt|bei\s+auftreten|im\s+falle" + r"|wenn\s+ein|sobald|unverzüglich|upon|in\s+case\s+of|when\s+a)\b", + re.IGNORECASE, +) +_PERIODIC_TRIGGERS = re.compile( + r"\b(jährlich|monatlich|quartalsweise|regelmäßig|periodisch|annually" + r"|monthly|quarterly|periodic|mindestens\s+(einmal|alle)|turnusmäßig" + r"|wiederkehrend|in\s+regelmäßigen\s+abständen)\b", + re.IGNORECASE, +) + +# Implementation-specific keywords (concrete tools/protocols/formats) +_IMPL_SPECIFIC_PATTERNS = re.compile( + r"\b(TLS|SSL|AES|RSA|SHA-\d|HTTPS|LDAP|SAML|OAuth|OIDC|MFA|2FA" + r"|SIEM|IDS|IPS|WAF|VPN|VLAN|DMZ|HSM|PKI|RBAC|ABAC" + r"|ISO\s*27\d{3}|SOC\s*2|PCI[\s-]DSS|NIST" + r"|Firewall|Antivirus|EDR|XDR|SOAR|DLP" + r"|SMS|E-Mail|Fax|Telefon" + r"|JSON|XML|CSV|PDF|YAML" + r"|PostgreSQL|MySQL|MongoDB|Redis|Kafka" + r"|Docker|Kubernetes|AWS|Azure|GCP" + r"|Active\s*Directory|RADIUS|Kerberos" + r"|RSyslog|Splunk|ELK|Grafana|Prometheus" + r"|Git|Jenkins|Terraform|Ansible)\b", + re.IGNORECASE, +) + + +def _classify_trigger_type(obligation_text: str, condition: str) -> str: + """Classify when an obligation is triggered: event/periodic/continuous.""" + combined = f"{obligation_text} {condition}" + if _EVENT_TRIGGERS.search(combined): + return "event" + if _PERIODIC_TRIGGERS.search(combined): + return "periodic" + return "continuous" + + +def _is_implementation_specific_text( + obligation_text: str, action: str, obj: str +) -> bool: + """Check if an obligation references concrete implementation details.""" + combined = f"{obligation_text} {action} {obj}" + matches = _IMPL_SPECIFIC_PATTERNS.findall(combined) + return len(matches) >= 1 + + +def _text_similar(a: str, b: str, threshold: float = 0.75) -> bool: + """Quick token-overlap similarity check (Jaccard on words).""" + if not a or not b: + return False + tokens_a = set(a.split()) + tokens_b = set(b.split()) + if not tokens_a or not tokens_b: + return False + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) >= threshold + + +def _is_more_implementation_specific(text_a: str, text_b: str) -> bool: + """Return True if text_a is more implementation-specific than text_b.""" + matches_a = len(_IMPL_SPECIFIC_PATTERNS.findall(text_a)) + matches_b = len(_IMPL_SPECIFIC_PATTERNS.findall(text_b)) + if matches_a != matches_b: + return matches_a > matches_b + # Tie-break: longer text is usually more specific + return len(text_a) > len(text_b) + + # --------------------------------------------------------------------------- # Data classes # --------------------------------------------------------------------------- @@ -864,12 +941,17 @@ class DecompositionPass: ) stats["controls_processed"] += 1 + # Commit after each successful sub-batch to avoid losing work + self.db.commit() + except Exception as e: ids = ", ".join(c["control_id"] for c in batch) logger.error("Pass 0a failed for [%s]: %s", ids, e) stats["errors"] += 1 - - self.db.commit() + try: + self.db.rollback() + except Exception: + pass logger.info("Pass 0a: %s", stats) return stats @@ -944,10 +1026,13 @@ class DecompositionPass: cc.category AS parent_category, cc.source_citation AS parent_citation, cc.severity AS parent_severity, - cc.control_id AS parent_control_id + cc.control_id AS parent_control_id, + oc.trigger_type, + oc.is_implementation_specific FROM obligation_candidates oc JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid WHERE oc.release_state = 'validated' + AND oc.merged_into_id IS NULL AND NOT EXISTS ( SELECT 1 FROM canonical_controls ac WHERE ac.parent_control_uuid = oc.parent_control_uuid @@ -971,6 +1056,7 @@ class DecompositionPass: "dedup_enabled": self._dedup is not None, "dedup_linked": 0, "dedup_review": 0, + "skipped_merged": 0, } # Prepare obligation data @@ -991,6 +1077,8 @@ class DecompositionPass: "parent_severity": row[11] or "medium", "parent_control_id": row[12] or "", "source_ref": _format_citation(row[10] or ""), + "trigger_type": row[13] or "continuous", + "is_implementation_specific": row[14] or False, }) # Process in batches @@ -1044,12 +1132,17 @@ class DecompositionPass: parsed = _parse_json_object(llm_response) await self._process_pass0b_control(obl, parsed, stats) + # Commit after each successful sub-batch + self.db.commit() + except Exception as e: ids = ", ".join(o["candidate_id"] for o in batch) logger.error("Pass 0b failed for [%s]: %s", ids, e) stats["errors"] += 1 - - self.db.commit() + try: + self.db.rollback() + except Exception: + pass logger.info("Pass 0b: %s", stats) return stats @@ -1090,6 +1183,16 @@ class DecompositionPass: atomic.parent_control_uuid = obl["parent_uuid"] atomic.obligation_candidate_id = obl["candidate_id"] + # Cap severity for implementation-specific obligations + if obl.get("is_implementation_specific") and atomic.severity in ( + "critical", "high" + ): + atomic.severity = "medium" + + # Override category for test obligations + if obl.get("is_test"): + atomic.category = "testing" + # ── Dedup check (if enabled) ──────────────────────────── if self._dedup: pattern_id = None @@ -1182,6 +1285,150 @@ class DecompositionPass: stats["controls_created"] += 1 stats["candidates_processed"] += 1 + # ------------------------------------------------------------------- + # Merge Pass: Deduplicate implementation-level obligations + # ------------------------------------------------------------------- + + def run_merge_pass(self) -> dict: + """Merge implementation-level duplicate obligations within each parent. + + When the same parent control has multiple obligations with nearly + identical action+object (e.g. "SMS-Verbot" + "Policy-as-Code" both + implementing a communication restriction), keep the more abstract one + and mark the concrete one as merged. + + No LLM calls — purely rule-based using text similarity. + """ + stats = { + "parents_checked": 0, + "obligations_merged": 0, + "obligations_kept": 0, + } + + # Get all parents that have >1 validated obligation + parents = self.db.execute(text(""" + SELECT parent_control_uuid, count(*) AS cnt + FROM obligation_candidates + WHERE release_state = 'validated' + AND merged_into_id IS NULL + GROUP BY parent_control_uuid + HAVING count(*) > 1 + """)).fetchall() + + for parent_uuid, cnt in parents: + stats["parents_checked"] += 1 + obligs = self.db.execute(text(""" + SELECT id, candidate_id, obligation_text, action, object + FROM obligation_candidates + WHERE parent_control_uuid = CAST(:pid AS uuid) + AND release_state = 'validated' + AND merged_into_id IS NULL + ORDER BY created_at + """), {"pid": str(parent_uuid)}).fetchall() + + merged_ids = set() + oblig_list = list(obligs) + + for i in range(len(oblig_list)): + if str(oblig_list[i][0]) in merged_ids: + continue + for j in range(i + 1, len(oblig_list)): + if str(oblig_list[j][0]) in merged_ids: + continue + + action_i = (oblig_list[i][3] or "").lower().strip() + action_j = (oblig_list[j][3] or "").lower().strip() + obj_i = (oblig_list[i][4] or "").lower().strip() + obj_j = (oblig_list[j][4] or "").lower().strip() + + # Check if actions are similar enough to be duplicates + if not _text_similar(action_i, action_j, threshold=0.75): + continue + if not _text_similar(obj_i, obj_j, threshold=0.60): + continue + + # Keep the more abstract one (shorter text = less specific) + text_i = oblig_list[i][2] or "" + text_j = oblig_list[j][2] or "" + if _is_more_implementation_specific(text_j, text_i): + survivor_id = str(oblig_list[i][0]) + merged_id = str(oblig_list[j][0]) + else: + survivor_id = str(oblig_list[j][0]) + merged_id = str(oblig_list[i][0]) + + self.db.execute(text(""" + UPDATE obligation_candidates + SET release_state = 'merged', + merged_into_id = CAST(:survivor AS uuid) + WHERE id = CAST(:merged AS uuid) + """), {"survivor": survivor_id, "merged": merged_id}) + + merged_ids.add(merged_id) + stats["obligations_merged"] += 1 + + # Commit per parent to avoid large transactions + self.db.commit() + + stats["obligations_kept"] = self.db.execute(text(""" + SELECT count(*) FROM obligation_candidates + WHERE release_state = 'validated' AND merged_into_id IS NULL + """)).fetchone()[0] + + logger.info("Merge pass: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Enrich Pass: Add metadata to obligations + # ------------------------------------------------------------------- + + def enrich_obligations(self) -> dict: + """Add trigger_type and is_implementation_specific to obligations. + + Rule-based enrichment — no LLM calls. + """ + stats = { + "enriched": 0, + "trigger_event": 0, + "trigger_periodic": 0, + "trigger_continuous": 0, + "implementation_specific": 0, + } + + obligs = self.db.execute(text(""" + SELECT id, obligation_text, condition, action, object + FROM obligation_candidates + WHERE release_state = 'validated' + AND merged_into_id IS NULL + AND trigger_type IS NULL + """)).fetchall() + + for row in obligs: + oc_id = str(row[0]) + obl_text = row[1] or "" + condition = row[2] or "" + action = row[3] or "" + obj = row[4] or "" + + trigger = _classify_trigger_type(obl_text, condition) + impl = _is_implementation_specific_text(obl_text, action, obj) + + self.db.execute(text(""" + UPDATE obligation_candidates + SET trigger_type = :trigger, + is_implementation_specific = :impl + WHERE id = CAST(:oid AS uuid) + """), {"trigger": trigger, "impl": impl, "oid": oc_id}) + + stats["enriched"] += 1 + stats[f"trigger_{trigger}"] += 1 + if impl: + stats["implementation_specific"] += 1 + + self.db.commit() + logger.info("Enrich pass: %s", stats) + return stats + # ------------------------------------------------------------------- # Decomposition Status # ------------------------------------------------------------------- @@ -1198,9 +1445,13 @@ class DecompositionPass: (SELECT count(*) FROM obligation_candidates WHERE release_state = 'validated') AS validated, (SELECT count(*) FROM obligation_candidates WHERE release_state = 'rejected') AS rejected, (SELECT count(*) FROM obligation_candidates WHERE release_state = 'composed') AS composed, - (SELECT count(*) FROM canonical_controls WHERE parent_control_uuid IS NOT NULL) AS atomic_controls + (SELECT count(*) FROM canonical_controls WHERE parent_control_uuid IS NOT NULL) AS atomic_controls, + (SELECT count(*) FROM obligation_candidates WHERE release_state = 'merged') AS merged, + (SELECT count(*) FROM obligation_candidates WHERE trigger_type IS NOT NULL) AS enriched """)).fetchone() + validated_for_0b = row[3] - (row[7] or 0) # validated minus merged + return { "rich_controls": row[0], "decomposed_controls": row[1], @@ -1209,8 +1460,11 @@ class DecompositionPass: "rejected": row[4], "composed": row[5], "atomic_controls": row[6], + "merged": row[7] or 0, + "enriched": row[8] or 0, + "ready_for_pass0b": validated_for_0b, "decomposition_pct": round(row[1] / max(row[0], 1) * 100, 1), - "composition_pct": round(row[5] / max(row[3], 1) * 100, 1), + "composition_pct": round(row[5] / max(validated_for_0b, 1) * 100, 1), } # ------------------------------------------------------------------- diff --git a/backend-compliance/migrations/075_obligation_refinement.sql b/backend-compliance/migrations/075_obligation_refinement.sql new file mode 100644 index 0000000..13841ef --- /dev/null +++ b/backend-compliance/migrations/075_obligation_refinement.sql @@ -0,0 +1,38 @@ +-- Migration 075: Obligation Refinement Fields +-- Supports Merge Pass (implementation-level dedup) and metadata enrichment. +-- +-- New fields: +-- merged_into_id — points to survivor obligation when merged +-- trigger_type — event / periodic / continuous +-- is_implementation_specific — true if obligation references concrete tool/protocol + +-- ============================================================================= +-- 1. Add merge tracking +-- ============================================================================= + +ALTER TABLE obligation_candidates + ADD COLUMN IF NOT EXISTS merged_into_id UUID + REFERENCES obligation_candidates(id); + +CREATE INDEX IF NOT EXISTS idx_oc_merged_into + ON obligation_candidates(merged_into_id) + WHERE merged_into_id IS NOT NULL; + +-- Allow 'merged' as release_state +ALTER TABLE obligation_candidates + DROP CONSTRAINT IF EXISTS obligation_candidates_release_state_check; + +ALTER TABLE obligation_candidates + ADD CONSTRAINT obligation_candidates_release_state_check + CHECK (release_state IN ('extracted', 'validated', 'rejected', 'composed', 'merged')); + +-- ============================================================================= +-- 2. Add enrichment metadata +-- ============================================================================= + +ALTER TABLE obligation_candidates + ADD COLUMN IF NOT EXISTS trigger_type VARCHAR(20) DEFAULT NULL + CHECK (trigger_type IS NULL OR trigger_type IN ('event', 'periodic', 'continuous')); + +ALTER TABLE obligation_candidates + ADD COLUMN IF NOT EXISTS is_implementation_specific BOOLEAN DEFAULT FALSE; diff --git a/backend-compliance/tests/test_decomposition_pass.py b/backend-compliance/tests/test_decomposition_pass.py index f0f0d2c..e42591b 100644 --- a/backend-compliance/tests/test_decomposition_pass.py +++ b/backend-compliance/tests/test_decomposition_pass.py @@ -49,6 +49,10 @@ from compliance.services.decomposition_pass import ( _PASS0A_SYSTEM_PROMPT, _PASS0B_SYSTEM_PROMPT, DecompositionPass, + _classify_trigger_type, + _is_implementation_specific_text, + _text_similar, + _is_more_implementation_specific, ) @@ -757,6 +761,7 @@ class TestDecompositionPassRun0b: "Service Continuity", "finance", '{"source": "MiCA", "article": "Art. 8"}', "high", "FIN-001", + "continuous", False, # trigger_type, is_implementation_specific ), ] @@ -809,6 +814,7 @@ class TestDecompositionPassRun0b: False, False, "Auth Controls", "authentication", "", "high", "AUTH-001", + "continuous", False, ), ] @@ -842,7 +848,8 @@ class TestDecompositionStatus: def test_returns_status(self): mock_db = MagicMock() mock_result = MagicMock() - mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800) + # 9 columns: rich, decomposed, total, validated, rejected, composed, atomic, merged, enriched + mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800, 100, 2400) mock_db.execute.return_value = mock_result decomp = DecompositionPass(db=mock_db) @@ -855,13 +862,17 @@ class TestDecompositionStatus: assert status["rejected"] == 200 assert status["composed"] == 2000 assert status["atomic_controls"] == 1800 + assert status["merged"] == 100 + assert status["enriched"] == 2400 + assert status["ready_for_pass0b"] == 2400 # 2500 validated - 100 merged assert status["decomposition_pct"] == 20.0 - assert status["composition_pct"] == 80.0 + # composition_pct: 2000 composed / 2400 ready_for_pass0b + assert status["composition_pct"] == 83.3 def test_handles_zero_division(self): mock_db = MagicMock() mock_result = MagicMock() - mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0) + mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0, 0, 0) mock_db.execute.return_value = mock_result decomp = DecompositionPass(db=mock_db) @@ -1089,12 +1100,14 @@ class TestDecompositionPassAnthropicBatch: "MFA implementieren", "implementieren", "MFA", False, False, "Auth", "security", '{"source": "DSGVO", "article": "Art. 32"}', - "high", "CTRL-001"), + "high", "CTRL-001", + "continuous", False), ("oc-uuid-2", "OC-CTRL-001-02", "parent-uuid-1", "MFA testen", "testen", "MFA", True, False, "Auth", "security", '{"source": "DSGVO", "article": "Art. 32"}', - "high", "CTRL-001"), + "high", "CTRL-001", + "periodic", False), ] mock_seq = MagicMock() @@ -1232,3 +1245,441 @@ class TestSourceFilter: query_str = str(call_args[0][0]) assert "IN :cats" in query_str assert "ILIKE" in query_str + + +# --------------------------------------------------------------------------- +# TRIGGER TYPE CLASSIFICATION TESTS +# --------------------------------------------------------------------------- + + +class TestClassifyTriggerType: + """Tests for _classify_trigger_type helper.""" + + def test_event_trigger_vorfall(self): + assert _classify_trigger_type( + "Bei einem Sicherheitsvorfall muss gemeldet werden", "" + ) == "event" + + def test_event_trigger_condition_field(self): + assert _classify_trigger_type( + "Melden", "wenn ein Datenverlust festgestellt wird" + ) == "event" + + def test_event_trigger_breach(self): + assert _classify_trigger_type( + "In case of a data breach, notify authorities", "" + ) == "event" + + def test_periodic_trigger_jaehrlich(self): + assert _classify_trigger_type( + "Jährlich ist eine Überprüfung durchzuführen", "" + ) == "periodic" + + def test_periodic_trigger_regelmaessig(self): + assert _classify_trigger_type( + "Regelmäßig muss ein Audit stattfinden", "" + ) == "periodic" + + def test_periodic_trigger_quarterly(self): + assert _classify_trigger_type( + "Quarterly review of access controls", "" + ) == "periodic" + + def test_continuous_default(self): + assert _classify_trigger_type( + "Betreiber müssen Zugangskontrollen implementieren", "" + ) == "continuous" + + def test_continuous_empty_text(self): + assert _classify_trigger_type("", "") == "continuous" + + def test_event_takes_precedence_over_periodic(self): + # "Vorfall" + "regelmäßig" → event wins + assert _classify_trigger_type( + "Bei einem Vorfall ist regelmäßig zu prüfen", "" + ) == "event" + + +# --------------------------------------------------------------------------- +# IMPLEMENTATION-SPECIFIC DETECTION TESTS +# --------------------------------------------------------------------------- + + +class TestIsImplementationSpecific: + """Tests for _is_implementation_specific_text helper.""" + + def test_tls_is_implementation_specific(self): + assert _is_implementation_specific_text( + "Verschlüsselung mittels TLS 1.3 sicherstellen", + "sicherstellen", "Verschlüsselung" + ) + + def test_mfa_is_implementation_specific(self): + assert _is_implementation_specific_text( + "MFA muss für alle Konten aktiviert werden", + "aktivieren", "MFA" + ) + + def test_siem_is_implementation_specific(self): + assert _is_implementation_specific_text( + "Ein SIEM-System muss betrieben werden", + "betreiben", "SIEM-System" + ) + + def test_abstract_obligation_not_specific(self): + assert not _is_implementation_specific_text( + "Zugriffskontrollen müssen implementiert werden", + "implementieren", "Zugriffskontrollen" + ) + + def test_generic_encryption_not_specific(self): + assert not _is_implementation_specific_text( + "Daten müssen verschlüsselt gespeichert werden", + "verschlüsseln", "Daten" + ) + + +# --------------------------------------------------------------------------- +# TEXT SIMILARITY TESTS +# --------------------------------------------------------------------------- + + +class TestTextSimilar: + """Tests for _text_similar Jaccard helper.""" + + def test_identical_strings(self): + assert _text_similar("implementieren mfa", "implementieren mfa") + + def test_similar_strings(self): + assert _text_similar( + "implementieren zugangskontrolle", + "implementieren zugangskontrolle system", + threshold=0.60, + ) + + def test_different_strings(self): + assert not _text_similar( + "implementieren mfa", + "dokumentieren audit", + threshold=0.75, + ) + + def test_empty_string(self): + assert not _text_similar("", "something") + + def test_both_empty(self): + assert not _text_similar("", "") + + +class TestIsMoreImplementationSpecific: + """Tests for _is_more_implementation_specific.""" + + def test_concrete_vs_abstract(self): + concrete = "SMS-Versand muss über TLS verschlüsselt werden" + abstract = "Kommunikation muss verschlüsselt werden" + assert _is_more_implementation_specific(concrete, abstract) + + def test_abstract_vs_concrete(self): + concrete = "Firewall-Regeln müssen konfiguriert werden" + abstract = "Netzwerksicherheit muss gewährleistet werden" + assert not _is_more_implementation_specific(abstract, concrete) + + def test_equal_specificity_longer_wins(self): + a = "Zugriffskontrollen müssen implementiert werden und dokumentiert werden" + b = "Zugriffskontrollen implementieren" + assert _is_more_implementation_specific(a, b) + + +# --------------------------------------------------------------------------- +# MERGE PASS TESTS +# --------------------------------------------------------------------------- + + +class TestMergePass: + """Tests for DecompositionPass.run_merge_pass.""" + + def test_merge_pass_merges_similar_obligations(self): + mock_db = MagicMock() + + # Step 1: Parents with >1 validated obligation + mock_parents = MagicMock() + mock_parents.fetchall.return_value = [ + ("parent-uuid-1", 3), + ] + + # Step 2: Obligations for that parent + mock_obligs = MagicMock() + mock_obligs.fetchall.return_value = [ + ("obl-1", "OC-001-01", + "Betreiber müssen Verschlüsselung implementieren", + "implementieren", "verschlüsselung"), + ("obl-2", "OC-001-02", + "Betreiber müssen Verschlüsselung mittels TLS implementieren", + "implementieren", "verschlüsselung"), + ("obl-3", "OC-001-03", + "Betreiber müssen Zugriffsprotokolle führen", + "führen", "zugriffsprotokolle"), + ] + + # Step 3: Final count + mock_count = MagicMock() + mock_count.fetchone.return_value = (2,) + + call_count = [0] + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_parents + if call_count[0] == 2: + return mock_obligs + if call_count[0] == 3: + return MagicMock() # UPDATE + if call_count[0] == 4: + return mock_count # Final count + return MagicMock() + mock_db.execute.side_effect = side_effect + + decomp = DecompositionPass(db=mock_db) + stats = decomp.run_merge_pass() + + assert stats["parents_checked"] == 1 + assert stats["obligations_merged"] == 1 # obl-2 merged into obl-1 + assert stats["obligations_kept"] == 2 + + def test_merge_pass_no_merge_when_different_actions(self): + mock_db = MagicMock() + + mock_parents = MagicMock() + mock_parents.fetchall.return_value = [ + ("parent-uuid-1", 2), + ] + + mock_obligs = MagicMock() + mock_obligs.fetchall.return_value = [ + ("obl-1", "OC-001-01", + "Verschlüsselung implementieren", + "implementieren", "verschlüsselung"), + ("obl-2", "OC-001-02", + "Zugriffsprotokolle dokumentieren", + "dokumentieren", "zugriffsprotokolle"), + ] + + mock_count = MagicMock() + mock_count.fetchone.return_value = (2,) + + call_count = [0] + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_parents + if call_count[0] == 2: + return mock_obligs + if call_count[0] == 3: + return mock_count + return MagicMock() + mock_db.execute.side_effect = side_effect + + decomp = DecompositionPass(db=mock_db) + stats = decomp.run_merge_pass() + + assert stats["obligations_merged"] == 0 + assert stats["obligations_kept"] == 2 + + +# --------------------------------------------------------------------------- +# ENRICH PASS TESTS +# --------------------------------------------------------------------------- + + +class TestEnrichPass: + """Tests for DecompositionPass.enrich_obligations.""" + + def test_enrich_classifies_trigger_types(self): + mock_db = MagicMock() + + mock_obligs = MagicMock() + mock_obligs.fetchall.return_value = [ + ("obl-1", "Bei Vorfall melden", "Sicherheitsvorfall", + "melden", "Vorfall"), + ("obl-2", "Jährlich Audit durchführen", "", + "durchführen", "Audit"), + ("obl-3", "Verschlüsselung mittels TLS implementieren", "", + "implementieren", "Verschlüsselung"), + ] + + call_count = [0] + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_obligs + return MagicMock() # UPDATE statements + mock_db.execute.side_effect = side_effect + + decomp = DecompositionPass(db=mock_db) + stats = decomp.enrich_obligations() + + assert stats["enriched"] == 3 + assert stats["trigger_event"] == 1 + assert stats["trigger_periodic"] == 1 + assert stats["trigger_continuous"] == 1 + assert stats["implementation_specific"] == 1 + + +# --------------------------------------------------------------------------- +# MIGRATION 075 TESTS +# --------------------------------------------------------------------------- + + +class TestMigration075: + """Tests for migration 075 SQL file.""" + + def test_migration_file_exists(self): + from pathlib import Path + migration = Path(__file__).parent.parent / "migrations" / "075_obligation_refinement.sql" + assert migration.exists(), "Migration 075 file missing" + + def test_migration_contains_required_fields(self): + from pathlib import Path + migration = Path(__file__).parent.parent / "migrations" / "075_obligation_refinement.sql" + content = migration.read_text() + assert "merged_into_id" in content + assert "trigger_type" in content + assert "is_implementation_specific" in content + assert "'merged'" in content + + +# --------------------------------------------------------------------------- +# PASS 0B ENRICHMENT INTEGRATION TESTS +# --------------------------------------------------------------------------- + + +class TestPass0bWithEnrichment: + """Tests that Pass 0b uses enrichment metadata correctly.""" + + def test_pass0b_query_skips_merged(self): + """Verify Pass 0b query includes merged_into_id IS NULL filter.""" + mock_db = MagicMock() + mock_rows = MagicMock() + mock_rows.fetchall.return_value = [] + mock_db.execute.return_value = mock_rows + + import asyncio + decomp = DecompositionPass(db=mock_db) + stats = asyncio.get_event_loop().run_until_complete( + decomp.run_pass0b(limit=10, use_anthropic=True) + ) + + call_args = mock_db.execute.call_args_list[0] + query_str = str(call_args[0][0]) + assert "merged_into_id IS NULL" in query_str + + def test_severity_capped_for_implementation_specific(self): + """Implementation-specific obligations get max severity=medium.""" + obl = { + "oc_id": "oc-1", + "candidate_id": "OC-001-01", + "parent_uuid": "p-uuid", + "obligation_text": "TLS implementieren", + "action": "implementieren", + "object": "TLS", + "is_test": False, + "is_reporting": False, + "parent_title": "Encryption", + "parent_category": "security", + "parent_citation": "", + "parent_severity": "high", + "parent_control_id": "SEC-001", + "source_ref": "", + "trigger_type": "continuous", + "is_implementation_specific": True, + } + parsed = { + "title": "TLS implementieren", + "objective": "TLS für alle Verbindungen", + "requirements": ["TLS 1.3"], + "test_procedure": ["Scan"], + "evidence": ["Zertifikat"], + "severity": "critical", + "category": "security", + } + stats = {"controls_created": 0, "candidates_processed": 0, + "llm_failures": 0, "dedup_linked": 0, "dedup_review": 0} + + mock_db = MagicMock() + mock_seq = MagicMock() + mock_seq.fetchone.return_value = (0,) + + call_count = [0] + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_seq # _next_atomic_seq + return MagicMock() + mock_db.execute.side_effect = side_effect + + import asyncio + decomp = DecompositionPass(db=mock_db) + asyncio.get_event_loop().run_until_complete( + decomp._process_pass0b_control(obl, parsed, stats) + ) + + # _write_atomic_control is call #2: db.execute(text(...), {params}) + insert_call = mock_db.execute.call_args_list[1] + # positional args: (text_obj, params_dict) + insert_params = insert_call[0][1] + assert insert_params["severity"] == "medium" + + def test_test_obligation_gets_testing_category(self): + """Test obligations should get category='testing'.""" + obl = { + "oc_id": "oc-1", + "candidate_id": "OC-001-01", + "parent_uuid": "p-uuid", + "obligation_text": "MFA testen", + "action": "testen", + "object": "MFA", + "is_test": True, + "is_reporting": False, + "parent_title": "Auth", + "parent_category": "security", + "parent_citation": "", + "parent_severity": "high", + "parent_control_id": "AUTH-001", + "source_ref": "", + "trigger_type": "periodic", + "is_implementation_specific": False, + } + parsed = { + "title": "MFA-Wirksamkeit testen", + "objective": "Regelmäßig MFA testen", + "requirements": ["Testplan"], + "test_procedure": ["Durchführung"], + "evidence": ["Protokoll"], + "severity": "high", + "category": "security", # LLM says security + } + stats = {"controls_created": 0, "candidates_processed": 0, + "llm_failures": 0, "dedup_linked": 0, "dedup_review": 0} + + mock_db = MagicMock() + mock_seq = MagicMock() + mock_seq.fetchone.return_value = (0,) + + call_count = [0] + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_seq + return MagicMock() + mock_db.execute.side_effect = side_effect + + import asyncio + decomp = DecompositionPass(db=mock_db) + asyncio.get_event_loop().run_until_complete( + decomp._process_pass0b_control(obl, parsed, stats) + ) + + # _write_atomic_control is call #2: db.execute(text(...), {params}) + insert_call = mock_db.execute.call_args_list[1] + insert_params = insert_call[0][1] + assert insert_params["category"] == "testing"