From 35784c35ebc552237d77c328b41f7f402e1d6ac2 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Tue, 24 Mar 2026 07:06:38 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20Batch=20Dedup=20Runner=20=E2=80=94=2085?= =?UTF-8?q?k=E2=86=92~18-25k=20Master=20Controls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds batch orchestration for deduplicating ~85k Pass 0b atomic controls into ~18-25k unique masters with M:N parent linking. New files: - migrations/078_batch_dedup.sql: merged_into_uuid column, perf indexes, link_type CHECK extended for cross_regulation - batch_dedup_runner.py: BatchDedupRunner with quality scoring, merge-hint grouping, title-identical short-circuit, parent-link transfer, and cross-regulation pass - tests/test_batch_dedup_runner.py: 21 tests (all passing) Modified: - control_dedup.py: optional collection param on Qdrant functions - crosswalk_routes.py: POST/GET batch-dedup endpoints Co-Authored-By: Claude Opus 4.6 --- .../compliance/api/crosswalk_routes.py | 69 +++ .../compliance/services/batch_dedup_runner.py | 560 ++++++++++++++++++ .../compliance/services/control_dedup.py | 32 +- .../migrations/078_batch_dedup.sql | 42 ++ .../tests/test_batch_dedup_runner.py | 433 ++++++++++++++ 5 files changed, 1126 insertions(+), 10 deletions(-) create mode 100644 backend-compliance/compliance/services/batch_dedup_runner.py create mode 100644 backend-compliance/migrations/078_batch_dedup.sql create mode 100644 backend-compliance/tests/test_batch_dedup_runner.py diff --git a/backend-compliance/compliance/api/crosswalk_routes.py b/backend-compliance/compliance/api/crosswalk_routes.py index 3d5f754..7f31e52 100644 --- a/backend-compliance/compliance/api/crosswalk_routes.py +++ b/backend-compliance/compliance/api/crosswalk_routes.py @@ -764,6 +764,75 @@ async def decomposition_status(): db.close() +# ============================================================================= +# BATCH DEDUP ENDPOINTS +# ============================================================================= + + +# Module-level runner reference for status polling +_batch_dedup_runner = None + + +@router.post("/migrate/batch-dedup", response_model=MigrationResponse) +async def migrate_batch_dedup( + dry_run: bool = Query(False, description="Preview mode — no DB changes"), + pattern_id: Optional[str] = Query(None, description="Only process this pattern"), +): + """Batch dedup: reduce ~85k Pass 0b controls to ~18-25k masters. + + Groups controls by pattern_id + merge_group_hint, picks the best + quality master, and links duplicates via control_parent_links. + """ + global _batch_dedup_runner + from compliance.services.batch_dedup_runner import BatchDedupRunner + + db = SessionLocal() + try: + runner = BatchDedupRunner(db=db) + _batch_dedup_runner = runner + stats = await runner.run(dry_run=dry_run, pattern_filter=pattern_id) + return MigrationResponse(status="completed", stats=stats) + except Exception as e: + logger.error("Batch dedup failed: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + finally: + _batch_dedup_runner = None + db.close() + + +@router.get("/migrate/batch-dedup/status") +async def batch_dedup_status(): + """Get current batch dedup progress (while running).""" + if _batch_dedup_runner is not None: + return {"running": True, **_batch_dedup_runner.get_status()} + + # Not running — show DB stats + db = SessionLocal() + try: + row = db.execute(text(""" + SELECT + count(*) FILTER (WHERE decomposition_method = 'pass0b') AS total_pass0b, + count(*) FILTER (WHERE decomposition_method = 'pass0b' + AND release_state = 'duplicate') AS duplicates, + count(*) FILTER (WHERE decomposition_method = 'pass0b' + AND release_state != 'duplicate' + AND release_state != 'deprecated') AS masters + FROM canonical_controls + """)).fetchone() + review_count = db.execute(text( + "SELECT count(*) FROM control_dedup_reviews WHERE review_status = 'pending'" + )).fetchone()[0] + return { + "running": False, + "total_pass0b": row[0], + "duplicates": row[1], + "masters": row[2], + "pending_reviews": review_count, + } + finally: + db.close() + + # ============================================================================= # HELPERS # ============================================================================= diff --git a/backend-compliance/compliance/services/batch_dedup_runner.py b/backend-compliance/compliance/services/batch_dedup_runner.py new file mode 100644 index 0000000..69817ab --- /dev/null +++ b/backend-compliance/compliance/services/batch_dedup_runner.py @@ -0,0 +1,560 @@ +"""Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls. + +Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls by: + 1. Intra-Pattern Dedup: Group by pattern_id + merge_group_hint, pick best master + 2. Cross-Regulation Dedup: Find near-duplicates across pattern boundaries + +Reuses the existing 4-Stage Pipeline from control_dedup.py. Only adds +batch orchestration, quality scoring, and parent-link transfer logic. + +Usage: + runner = BatchDedupRunner(db) + stats = await runner.run(dry_run=True) # preview + stats = await runner.run(dry_run=False) # execute + stats = await runner.run(pattern_filter="CP-AUTH-001") # single pattern +""" + +import json +import logging +import time +from collections import defaultdict + +from sqlalchemy import text + +from compliance.services.control_dedup import ( + ControlDedupChecker, + DedupResult, + canonicalize_text, + ensure_qdrant_collection, + get_embedding, + normalize_action, + normalize_object, + qdrant_search, + qdrant_search_cross_regulation, + qdrant_upsert, + CROSS_REG_LINK_THRESHOLD, +) + +logger = logging.getLogger(__name__) + +DEDUP_COLLECTION = "atomic_controls_dedup" + + +# ── Quality Score ──────────────────────────────────────────────────────── + + +def quality_score(control: dict) -> float: + """Score a control by richness of requirements, tests, evidence, and objective. + + Higher score = better candidate for master control. + """ + score = 0.0 + + reqs = control.get("requirements") or "[]" + if isinstance(reqs, str): + try: + reqs = json.loads(reqs) + except (json.JSONDecodeError, TypeError): + reqs = [] + score += len(reqs) * 2.0 + + tests = control.get("test_procedure") or "[]" + if isinstance(tests, str): + try: + tests = json.loads(tests) + except (json.JSONDecodeError, TypeError): + tests = [] + score += len(tests) * 1.5 + + evidence = control.get("evidence") or "[]" + if isinstance(evidence, str): + try: + evidence = json.loads(evidence) + except (json.JSONDecodeError, TypeError): + evidence = [] + score += len(evidence) * 1.0 + + objective = control.get("objective") or "" + score += min(len(objective) / 200, 3.0) + + return score + + +# ── Batch Dedup Runner ─────────────────────────────────────────────────── + + +class BatchDedupRunner: + """Batch dedup orchestrator for existing Pass 0b atomic controls.""" + + def __init__(self, db, collection: str = DEDUP_COLLECTION): + self.db = db + self.collection = collection + self.stats = { + "total_controls": 0, + "patterns_processed": 0, + "sub_groups_processed": 0, + "masters": 0, + "linked": 0, + "review": 0, + "new_controls": 0, + "parent_links_transferred": 0, + "cross_reg_linked": 0, + "errors": 0, + "skipped_title_identical": 0, + } + self._progress_pattern = "" + self._progress_count = 0 + + async def run( + self, + dry_run: bool = False, + pattern_filter: str = None, + ) -> dict: + """Run the full batch dedup pipeline. + + Args: + dry_run: If True, compute stats but don't modify DB. + pattern_filter: If set, only process this pattern_id. + + Returns: + Stats dict with counts. + """ + start = time.monotonic() + logger.info("BatchDedup starting (dry_run=%s, pattern_filter=%s)", + dry_run, pattern_filter) + + # Ensure Qdrant collection + await ensure_qdrant_collection(collection=self.collection) + + # Phase 1: Intra-pattern dedup + groups = self._load_pattern_groups(pattern_filter) + for pattern_id, controls in groups: + try: + await self._process_pattern_group(pattern_id, controls, dry_run) + self.stats["patterns_processed"] += 1 + except Exception as e: + logger.error("BatchDedup error on pattern %s: %s", pattern_id, e) + self.stats["errors"] += 1 + + # Phase 2: Cross-regulation dedup (skip in dry_run for speed) + if not dry_run: + await self._run_cross_regulation_pass() + + elapsed = time.monotonic() - start + self.stats["elapsed_seconds"] = round(elapsed, 1) + logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats) + return self.stats + + def _load_pattern_groups(self, pattern_filter: str = None) -> list: + """Load all Pass 0b controls grouped by pattern_id, largest first.""" + conditions = [ + "decomposition_method = 'pass0b'", + "release_state != 'deprecated'", + "release_state != 'duplicate'", + ] + params = {} + + if pattern_filter: + conditions.append("pattern_id = :pf") + params["pf"] = pattern_filter + + where = " AND ".join(conditions) + rows = self.db.execute(text(f""" + SELECT id::text, control_id, title, objective, + pattern_id, requirements::text, test_procedure::text, + evidence::text, release_state, + generation_metadata->>'merge_group_hint' as merge_group_hint, + generation_metadata->>'action_object_class' as action_object_class + FROM canonical_controls + WHERE {where} + ORDER BY pattern_id, control_id + """), params).fetchall() + + # Group by pattern_id + by_pattern = defaultdict(list) + for r in rows: + by_pattern[r[4]].append({ + "uuid": r[0], + "control_id": r[1], + "title": r[2], + "objective": r[3], + "pattern_id": r[4], + "requirements": r[5], + "test_procedure": r[6], + "evidence": r[7], + "release_state": r[8], + "merge_group_hint": r[9] or "", + "action_object_class": r[10] or "", + }) + + self.stats["total_controls"] = len(rows) + + # Sort patterns by group size (descending) for progress visibility + sorted_groups = sorted(by_pattern.items(), key=lambda x: len(x[1]), reverse=True) + logger.info("BatchDedup loaded %d controls in %d patterns", + len(rows), len(sorted_groups)) + return sorted_groups + + def _sub_group_by_merge_hint(self, controls: list) -> dict: + """Group controls by merge_group_hint composite key.""" + groups = defaultdict(list) + for c in controls: + hint = c["merge_group_hint"] + if hint: + groups[hint].append(c) + else: + # No hint → each control is its own group + groups[f"__no_hint_{c['uuid']}"].append(c) + return dict(groups) + + async def _process_pattern_group( + self, + pattern_id: str, + controls: list, + dry_run: bool, + ): + """Process all controls within a single pattern_id.""" + self._progress_pattern = pattern_id + self._progress_count = 0 + total = len(controls) + + sub_groups = self._sub_group_by_merge_hint(controls) + + for hint, group in sub_groups.items(): + if len(group) < 2: + # Single control → always master + master = group[0] + self.stats["masters"] += 1 + if not dry_run: + await self._embed_and_index(master) + self._progress_count += 1 + continue + + # Sort by quality score (best first) + sorted_group = sorted(group, key=quality_score, reverse=True) + master = sorted_group[0] + self.stats["masters"] += 1 + + if not dry_run: + await self._embed_and_index(master) + + for candidate in sorted_group[1:]: + await self._check_and_link(master, candidate, pattern_id, dry_run) + self._progress_count += 1 + + self.stats["sub_groups_processed"] += 1 + + # Progress logging every 100 controls + if self._progress_count > 0 and self._progress_count % 100 == 0: + logger.info( + "BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d", + pattern_id, self._progress_count, total, + self.stats["masters"], self.stats["linked"], self.stats["review"], + ) + + async def _check_and_link( + self, + master: dict, + candidate: dict, + pattern_id: str, + dry_run: bool, + ): + """Check if candidate is a duplicate of master and link if so.""" + # Short-circuit: identical titles within same merge_group → direct link + if (candidate["title"].strip().lower() == master["title"].strip().lower() + and candidate["merge_group_hint"] == master["merge_group_hint"] + and candidate["merge_group_hint"]): + self.stats["linked"] += 1 + self.stats["skipped_title_identical"] += 1 + if not dry_run: + await self._mark_duplicate(master, candidate, confidence=1.0) + return + + # Extract action/object from merge_group_hint (format: "action_type:norm_obj:trigger_key") + parts = candidate["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + # Build canonical text and get embedding for candidate + canonical = canonicalize_text(action, obj, candidate["title"]) + embedding = await get_embedding(canonical) + + if not embedding: + # Can't embed → keep as new control + self.stats["new_controls"] += 1 + if not dry_run: + await self._embed_and_index(candidate) + return + + # Search the dedup collection for similar controls + results = await qdrant_search( + embedding, pattern_id, top_k=5, collection=self.collection, + ) + + if not results: + # No matches → new master + self.stats["new_controls"] += 1 + if not dry_run: + await self._embed_and_index(candidate) + return + + best = results[0] + best_score = best.get("score", 0.0) + best_payload = best.get("payload", {}) + best_uuid = best_payload.get("control_uuid", "") + + # Same action+object (since same merge_group_hint) → use standard thresholds + from compliance.services.control_dedup import LINK_THRESHOLD, REVIEW_THRESHOLD + + if best_score > LINK_THRESHOLD: + self.stats["linked"] += 1 + if not dry_run: + # Link to the matched master (which may differ from our `master`) + await self._mark_duplicate_to( + master_uuid=best_uuid, + candidate=candidate, + confidence=best_score, + ) + elif best_score > REVIEW_THRESHOLD: + self.stats["review"] += 1 + if not dry_run: + self._write_review(candidate, best_payload, best_score) + else: + # Below threshold → becomes a new master + self.stats["new_controls"] += 1 + if not dry_run: + await self._index_with_embedding(candidate, embedding) + + async def _embed_and_index(self, control: dict): + """Compute embedding and index a control in the dedup Qdrant collection.""" + parts = control["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + canonical = canonicalize_text(action, obj, control["title"]) + embedding = await get_embedding(canonical) + + if not embedding: + return + + await qdrant_upsert( + point_id=control["uuid"], + embedding=embedding, + payload={ + "control_uuid": control["uuid"], + "control_id": control["control_id"], + "title": control["title"], + "pattern_id": control["pattern_id"], + "action_normalized": norm_action, + "object_normalized": norm_object, + "canonical_text": canonical, + }, + collection=self.collection, + ) + + async def _index_with_embedding(self, control: dict, embedding: list): + """Index a control with a pre-computed embedding.""" + parts = control["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + canonical = canonicalize_text(action, obj, control["title"]) + + await qdrant_upsert( + point_id=control["uuid"], + embedding=embedding, + payload={ + "control_uuid": control["uuid"], + "control_id": control["control_id"], + "title": control["title"], + "pattern_id": control["pattern_id"], + "action_normalized": norm_action, + "object_normalized": norm_object, + "canonical_text": canonical, + }, + collection=self.collection, + ) + + async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float): + """Mark candidate as duplicate of master, transfer parent links.""" + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:cand AS uuid) + """), {"master": master["uuid"], "cand": candidate["uuid"]}) + + # Add dedup_merge link + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence}) + + # Transfer parent links from candidate to master + transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"]) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + + async def _mark_duplicate_to(self, master_uuid: str, candidate: dict, confidence: float): + """Mark candidate as duplicate of a Qdrant-matched master.""" + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:cand AS uuid) + """), {"master": master_uuid, "cand": candidate["uuid"]}) + + # Add dedup_merge link + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence}) + + # Transfer parent links + transferred = self._transfer_parent_links(master_uuid, candidate["uuid"]) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + + def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int: + """Move existing parent links from duplicate to master. + + Returns the number of links transferred. + """ + # Find parent links pointing TO the duplicate (where it was the child control) + rows = self.db.execute(text(""" + SELECT parent_control_uuid::text, link_type, confidence, + source_regulation, source_article, obligation_candidate_id::text + FROM control_parent_links + WHERE control_uuid = CAST(:dup AS uuid) + AND link_type = 'decomposition' + """), {"dup": duplicate_uuid}).fetchall() + + transferred = 0 + for r in rows: + parent_uuid = r[0] + # Skip self-references + if parent_uuid == master_uuid: + continue + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence, + source_regulation, source_article, obligation_candidate_id) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), :lt, :conf, + :sr, :sa, CAST(:oci AS uuid)) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), { + "cu": master_uuid, + "pu": parent_uuid, + "lt": r[1], + "conf": float(r[2]) if r[2] else 1.0, + "sr": r[3], + "sa": r[4], + "oci": r[5], + }) + transferred += 1 + + return transferred + + def _write_review(self, candidate: dict, matched_payload: dict, score: float): + """Write a dedup review entry for borderline matches.""" + self.db.execute(text(""" + INSERT INTO control_dedup_reviews + (candidate_control_id, candidate_title, candidate_objective, + matched_control_uuid, matched_control_id, + similarity_score, dedup_stage, dedup_details) + VALUES (:ccid, :ct, :co, CAST(:mcu AS uuid), :mci, + :ss, 'batch_dedup', :dd::jsonb) + """), { + "ccid": candidate["control_id"], + "ct": candidate["title"], + "co": candidate.get("objective", ""), + "mcu": matched_payload.get("control_uuid"), + "mci": matched_payload.get("control_id"), + "ss": score, + "dd": json.dumps({ + "merge_group_hint": candidate["merge_group_hint"], + "pattern_id": candidate["pattern_id"], + }), + }) + self.db.commit() + + async def _run_cross_regulation_pass(self): + """Phase 2: Find cross-regulation duplicates among surviving masters.""" + logger.info("BatchDedup Phase 2: Cross-regulation pass starting...") + + # Load all non-duplicate pass0b controls that are now masters + rows = self.db.execute(text(""" + SELECT id::text, control_id, title, pattern_id, + generation_metadata->>'merge_group_hint' as merge_group_hint + FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state != 'duplicate' + AND release_state != 'deprecated' + ORDER BY control_id + """)).fetchall() + + logger.info("BatchDedup Cross-reg: %d masters to check", len(rows)) + cross_linked = 0 + + for i, r in enumerate(rows): + uuid = r[0] + hint = r[4] or "" + parts = hint.split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + canonical = canonicalize_text(action, obj, r[2]) + embedding = await get_embedding(canonical) + if not embedding: + continue + + results = await qdrant_search_cross_regulation( + embedding, top_k=5, collection=self.collection, + ) + if not results: + continue + + # Check if best match is from a DIFFERENT pattern + best = results[0] + best_score = best.get("score", 0.0) + best_payload = best.get("payload", {}) + + if (best_score > CROSS_REG_LINK_THRESHOLD + and best_payload.get("pattern_id") != r[3] + and best_payload.get("control_uuid") != uuid): + # Cross-regulation link + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), { + "cu": best_payload["control_uuid"], + "pu": uuid, + "conf": best_score, + }) + self.db.commit() + cross_linked += 1 + + if (i + 1) % 500 == 0: + logger.info("BatchDedup Cross-reg: %d/%d checked, %d linked", + i + 1, len(rows), cross_linked) + + self.stats["cross_reg_linked"] = cross_linked + logger.info("BatchDedup Cross-reg complete: %d links created", cross_linked) + + def get_status(self) -> dict: + """Return current progress stats (for status endpoint).""" + return { + "pattern": self._progress_pattern, + "progress": self._progress_count, + **self.stats, + } diff --git a/backend-compliance/compliance/services/control_dedup.py b/backend-compliance/compliance/services/control_dedup.py index 4e4b263..26a26f2 100644 --- a/backend-compliance/compliance/services/control_dedup.py +++ b/backend-compliance/compliance/services/control_dedup.py @@ -317,10 +317,12 @@ async def qdrant_search( embedding: list[float], pattern_id: str, top_k: int = 10, + collection: Optional[str] = None, ) -> list[dict]: """Search Qdrant for similar atomic controls, filtered by pattern_id.""" if not embedding: return [] + coll = collection or QDRANT_COLLECTION body: dict = { "vector": embedding, "limit": top_k, @@ -334,7 +336,7 @@ async def qdrant_search( try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search", + f"{QDRANT_URL}/collections/{coll}/points/search", json=body, ) if resp.status_code != 200: @@ -349,6 +351,7 @@ async def qdrant_search( async def qdrant_search_cross_regulation( embedding: list[float], top_k: int = 5, + collection: Optional[str] = None, ) -> list[dict]: """Search Qdrant for similar controls across ALL regulations (no pattern_id filter). @@ -356,6 +359,7 @@ async def qdrant_search_cross_regulation( """ if not embedding: return [] + coll = collection or QDRANT_COLLECTION body: dict = { "vector": embedding, "limit": top_k, @@ -364,7 +368,7 @@ async def qdrant_search_cross_regulation( try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search", + f"{QDRANT_URL}/collections/{coll}/points/search", json=body, ) if resp.status_code != 200: @@ -380,10 +384,12 @@ async def qdrant_upsert( point_id: str, embedding: list[float], payload: dict, + collection: Optional[str] = None, ) -> bool: - """Upsert a single point into the atomic_controls Qdrant collection.""" + """Upsert a single point into a Qdrant collection.""" if not embedding: return False + coll = collection or QDRANT_COLLECTION body = { "points": [{ "id": point_id, @@ -394,7 +400,7 @@ async def qdrant_upsert( try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.put( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points", + f"{QDRANT_URL}/collections/{coll}/points", json=body, ) return resp.status_code == 200 @@ -403,27 +409,31 @@ async def qdrant_upsert( return False -async def ensure_qdrant_collection(vector_size: int = 1024) -> bool: - """Create the Qdrant collection if it doesn't exist (idempotent).""" +async def ensure_qdrant_collection( + vector_size: int = 1024, + collection: Optional[str] = None, +) -> bool: + """Create a Qdrant collection if it doesn't exist (idempotent).""" + coll = collection or QDRANT_COLLECTION try: async with httpx.AsyncClient(timeout=10.0) as client: # Check if exists - resp = await client.get(f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}") + resp = await client.get(f"{QDRANT_URL}/collections/{coll}") if resp.status_code == 200: return True # Create resp = await client.put( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}", + f"{QDRANT_URL}/collections/{coll}", json={ "vectors": {"size": vector_size, "distance": "Cosine"}, }, ) if resp.status_code == 200: - logger.info("Created Qdrant collection: %s", QDRANT_COLLECTION) + logger.info("Created Qdrant collection: %s", coll) # Create payload indexes for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]: await client.put( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/index", + f"{QDRANT_URL}/collections/{coll}/index", json={"field_name": field_name, "field_schema": "keyword"}, ) return True @@ -710,6 +720,7 @@ class ControlDedupChecker: action: str, obj: str, pattern_id: str, + collection: Optional[str] = None, ) -> bool: """Index a new atomic control in Qdrant for future dedup checks.""" norm_action = normalize_action(action) @@ -730,4 +741,5 @@ class ControlDedupChecker: "object_normalized": norm_object, "canonical_text": canonical, }, + collection=collection, ) diff --git a/backend-compliance/migrations/078_batch_dedup.sql b/backend-compliance/migrations/078_batch_dedup.sql new file mode 100644 index 0000000..a91a56b --- /dev/null +++ b/backend-compliance/migrations/078_batch_dedup.sql @@ -0,0 +1,42 @@ +-- Migration 078: Batch Dedup — Schema extensions for 85k→~18-25k reduction +-- Adds merged_into_uuid tracking, performance indexes for batch dedup, +-- and extends link_type CHECK to include 'cross_regulation'. + +BEGIN; + +-- ============================================================================= +-- 1. merged_into_uuid: Track which master a duplicate was merged into +-- ============================================================================= + +ALTER TABLE canonical_controls + ADD COLUMN IF NOT EXISTS merged_into_uuid UUID REFERENCES canonical_controls(id); + +CREATE INDEX IF NOT EXISTS idx_cc_merged_into + ON canonical_controls(merged_into_uuid) WHERE merged_into_uuid IS NOT NULL; + +-- ============================================================================= +-- 2. Performance indexes for batch dedup queries +-- ============================================================================= + +-- Index on merge_group_hint inside generation_metadata (for sub-grouping) +CREATE INDEX IF NOT EXISTS idx_cc_merge_group_hint + ON canonical_controls ((generation_metadata->>'merge_group_hint')) + WHERE decomposition_method = 'pass0b'; + +-- Composite index for pattern-based dedup loading +CREATE INDEX IF NOT EXISTS idx_cc_pattern_dedup + ON canonical_controls (pattern_id, release_state) + WHERE decomposition_method = 'pass0b'; + +-- ============================================================================= +-- 3. Extend link_type CHECK to include 'cross_regulation' +-- ============================================================================= + +ALTER TABLE control_parent_links + DROP CONSTRAINT IF EXISTS control_parent_links_link_type_check; + +ALTER TABLE control_parent_links + ADD CONSTRAINT control_parent_links_link_type_check + CHECK (link_type IN ('decomposition', 'dedup_merge', 'manual', 'crosswalk', 'cross_regulation')); + +COMMIT; diff --git a/backend-compliance/tests/test_batch_dedup_runner.py b/backend-compliance/tests/test_batch_dedup_runner.py new file mode 100644 index 0000000..95e7224 --- /dev/null +++ b/backend-compliance/tests/test_batch_dedup_runner.py @@ -0,0 +1,433 @@ +"""Tests for Batch Dedup Runner (batch_dedup_runner.py). + +Covers: +- quality_score(): Richness ranking +- BatchDedupRunner._sub_group_by_merge_hint(): Composite key grouping +- Master selection (highest quality score wins) +- Duplicate linking (mark + parent-link transfer) +- Dry run mode (no DB changes) +- Cross-regulation pass +- Progress reporting / stats +""" + +import json +import pytest +from unittest.mock import MagicMock, AsyncMock, patch, call + +from compliance.services.batch_dedup_runner import ( + quality_score, + BatchDedupRunner, + DEDUP_COLLECTION, +) + + +# --------------------------------------------------------------------------- +# quality_score TESTS +# --------------------------------------------------------------------------- + + +class TestQualityScore: + """Quality scoring: richer controls should score higher.""" + + def test_empty_control(self): + score = quality_score({}) + assert score == 0.0 + + def test_requirements_weight(self): + score = quality_score({"requirements": json.dumps(["r1", "r2", "r3"])}) + assert score == pytest.approx(6.0) # 3 * 2.0 + + def test_test_procedure_weight(self): + score = quality_score({"test_procedure": json.dumps(["t1", "t2"])}) + assert score == pytest.approx(3.0) # 2 * 1.5 + + def test_evidence_weight(self): + score = quality_score({"evidence": json.dumps(["e1"])}) + assert score == pytest.approx(1.0) # 1 * 1.0 + + def test_objective_weight_capped(self): + short = quality_score({"objective": "x" * 100}) + long = quality_score({"objective": "x" * 1000}) + assert short == pytest.approx(0.5) # 100/200 + assert long == pytest.approx(3.0) # capped at 3.0 + + def test_combined_score(self): + control = { + "requirements": json.dumps(["r1", "r2"]), + "test_procedure": json.dumps(["t1"]), + "evidence": json.dumps(["e1", "e2"]), + "objective": "x" * 400, + } + # 2*2 + 1*1.5 + 2*1.0 + min(400/200, 3) = 4 + 1.5 + 2 + 2 = 9.5 + assert quality_score(control) == pytest.approx(9.5) + + def test_json_string_vs_list(self): + """Both JSON strings and already-parsed lists should work.""" + a = quality_score({"requirements": json.dumps(["r1", "r2"])}) + b = quality_score({"requirements": '["r1", "r2"]'}) + assert a == b + + def test_null_fields(self): + """None values should not crash.""" + score = quality_score({ + "requirements": None, + "test_procedure": None, + "evidence": None, + "objective": None, + }) + assert score == 0.0 + + def test_ranking_order(self): + """Rich control should rank above sparse control.""" + rich = { + "requirements": json.dumps(["r1", "r2", "r3"]), + "test_procedure": json.dumps(["t1", "t2"]), + "evidence": json.dumps(["e1"]), + "objective": "A comprehensive objective for this control.", + } + sparse = { + "requirements": json.dumps(["r1"]), + "objective": "Short", + } + assert quality_score(rich) > quality_score(sparse) + + +# --------------------------------------------------------------------------- +# Sub-grouping TESTS +# --------------------------------------------------------------------------- + + +class TestSubGrouping: + def _make_runner(self): + db = MagicMock() + return BatchDedupRunner(db=db) + + def test_groups_by_merge_hint(self): + runner = self._make_runner() + controls = [ + {"uuid": "a", "merge_group_hint": "implement:mfa:none"}, + {"uuid": "b", "merge_group_hint": "implement:mfa:none"}, + {"uuid": "c", "merge_group_hint": "test:firewall:periodic"}, + ] + groups = runner._sub_group_by_merge_hint(controls) + assert len(groups) == 2 + assert len(groups["implement:mfa:none"]) == 2 + assert len(groups["test:firewall:periodic"]) == 1 + + def test_empty_hint_gets_own_group(self): + runner = self._make_runner() + controls = [ + {"uuid": "x", "merge_group_hint": ""}, + {"uuid": "y", "merge_group_hint": ""}, + ] + groups = runner._sub_group_by_merge_hint(controls) + # Each empty-hint control gets its own group + assert len(groups) == 2 + + def test_single_control_single_group(self): + runner = self._make_runner() + controls = [ + {"uuid": "a", "merge_group_hint": "implement:mfa:none"}, + ] + groups = runner._sub_group_by_merge_hint(controls) + assert len(groups) == 1 + + +# --------------------------------------------------------------------------- +# Master Selection TESTS +# --------------------------------------------------------------------------- + + +class TestMasterSelection: + """Best quality score should become master.""" + + @pytest.mark.asyncio + async def test_highest_score_is_master(self): + """In a group, the control with highest quality_score is master.""" + db = MagicMock() + db.execute = MagicMock() + db.commit = MagicMock() + + runner = BatchDedupRunner(db=db) + + sparse = _make_control("s1", reqs=1, hint="implement:mfa:none") + rich = _make_control("r1", reqs=5, tests=3, evidence=2, hint="implement:mfa:none") + medium = _make_control("m1", reqs=2, tests=1, hint="implement:mfa:none") + + controls = [sparse, medium, rich] + + # Mock embedding to avoid real API calls + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock, return_value=[0.1] * 1024), \ + patch("compliance.services.batch_dedup_runner.qdrant_upsert", + new_callable=AsyncMock, return_value=True), \ + patch("compliance.services.batch_dedup_runner.qdrant_search", + new_callable=AsyncMock, return_value=[{ + "score": 0.95, + "payload": {"control_uuid": rich["uuid"], + "control_id": rich["control_id"]}, + }]): + await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True) + + # Rich should be master (1 master), others linked (2 linked) + assert runner.stats["masters"] == 1 + assert runner.stats["linked"] == 2 + + +# --------------------------------------------------------------------------- +# Dry Run TESTS +# --------------------------------------------------------------------------- + + +class TestDryRun: + """Dry run should compute stats but NOT modify DB.""" + + @pytest.mark.asyncio + async def test_dry_run_no_db_writes(self): + db = MagicMock() + db.execute = MagicMock() + db.commit = MagicMock() + + runner = BatchDedupRunner(db=db) + + controls = [ + _make_control("a", reqs=3, hint="implement:mfa:none"), + _make_control("b", reqs=1, hint="implement:mfa:none"), + ] + + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock, return_value=[0.1] * 1024), \ + patch("compliance.services.batch_dedup_runner.qdrant_upsert", + new_callable=AsyncMock, return_value=True), \ + patch("compliance.services.batch_dedup_runner.qdrant_search", + new_callable=AsyncMock, return_value=[{ + "score": 0.95, + "payload": {"control_uuid": "a-uuid", + "control_id": "AUTH-001"}, + }]): + await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True) + + # No DB execute calls for UPDATE/INSERT (only the initial load query was mocked) + # In dry_run, _mark_duplicate and _embed_and_index are skipped + assert runner.stats["masters"] == 1 + # qdrant_upsert should NOT have been called (dry_run skips indexing) + from compliance.services.batch_dedup_runner import qdrant_upsert + # No commit for dedup operations + db.commit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Parent Link Transfer TESTS +# --------------------------------------------------------------------------- + + +class TestParentLinkTransfer: + """Parent links should migrate from duplicate to master.""" + + def test_transfer_parent_links(self): + db = MagicMock() + # Mock: duplicate has 2 parent links + db.execute.return_value.fetchall.return_value = [ + ("parent-1", "decomposition", 1.0, "DSGVO", "Art. 32", "obl-1"), + ("parent-2", "decomposition", 0.9, "NIS2", "Art. 21", "obl-2"), + ] + + runner = BatchDedupRunner(db=db) + count = runner._transfer_parent_links("master-uuid", "dup-uuid") + + assert count == 2 + # Two INSERT calls for the transferred links + assert db.execute.call_count == 3 # 1 SELECT + 2 INSERTs + + def test_transfer_skips_self_reference(self): + db = MagicMock() + # Parent link points to master itself → should be skipped + db.execute.return_value.fetchall.return_value = [ + ("master-uuid", "decomposition", 1.0, "DSGVO", "Art. 32", "obl-1"), + ] + + runner = BatchDedupRunner(db=db) + count = runner._transfer_parent_links("master-uuid", "dup-uuid") + + assert count == 0 + + +# --------------------------------------------------------------------------- +# Title-identical Short-circuit TESTS +# --------------------------------------------------------------------------- + + +class TestTitleIdenticalShortCircuit: + + @pytest.mark.asyncio + async def test_identical_titles_skip_embedding(self): + """Controls with identical titles in same merge group → direct link.""" + db = MagicMock() + db.execute = MagicMock() + db.commit = MagicMock() + # Mock the parent link transfer query + db.execute.return_value.fetchall.return_value = [] + + runner = BatchDedupRunner(db=db) + + master = _make_control("m", reqs=3, hint="implement:mfa:none", + title="MFA implementieren") + candidate = _make_control("c", reqs=1, hint="implement:mfa:none", + title="MFA implementieren") + + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock) as mock_embed: + await runner._check_and_link(master, candidate, "CP-AUTH-001", dry_run=False) + + # Embedding should NOT be called (title-identical short-circuit) + mock_embed.assert_not_called() + assert runner.stats["linked"] == 1 + assert runner.stats["skipped_title_identical"] == 1 + + +# --------------------------------------------------------------------------- +# Cross-Regulation Pass TESTS +# --------------------------------------------------------------------------- + + +class TestCrossRegulationPass: + + @pytest.mark.asyncio + async def test_cross_reg_creates_link(self): + db = MagicMock() + db.execute = MagicMock() + db.commit = MagicMock() + # First call: load masters + db.execute.return_value.fetchall.return_value = [ + ("uuid-1", "AUTH-001", "MFA implementieren", "CP-AUTH-001", + "implement:multi_factor_auth:none"), + ] + + runner = BatchDedupRunner(db=db) + + cross_result = [{ + "score": 0.96, + "payload": { + "control_uuid": "uuid-2", + "control_id": "SEC-001", + "pattern_id": "CP-SEC-001", # different pattern! + }, + }] + + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock, return_value=[0.1] * 1024), \ + patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation", + new_callable=AsyncMock, return_value=cross_result): + await runner._run_cross_regulation_pass() + + assert runner.stats["cross_reg_linked"] == 1 + + @pytest.mark.asyncio + async def test_cross_reg_ignores_same_pattern(self): + """Cross-reg should NOT link controls from same pattern.""" + db = MagicMock() + db.execute = MagicMock() + db.commit = MagicMock() + db.execute.return_value.fetchall.return_value = [ + ("uuid-1", "AUTH-001", "MFA", "CP-AUTH-001", "implement:mfa:none"), + ] + + runner = BatchDedupRunner(db=db) + + # Match from SAME pattern + cross_result = [{ + "score": 0.97, + "payload": { + "control_uuid": "uuid-3", + "control_id": "AUTH-002", + "pattern_id": "CP-AUTH-001", # same pattern + }, + }] + + with patch("compliance.services.batch_dedup_runner.get_embedding", + new_callable=AsyncMock, return_value=[0.1] * 1024), \ + patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation", + new_callable=AsyncMock, return_value=cross_result): + await runner._run_cross_regulation_pass() + + assert runner.stats["cross_reg_linked"] == 0 + + +# --------------------------------------------------------------------------- +# Progress Stats TESTS +# --------------------------------------------------------------------------- + + +class TestProgressStats: + + def test_get_status(self): + db = MagicMock() + runner = BatchDedupRunner(db=db) + runner.stats["masters"] = 42 + runner.stats["linked"] = 100 + runner._progress_pattern = "CP-AUTH-001" + runner._progress_count = 500 + + status = runner.get_status() + assert status["pattern"] == "CP-AUTH-001" + assert status["progress"] == 500 + assert status["masters"] == 42 + assert status["linked"] == 100 + + +# --------------------------------------------------------------------------- +# Route endpoint TESTS +# --------------------------------------------------------------------------- + + +class TestBatchDedupRoutes: + """Test the batch-dedup API endpoints.""" + + def test_status_endpoint_not_running(self): + from fastapi import FastAPI + from fastapi.testclient import TestClient + from compliance.api.crosswalk_routes import router + + app = FastAPI() + app.include_router(router, prefix="/api/compliance") + client = TestClient(app) + + with patch("compliance.api.crosswalk_routes.SessionLocal") as mock_session: + mock_db = MagicMock() + mock_session.return_value = mock_db + mock_db.execute.return_value.fetchone.return_value = (85000, 0, 85000) + + resp = client.get("/api/compliance/v1/canonical/migrate/batch-dedup/status") + assert resp.status_code == 200 + data = resp.json() + assert data["running"] is False + + +# --------------------------------------------------------------------------- +# HELPERS +# --------------------------------------------------------------------------- + + +def _make_control( + prefix: str, + reqs: int = 0, + tests: int = 0, + evidence: int = 0, + hint: str = "", + title: str = None, + pattern_id: str = "CP-AUTH-001", +) -> dict: + """Build a mock control dict for testing.""" + return { + "uuid": f"{prefix}-uuid", + "control_id": f"AUTH-{prefix}", + "title": title or f"Control {prefix}", + "objective": f"Objective for {prefix}", + "pattern_id": pattern_id, + "requirements": json.dumps([f"r{i}" for i in range(reqs)]), + "test_procedure": json.dumps([f"t{i}" for i in range(tests)]), + "evidence": json.dumps([f"e{i}" for i in range(evidence)]), + "release_state": "draft", + "merge_group_hint": hint, + "action_object_class": "", + }