From 770f0b5ab057f07baddd9abcca98fc3547e47440 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Tue, 24 Mar 2026 07:24:02 +0100 Subject: [PATCH] =?UTF-8?q?fix:=20adapt=20batch=20dedup=20to=20NULL=20patt?= =?UTF-8?q?ern=5Fid=20=E2=80=94=20group=20by=20merge=5Fgroup=5Fhint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All Pass 0b controls have pattern_id=NULL. Rewritten to: - Phase 1: Group by merge_group_hint (action:object:trigger), 52k groups - Phase 2: Cross-group embedding search for semantically similar masters - Qdrant search uses unfiltered cross-regulation endpoint - API param changed: pattern_id → hint_filter Co-Authored-By: Claude Opus 4.6 --- .../compliance/api/crosswalk_routes.py | 8 +- .../compliance/services/batch_dedup_runner.py | 405 ++++++++++-------- .../tests/test_batch_dedup_runner.py | 185 ++++---- 3 files changed, 318 insertions(+), 280 deletions(-) diff --git a/backend-compliance/compliance/api/crosswalk_routes.py b/backend-compliance/compliance/api/crosswalk_routes.py index 7f31e52..c0d37c6 100644 --- a/backend-compliance/compliance/api/crosswalk_routes.py +++ b/backend-compliance/compliance/api/crosswalk_routes.py @@ -776,12 +776,12 @@ _batch_dedup_runner = None @router.post("/migrate/batch-dedup", response_model=MigrationResponse) async def migrate_batch_dedup( dry_run: bool = Query(False, description="Preview mode — no DB changes"), - pattern_id: Optional[str] = Query(None, description="Only process this pattern"), + hint_filter: Optional[str] = Query(None, description="Only process hints matching this prefix"), ): """Batch dedup: reduce ~85k Pass 0b controls to ~18-25k masters. - Groups controls by pattern_id + merge_group_hint, picks the best - quality master, and links duplicates via control_parent_links. + Phase 1: Groups by merge_group_hint, picks best quality master, links rest. + Phase 2: Cross-group embedding search for semantically similar masters. """ global _batch_dedup_runner from compliance.services.batch_dedup_runner import BatchDedupRunner @@ -790,7 +790,7 @@ async def migrate_batch_dedup( try: runner = BatchDedupRunner(db=db) _batch_dedup_runner = runner - stats = await runner.run(dry_run=dry_run, pattern_filter=pattern_id) + stats = await runner.run(dry_run=dry_run, hint_filter=hint_filter) return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Batch dedup failed: %s", e) diff --git a/backend-compliance/compliance/services/batch_dedup_runner.py b/backend-compliance/compliance/services/batch_dedup_runner.py index 69817ab..c2c4447 100644 --- a/backend-compliance/compliance/services/batch_dedup_runner.py +++ b/backend-compliance/compliance/services/batch_dedup_runner.py @@ -1,17 +1,20 @@ """Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls. -Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls by: - 1. Intra-Pattern Dedup: Group by pattern_id + merge_group_hint, pick best master - 2. Cross-Regulation Dedup: Find near-duplicates across pattern boundaries +Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls via: + Phase 1: Intra-Group Dedup — same merge_group_hint → pick best, link rest + (85k → ~52k, mostly title-identical short-circuit, no embeddings) + Phase 2: Cross-Group Dedup — embed masters, search Qdrant for similar + masters with different hints (52k → ~18-25k) -Reuses the existing 4-Stage Pipeline from control_dedup.py. Only adds -batch orchestration, quality scoring, and parent-link transfer logic. +All Pass 0b controls have pattern_id=NULL. The primary grouping key is +merge_group_hint (format: "action_type:norm_obj:trigger_key"), which +encodes the normalized action, object, and trigger. Usage: runner = BatchDedupRunner(db) stats = await runner.run(dry_run=True) # preview stats = await runner.run(dry_run=False) # execute - stats = await runner.run(pattern_filter="CP-AUTH-001") # single pattern + stats = await runner.run(hint_filter="implement:multi_factor_auth:none") """ import json @@ -22,17 +25,15 @@ from collections import defaultdict from sqlalchemy import text from compliance.services.control_dedup import ( - ControlDedupChecker, - DedupResult, canonicalize_text, ensure_qdrant_collection, get_embedding, normalize_action, normalize_object, - qdrant_search, qdrant_search_cross_regulation, qdrant_upsert, - CROSS_REG_LINK_THRESHOLD, + LINK_THRESHOLD, + REVIEW_THRESHOLD, ) logger = logging.getLogger(__name__) @@ -91,62 +92,73 @@ class BatchDedupRunner: self.collection = collection self.stats = { "total_controls": 0, - "patterns_processed": 0, - "sub_groups_processed": 0, + "unique_hints": 0, + "phase1_groups_processed": 0, "masters": 0, "linked": 0, "review": 0, "new_controls": 0, "parent_links_transferred": 0, - "cross_reg_linked": 0, + "cross_group_linked": 0, + "cross_group_review": 0, "errors": 0, "skipped_title_identical": 0, } - self._progress_pattern = "" + self._progress_phase = "" self._progress_count = 0 + self._progress_total = 0 async def run( self, dry_run: bool = False, - pattern_filter: str = None, + hint_filter: str = None, ) -> dict: """Run the full batch dedup pipeline. Args: - dry_run: If True, compute stats but don't modify DB. - pattern_filter: If set, only process this pattern_id. + dry_run: If True, compute stats but don't modify DB/Qdrant. + hint_filter: If set, only process groups matching this hint prefix. Returns: Stats dict with counts. """ start = time.monotonic() - logger.info("BatchDedup starting (dry_run=%s, pattern_filter=%s)", - dry_run, pattern_filter) + logger.info("BatchDedup starting (dry_run=%s, hint_filter=%s)", + dry_run, hint_filter) - # Ensure Qdrant collection - await ensure_qdrant_collection(collection=self.collection) + if not dry_run: + await ensure_qdrant_collection(collection=self.collection) - # Phase 1: Intra-pattern dedup - groups = self._load_pattern_groups(pattern_filter) - for pattern_id, controls in groups: + # Phase 1: Intra-group dedup (same merge_group_hint) + self._progress_phase = "phase1" + groups = self._load_merge_groups(hint_filter) + self._progress_total = self.stats["total_controls"] + + for hint, controls in groups: try: - await self._process_pattern_group(pattern_id, controls, dry_run) - self.stats["patterns_processed"] += 1 + await self._process_hint_group(hint, controls, dry_run) + self.stats["phase1_groups_processed"] += 1 except Exception as e: - logger.error("BatchDedup error on pattern %s: %s", pattern_id, e) + logger.error("BatchDedup Phase 1 error on hint %s: %s", hint, e) self.stats["errors"] += 1 - # Phase 2: Cross-regulation dedup (skip in dry_run for speed) + logger.info( + "BatchDedup Phase 1 done: %d masters, %d linked, %d review", + self.stats["masters"], self.stats["linked"], self.stats["review"], + ) + + # Phase 2: Cross-group dedup via embeddings if not dry_run: - await self._run_cross_regulation_pass() + self._progress_phase = "phase2" + await self._run_cross_group_pass() elapsed = time.monotonic() - start self.stats["elapsed_seconds"] = round(elapsed, 1) logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats) return self.stats - def _load_pattern_groups(self, pattern_filter: str = None) -> list: - """Load all Pass 0b controls grouped by pattern_id, largest first.""" + def _load_merge_groups(self, hint_filter: str = None) -> list: + """Load all Pass 0b controls grouped by merge_group_hint, largest first.""" conditions = [ "decomposition_method = 'pass0b'", "release_state != 'deprecated'", @@ -154,9 +166,9 @@ class BatchDedupRunner: ] params = {} - if pattern_filter: - conditions.append("pattern_id = :pf") - params["pf"] = pattern_filter + if hint_filter: + conditions.append("generation_metadata->>'merge_group_hint' LIKE :hf") + params["hf"] = f"{hint_filter}%" where = " AND ".join(conditions) rows = self.db.execute(text(f""" @@ -167,13 +179,12 @@ class BatchDedupRunner: generation_metadata->>'action_object_class' as action_object_class FROM canonical_controls WHERE {where} - ORDER BY pattern_id, control_id + ORDER BY control_id """), params).fetchall() - # Group by pattern_id - by_pattern = defaultdict(list) + by_hint = defaultdict(list) for r in rows: - by_pattern[r[4]].append({ + by_hint[r[9] or ""].append({ "uuid": r[0], "control_id": r[1], "title": r[2], @@ -188,10 +199,10 @@ class BatchDedupRunner: }) self.stats["total_controls"] = len(rows) + self.stats["unique_hints"] = len(by_hint) - # Sort patterns by group size (descending) for progress visibility - sorted_groups = sorted(by_pattern.items(), key=lambda x: len(x[1]), reverse=True) - logger.info("BatchDedup loaded %d controls in %d patterns", + sorted_groups = sorted(by_hint.items(), key=lambda x: len(x[1]), reverse=True) + logger.info("BatchDedup loaded %d controls in %d hint groups", len(rows), len(sorted_groups)) return sorted_groups @@ -203,99 +214,84 @@ class BatchDedupRunner: if hint: groups[hint].append(c) else: - # No hint → each control is its own group groups[f"__no_hint_{c['uuid']}"].append(c) return dict(groups) - async def _process_pattern_group( + async def _process_hint_group( self, - pattern_id: str, + hint: str, controls: list, dry_run: bool, ): - """Process all controls within a single pattern_id.""" - self._progress_pattern = pattern_id - self._progress_count = 0 - total = len(controls) + """Process all controls sharing the same merge_group_hint. - sub_groups = self._sub_group_by_merge_hint(controls) - - for hint, group in sub_groups.items(): - if len(group) < 2: - # Single control → always master - master = group[0] - self.stats["masters"] += 1 - if not dry_run: - await self._embed_and_index(master) - self._progress_count += 1 - continue - - # Sort by quality score (best first) - sorted_group = sorted(group, key=quality_score, reverse=True) - master = sorted_group[0] + Within a hint group, all controls share action+object+trigger. + The best-quality control becomes master, rest are linked as duplicates. + """ + if len(controls) < 2: + # Singleton → always master self.stats["masters"] += 1 - if not dry_run: - await self._embed_and_index(master) + await self._embed_and_index(controls[0]) + self._progress_count += 1 + self._log_progress(hint) + return - for candidate in sorted_group[1:]: - await self._check_and_link(master, candidate, pattern_id, dry_run) - self._progress_count += 1 + # Sort by quality score (best first) + sorted_group = sorted(controls, key=quality_score, reverse=True) + master = sorted_group[0] + self.stats["masters"] += 1 - self.stats["sub_groups_processed"] += 1 + if not dry_run: + await self._embed_and_index(master) - # Progress logging every 100 controls - if self._progress_count > 0 and self._progress_count % 100 == 0: - logger.info( - "BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d", - pattern_id, self._progress_count, total, - self.stats["masters"], self.stats["linked"], self.stats["review"], - ) + for candidate in sorted_group[1:]: + # All share the same hint → check title similarity + if candidate["title"].strip().lower() == master["title"].strip().lower(): + # Identical title → direct link (no embedding needed) + self.stats["linked"] += 1 + self.stats["skipped_title_identical"] += 1 + if not dry_run: + await self._mark_duplicate(master, candidate, confidence=1.0) + else: + # Different title within same hint → still likely duplicate + # Use embedding to verify + await self._check_and_link_within_group(master, candidate, dry_run) - async def _check_and_link( + self._progress_count += 1 + self._log_progress(hint) + + async def _check_and_link_within_group( self, master: dict, candidate: dict, - pattern_id: str, dry_run: bool, ): - """Check if candidate is a duplicate of master and link if so.""" - # Short-circuit: identical titles within same merge_group → direct link - if (candidate["title"].strip().lower() == master["title"].strip().lower() - and candidate["merge_group_hint"] == master["merge_group_hint"] - and candidate["merge_group_hint"]): - self.stats["linked"] += 1 - self.stats["skipped_title_identical"] += 1 - if not dry_run: - await self._mark_duplicate(master, candidate, confidence=1.0) - return - - # Extract action/object from merge_group_hint (format: "action_type:norm_obj:trigger_key") + """Check if candidate (same hint group) is duplicate of master via embedding.""" parts = candidate["merge_group_hint"].split(":", 2) action = parts[0] if len(parts) > 0 else "" obj = parts[1] if len(parts) > 1 else "" - # Build canonical text and get embedding for candidate canonical = canonicalize_text(action, obj, candidate["title"]) embedding = await get_embedding(canonical) if not embedding: - # Can't embed → keep as new control - self.stats["new_controls"] += 1 + # Can't embed → link anyway (same hint = same action+object) + self.stats["linked"] += 1 if not dry_run: - await self._embed_and_index(candidate) + await self._mark_duplicate(master, candidate, confidence=0.90) return - # Search the dedup collection for similar controls - results = await qdrant_search( - embedding, pattern_id, top_k=5, collection=self.collection, + # Search the dedup collection (unfiltered — pattern_id is NULL) + results = await qdrant_search_cross_regulation( + embedding, top_k=3, collection=self.collection, ) if not results: - # No matches → new master - self.stats["new_controls"] += 1 + # No Qdrant matches yet (master might not be indexed yet) → link to master + self.stats["linked"] += 1 if not dry_run: - await self._embed_and_index(candidate) + await self._mark_duplicate(master, candidate, confidence=0.90) return best = results[0] @@ -303,28 +299,121 @@ class BatchDedupRunner: best_payload = best.get("payload", {}) best_uuid = best_payload.get("control_uuid", "") - # Same action+object (since same merge_group_hint) → use standard thresholds - from compliance.services.control_dedup import LINK_THRESHOLD, REVIEW_THRESHOLD - if best_score > LINK_THRESHOLD: self.stats["linked"] += 1 if not dry_run: - # Link to the matched master (which may differ from our `master`) - await self._mark_duplicate_to( - master_uuid=best_uuid, - candidate=candidate, - confidence=best_score, - ) + await self._mark_duplicate_to(best_uuid, candidate, confidence=best_score) elif best_score > REVIEW_THRESHOLD: self.stats["review"] += 1 if not dry_run: self._write_review(candidate, best_payload, best_score) else: - # Below threshold → becomes a new master + # Very different despite same hint → new master self.stats["new_controls"] += 1 if not dry_run: await self._index_with_embedding(candidate, embedding) + async def _run_cross_group_pass(self): + """Phase 2: Find cross-group duplicates among surviving masters. + + After Phase 1, ~52k masters remain. Many have similar semantics + despite different merge_group_hints (e.g. different German spellings). + This pass embeds all masters and finds near-duplicates via Qdrant. + """ + logger.info("BatchDedup Phase 2: Cross-group pass starting...") + + rows = self.db.execute(text(""" + SELECT id::text, control_id, title, + generation_metadata->>'merge_group_hint' as merge_group_hint + FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state != 'duplicate' + AND release_state != 'deprecated' + ORDER BY control_id + """)).fetchall() + + self._progress_total = len(rows) + self._progress_count = 0 + logger.info("BatchDedup Cross-group: %d masters to check", len(rows)) + cross_linked = 0 + cross_review = 0 + + for i, r in enumerate(rows): + uuid = r[0] + hint = r[3] or "" + parts = hint.split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + canonical = canonicalize_text(action, obj, r[2]) + embedding = await get_embedding(canonical) + if not embedding: + continue + + results = await qdrant_search_cross_regulation( + embedding, top_k=5, collection=self.collection, + ) + if not results: + continue + + # Find best match from a DIFFERENT hint group + for match in results: + match_score = match.get("score", 0.0) + match_payload = match.get("payload", {}) + match_uuid = match_payload.get("control_uuid", "") + + # Skip self-match + if match_uuid == uuid: + continue + + # Must be a different hint group (otherwise already handled in Phase 1) + match_action = match_payload.get("action_normalized", "") + match_object = match_payload.get("object_normalized", "") + # Simple check: different control UUID is enough + if match_score > LINK_THRESHOLD: + # Mark the worse one as duplicate + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:dup AS uuid) + AND release_state != 'duplicate' + """), {"master": match_uuid, "dup": uuid}) + + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"cu": match_uuid, "pu": uuid, "conf": match_score}) + + # Transfer parent links + transferred = self._transfer_parent_links(match_uuid, uuid) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + cross_linked += 1 + break # Only one cross-link per control + elif match_score > REVIEW_THRESHOLD: + self._write_review( + {"control_id": r[1], "title": r[2], "objective": "", + "merge_group_hint": hint, "pattern_id": None}, + match_payload, match_score, + ) + cross_review += 1 + break + + self._progress_count = i + 1 + if (i + 1) % 500 == 0: + logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review", + i + 1, len(rows), cross_linked, cross_review) + + self.stats["cross_group_linked"] = cross_linked + self.stats["cross_group_review"] = cross_review + logger.info("BatchDedup Cross-group complete: %d linked, %d review", + cross_linked, cross_review) + + # ── Qdrant Helpers ─────────────────────────────────────────────────── + async def _embed_and_index(self, control: dict): """Compute embedding and index a control in the dedup Qdrant collection.""" parts = control["merge_group_hint"].split(":", 2) @@ -346,10 +435,11 @@ class BatchDedupRunner: "control_uuid": control["uuid"], "control_id": control["control_id"], "title": control["title"], - "pattern_id": control["pattern_id"], + "pattern_id": control.get("pattern_id"), "action_normalized": norm_action, "object_normalized": norm_object, "canonical_text": canonical, + "merge_group_hint": control["merge_group_hint"], }, collection=self.collection, ) @@ -371,14 +461,17 @@ class BatchDedupRunner: "control_uuid": control["uuid"], "control_id": control["control_id"], "title": control["title"], - "pattern_id": control["pattern_id"], + "pattern_id": control.get("pattern_id"), "action_normalized": norm_action, "object_normalized": norm_object, "canonical_text": canonical, + "merge_group_hint": control["merge_group_hint"], }, collection=self.collection, ) + # ── DB Write Helpers ───────────────────────────────────────────────── + async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float): """Mark candidate as duplicate of master, transfer parent links.""" self.db.execute(text(""" @@ -387,7 +480,6 @@ class BatchDedupRunner: WHERE id = CAST(:cand AS uuid) """), {"master": master["uuid"], "cand": candidate["uuid"]}) - # Add dedup_merge link self.db.execute(text(""" INSERT INTO control_parent_links (control_uuid, parent_control_uuid, link_type, confidence) @@ -395,7 +487,6 @@ class BatchDedupRunner: ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING """), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence}) - # Transfer parent links from candidate to master transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"]) self.stats["parent_links_transferred"] += transferred @@ -409,7 +500,6 @@ class BatchDedupRunner: WHERE id = CAST(:cand AS uuid) """), {"master": master_uuid, "cand": candidate["uuid"]}) - # Add dedup_merge link self.db.execute(text(""" INSERT INTO control_parent_links (control_uuid, parent_control_uuid, link_type, confidence) @@ -417,18 +507,13 @@ class BatchDedupRunner: ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING """), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence}) - # Transfer parent links transferred = self._transfer_parent_links(master_uuid, candidate["uuid"]) self.stats["parent_links_transferred"] += transferred self.db.commit() def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int: - """Move existing parent links from duplicate to master. - - Returns the number of links transferred. - """ - # Find parent links pointing TO the duplicate (where it was the child control) + """Move existing parent links from duplicate to master.""" rows = self.db.execute(text(""" SELECT parent_control_uuid::text, link_type, confidence, source_regulation, source_article, obligation_candidate_id::text @@ -440,7 +525,6 @@ class BatchDedupRunner: transferred = 0 for r in rows: parent_uuid = r[0] - # Skip self-references if parent_uuid == master_uuid: continue self.db.execute(text(""" @@ -480,81 +564,28 @@ class BatchDedupRunner: "mci": matched_payload.get("control_id"), "ss": score, "dd": json.dumps({ - "merge_group_hint": candidate["merge_group_hint"], - "pattern_id": candidate["pattern_id"], + "merge_group_hint": candidate.get("merge_group_hint", ""), + "pattern_id": candidate.get("pattern_id"), }), }) self.db.commit() - async def _run_cross_regulation_pass(self): - """Phase 2: Find cross-regulation duplicates among surviving masters.""" - logger.info("BatchDedup Phase 2: Cross-regulation pass starting...") + # ── Progress ───────────────────────────────────────────────────────── - # Load all non-duplicate pass0b controls that are now masters - rows = self.db.execute(text(""" - SELECT id::text, control_id, title, pattern_id, - generation_metadata->>'merge_group_hint' as merge_group_hint - FROM canonical_controls - WHERE decomposition_method = 'pass0b' - AND release_state != 'duplicate' - AND release_state != 'deprecated' - ORDER BY control_id - """)).fetchall() - - logger.info("BatchDedup Cross-reg: %d masters to check", len(rows)) - cross_linked = 0 - - for i, r in enumerate(rows): - uuid = r[0] - hint = r[4] or "" - parts = hint.split(":", 2) - action = parts[0] if len(parts) > 0 else "" - obj = parts[1] if len(parts) > 1 else "" - - canonical = canonicalize_text(action, obj, r[2]) - embedding = await get_embedding(canonical) - if not embedding: - continue - - results = await qdrant_search_cross_regulation( - embedding, top_k=5, collection=self.collection, + def _log_progress(self, hint: str): + """Log progress every 500 controls.""" + if self._progress_count > 0 and self._progress_count % 500 == 0: + logger.info( + "BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d", + self._progress_phase, self._progress_count, self._progress_total, + self.stats["masters"], self.stats["linked"], self.stats["review"], ) - if not results: - continue - - # Check if best match is from a DIFFERENT pattern - best = results[0] - best_score = best.get("score", 0.0) - best_payload = best.get("payload", {}) - - if (best_score > CROSS_REG_LINK_THRESHOLD - and best_payload.get("pattern_id") != r[3] - and best_payload.get("control_uuid") != uuid): - # Cross-regulation link - self.db.execute(text(""" - INSERT INTO control_parent_links - (control_uuid, parent_control_uuid, link_type, confidence) - VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) - ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING - """), { - "cu": best_payload["control_uuid"], - "pu": uuid, - "conf": best_score, - }) - self.db.commit() - cross_linked += 1 - - if (i + 1) % 500 == 0: - logger.info("BatchDedup Cross-reg: %d/%d checked, %d linked", - i + 1, len(rows), cross_linked) - - self.stats["cross_reg_linked"] = cross_linked - logger.info("BatchDedup Cross-reg complete: %d links created", cross_linked) def get_status(self) -> dict: """Return current progress stats (for status endpoint).""" return { - "pattern": self._progress_pattern, + "phase": self._progress_phase, "progress": self._progress_count, + "total": self._progress_total, **self.stats, } diff --git a/backend-compliance/tests/test_batch_dedup_runner.py b/backend-compliance/tests/test_batch_dedup_runner.py index 95e7224..fffa056 100644 --- a/backend-compliance/tests/test_batch_dedup_runner.py +++ b/backend-compliance/tests/test_batch_dedup_runner.py @@ -6,7 +6,7 @@ Covers: - Master selection (highest quality score wins) - Duplicate linking (mark + parent-link transfer) - Dry run mode (no DB changes) -- Cross-regulation pass +- Cross-group pass - Progress reporting / stats """ @@ -147,31 +147,31 @@ class TestMasterSelection: db = MagicMock() db.execute = MagicMock() db.commit = MagicMock() + # Mock parent link transfer query + db.execute.return_value.fetchall.return_value = [] runner = BatchDedupRunner(db=db) - sparse = _make_control("s1", reqs=1, hint="implement:mfa:none") - rich = _make_control("r1", reqs=5, tests=3, evidence=2, hint="implement:mfa:none") - medium = _make_control("m1", reqs=2, tests=1, hint="implement:mfa:none") + sparse = _make_control("s1", reqs=1, hint="implement:mfa:none", + title="MFA implementiert") + rich = _make_control("r1", reqs=5, tests=3, evidence=2, + hint="implement:mfa:none", title="MFA implementiert") + medium = _make_control("m1", reqs=2, tests=1, + hint="implement:mfa:none", title="MFA implementiert") controls = [sparse, medium, rich] - # Mock embedding to avoid real API calls + # All have same title → all should be title-identical linked with patch("compliance.services.batch_dedup_runner.get_embedding", new_callable=AsyncMock, return_value=[0.1] * 1024), \ patch("compliance.services.batch_dedup_runner.qdrant_upsert", - new_callable=AsyncMock, return_value=True), \ - patch("compliance.services.batch_dedup_runner.qdrant_search", - new_callable=AsyncMock, return_value=[{ - "score": 0.95, - "payload": {"control_uuid": rich["uuid"], - "control_id": rich["control_id"]}, - }]): - await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True) + new_callable=AsyncMock, return_value=True): + await runner._process_hint_group("implement:mfa:none", controls, dry_run=True) # Rich should be master (1 master), others linked (2 linked) assert runner.stats["masters"] == 1 assert runner.stats["linked"] == 2 + assert runner.stats["skipped_title_identical"] == 2 # --------------------------------------------------------------------------- @@ -191,28 +191,19 @@ class TestDryRun: runner = BatchDedupRunner(db=db) controls = [ - _make_control("a", reqs=3, hint="implement:mfa:none"), - _make_control("b", reqs=1, hint="implement:mfa:none"), + _make_control("a", reqs=3, hint="implement:mfa:none", title="MFA impl"), + _make_control("b", reqs=1, hint="implement:mfa:none", title="MFA impl"), ] with patch("compliance.services.batch_dedup_runner.get_embedding", new_callable=AsyncMock, return_value=[0.1] * 1024), \ patch("compliance.services.batch_dedup_runner.qdrant_upsert", - new_callable=AsyncMock, return_value=True), \ - patch("compliance.services.batch_dedup_runner.qdrant_search", - new_callable=AsyncMock, return_value=[{ - "score": 0.95, - "payload": {"control_uuid": "a-uuid", - "control_id": "AUTH-001"}, - }]): - await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True) + new_callable=AsyncMock, return_value=True): + await runner._process_hint_group("implement:mfa:none", controls, dry_run=True) - # No DB execute calls for UPDATE/INSERT (only the initial load query was mocked) - # In dry_run, _mark_duplicate and _embed_and_index are skipped assert runner.stats["masters"] == 1 - # qdrant_upsert should NOT have been called (dry_run skips indexing) - from compliance.services.batch_dedup_runner import qdrant_upsert - # No commit for dedup operations + assert runner.stats["linked"] == 1 + # No commit for dedup operations in dry_run db.commit.assert_not_called() @@ -261,56 +252,100 @@ class TestTitleIdenticalShortCircuit: @pytest.mark.asyncio async def test_identical_titles_skip_embedding(self): - """Controls with identical titles in same merge group → direct link.""" + """Controls with identical titles in same hint group → direct link.""" db = MagicMock() db.execute = MagicMock() db.commit = MagicMock() - # Mock the parent link transfer query db.execute.return_value.fetchall.return_value = [] runner = BatchDedupRunner(db=db) - master = _make_control("m", reqs=3, hint="implement:mfa:none", - title="MFA implementieren") - candidate = _make_control("c", reqs=1, hint="implement:mfa:none", - title="MFA implementieren") + controls = [ + _make_control("m", reqs=3, hint="implement:mfa:none", + title="MFA implementieren"), + _make_control("c", reqs=1, hint="implement:mfa:none", + title="MFA implementieren"), + ] with patch("compliance.services.batch_dedup_runner.get_embedding", - new_callable=AsyncMock) as mock_embed: - await runner._check_and_link(master, candidate, "CP-AUTH-001", dry_run=False) + new_callable=AsyncMock) as mock_embed, \ + patch("compliance.services.batch_dedup_runner.qdrant_upsert", + new_callable=AsyncMock, return_value=True): + await runner._process_hint_group("implement:mfa:none", controls, dry_run=False) - # Embedding should NOT be called (title-identical short-circuit) - mock_embed.assert_not_called() + # Embedding should only be called for the master (indexing), not for linking assert runner.stats["linked"] == 1 assert runner.stats["skipped_title_identical"] == 1 - -# --------------------------------------------------------------------------- -# Cross-Regulation Pass TESTS -# --------------------------------------------------------------------------- - - -class TestCrossRegulationPass: - @pytest.mark.asyncio - async def test_cross_reg_creates_link(self): + async def test_different_titles_use_embedding(self): + """Controls with different titles should use embedding check.""" db = MagicMock() db.execute = MagicMock() db.commit = MagicMock() - # First call: load masters - db.execute.return_value.fetchall.return_value = [ - ("uuid-1", "AUTH-001", "MFA implementieren", "CP-AUTH-001", + db.execute.return_value.fetchall.return_value = [] + + runner = BatchDedupRunner(db=db) + + controls = [ + _make_control("m", reqs=3, hint="implement:mfa:none", + title="MFA implementieren fuer Admins"), + _make_control("c", reqs=1, hint="implement:mfa:none", + title="MFA einrichten fuer alle Benutzer"), + ] + + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock, return_value=[0.1] * 1024) as mock_embed, \ + patch("compliance.services.batch_dedup_runner.qdrant_upsert", + new_callable=AsyncMock, return_value=True), \ + patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation", + new_callable=AsyncMock, return_value=[]): + await runner._process_hint_group("implement:mfa:none", controls, dry_run=False) + + # Different titles → embedding was called for both (master + candidate) + assert mock_embed.call_count >= 2 + # No Qdrant results → linked anyway (same hint = same action+object) + assert runner.stats["linked"] == 1 + + +# --------------------------------------------------------------------------- +# Cross-Group Pass TESTS +# --------------------------------------------------------------------------- + + +class TestCrossGroupPass: + + @pytest.mark.asyncio + async def test_cross_group_creates_link(self): + db = MagicMock() + db.commit = MagicMock() + + # First call returns masters, subsequent calls return empty (for transfer) + master_rows = [ + ("uuid-1", "CTRL-001", "MFA implementieren", "implement:multi_factor_auth:none"), ] + call_count = {"n": 0} + + def mock_execute(stmt, params=None): + result = MagicMock() + call_count["n"] += 1 + if call_count["n"] == 1: + result.fetchall.return_value = master_rows + else: + result.fetchall.return_value = [] + return result + + db.execute = mock_execute runner = BatchDedupRunner(db=db) cross_result = [{ - "score": 0.96, + "score": 0.95, "payload": { "control_uuid": "uuid-2", - "control_id": "SEC-001", - "pattern_id": "CP-SEC-001", # different pattern! + "control_id": "CTRL-002", + "merge_group_hint": "implement:mfa:continuous", }, }] @@ -318,39 +353,9 @@ class TestCrossRegulationPass: new_callable=AsyncMock, return_value=[0.1] * 1024), \ patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation", new_callable=AsyncMock, return_value=cross_result): - await runner._run_cross_regulation_pass() + await runner._run_cross_group_pass() - assert runner.stats["cross_reg_linked"] == 1 - - @pytest.mark.asyncio - async def test_cross_reg_ignores_same_pattern(self): - """Cross-reg should NOT link controls from same pattern.""" - db = MagicMock() - db.execute = MagicMock() - db.commit = MagicMock() - db.execute.return_value.fetchall.return_value = [ - ("uuid-1", "AUTH-001", "MFA", "CP-AUTH-001", "implement:mfa:none"), - ] - - runner = BatchDedupRunner(db=db) - - # Match from SAME pattern - cross_result = [{ - "score": 0.97, - "payload": { - "control_uuid": "uuid-3", - "control_id": "AUTH-002", - "pattern_id": "CP-AUTH-001", # same pattern - }, - }] - - with patch("compliance.services.batch_dedup_runner.get_embedding", - new_callable=AsyncMock, return_value=[0.1] * 1024), \ - patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation", - new_callable=AsyncMock, return_value=cross_result): - await runner._run_cross_regulation_pass() - - assert runner.stats["cross_reg_linked"] == 0 + assert runner.stats["cross_group_linked"] == 1 # --------------------------------------------------------------------------- @@ -365,12 +370,14 @@ class TestProgressStats: runner = BatchDedupRunner(db=db) runner.stats["masters"] = 42 runner.stats["linked"] = 100 - runner._progress_pattern = "CP-AUTH-001" + runner._progress_phase = "phase1" runner._progress_count = 500 + runner._progress_total = 85000 status = runner.get_status() - assert status["pattern"] == "CP-AUTH-001" + assert status["phase"] == "phase1" assert status["progress"] == 500 + assert status["total"] == 85000 assert status["masters"] == 42 assert status["linked"] == 100 @@ -415,12 +422,12 @@ def _make_control( evidence: int = 0, hint: str = "", title: str = None, - pattern_id: str = "CP-AUTH-001", + pattern_id: str = None, ) -> dict: """Build a mock control dict for testing.""" return { "uuid": f"{prefix}-uuid", - "control_id": f"AUTH-{prefix}", + "control_id": f"CTRL-{prefix}", "title": title or f"Control {prefix}", "objective": f"Objective for {prefix}", "pattern_id": pattern_id,