feat: Batch Dedup Runner — 85k→~18-25k Master Controls
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 32s
CI/CD / test-python-backend-compliance (push) Successful in 30s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 16s
CI/CD / validate-canonical-controls (push) Successful in 9s
CI/CD / Deploy (push) Successful in 1s
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 32s
CI/CD / test-python-backend-compliance (push) Successful in 30s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 16s
CI/CD / validate-canonical-controls (push) Successful in 9s
CI/CD / Deploy (push) Successful in 1s
Adds batch orchestration for deduplicating ~85k Pass 0b atomic controls into ~18-25k unique masters with M:N parent linking. New files: - migrations/078_batch_dedup.sql: merged_into_uuid column, perf indexes, link_type CHECK extended for cross_regulation - batch_dedup_runner.py: BatchDedupRunner with quality scoring, merge-hint grouping, title-identical short-circuit, parent-link transfer, and cross-regulation pass - tests/test_batch_dedup_runner.py: 21 tests (all passing) Modified: - control_dedup.py: optional collection param on Qdrant functions - crosswalk_routes.py: POST/GET batch-dedup endpoints Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -764,6 +764,75 @@ async def decomposition_status():
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# BATCH DEDUP ENDPOINTS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level runner reference for status polling
|
||||||
|
_batch_dedup_runner = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/migrate/batch-dedup", response_model=MigrationResponse)
|
||||||
|
async def migrate_batch_dedup(
|
||||||
|
dry_run: bool = Query(False, description="Preview mode — no DB changes"),
|
||||||
|
pattern_id: Optional[str] = Query(None, description="Only process this pattern"),
|
||||||
|
):
|
||||||
|
"""Batch dedup: reduce ~85k Pass 0b controls to ~18-25k masters.
|
||||||
|
|
||||||
|
Groups controls by pattern_id + merge_group_hint, picks the best
|
||||||
|
quality master, and links duplicates via control_parent_links.
|
||||||
|
"""
|
||||||
|
global _batch_dedup_runner
|
||||||
|
from compliance.services.batch_dedup_runner import BatchDedupRunner
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
_batch_dedup_runner = runner
|
||||||
|
stats = await runner.run(dry_run=dry_run, pattern_filter=pattern_id)
|
||||||
|
return MigrationResponse(status="completed", stats=stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Batch dedup failed: %s", e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
finally:
|
||||||
|
_batch_dedup_runner = None
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/migrate/batch-dedup/status")
|
||||||
|
async def batch_dedup_status():
|
||||||
|
"""Get current batch dedup progress (while running)."""
|
||||||
|
if _batch_dedup_runner is not None:
|
||||||
|
return {"running": True, **_batch_dedup_runner.get_status()}
|
||||||
|
|
||||||
|
# Not running — show DB stats
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
row = db.execute(text("""
|
||||||
|
SELECT
|
||||||
|
count(*) FILTER (WHERE decomposition_method = 'pass0b') AS total_pass0b,
|
||||||
|
count(*) FILTER (WHERE decomposition_method = 'pass0b'
|
||||||
|
AND release_state = 'duplicate') AS duplicates,
|
||||||
|
count(*) FILTER (WHERE decomposition_method = 'pass0b'
|
||||||
|
AND release_state != 'duplicate'
|
||||||
|
AND release_state != 'deprecated') AS masters
|
||||||
|
FROM canonical_controls
|
||||||
|
""")).fetchone()
|
||||||
|
review_count = db.execute(text(
|
||||||
|
"SELECT count(*) FROM control_dedup_reviews WHERE review_status = 'pending'"
|
||||||
|
)).fetchone()[0]
|
||||||
|
return {
|
||||||
|
"running": False,
|
||||||
|
"total_pass0b": row[0],
|
||||||
|
"duplicates": row[1],
|
||||||
|
"masters": row[2],
|
||||||
|
"pending_reviews": review_count,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# HELPERS
|
# HELPERS
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
560
backend-compliance/compliance/services/batch_dedup_runner.py
Normal file
560
backend-compliance/compliance/services/batch_dedup_runner.py
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
"""Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls.
|
||||||
|
|
||||||
|
Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls by:
|
||||||
|
1. Intra-Pattern Dedup: Group by pattern_id + merge_group_hint, pick best master
|
||||||
|
2. Cross-Regulation Dedup: Find near-duplicates across pattern boundaries
|
||||||
|
|
||||||
|
Reuses the existing 4-Stage Pipeline from control_dedup.py. Only adds
|
||||||
|
batch orchestration, quality scoring, and parent-link transfer logic.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
runner = BatchDedupRunner(db)
|
||||||
|
stats = await runner.run(dry_run=True) # preview
|
||||||
|
stats = await runner.run(dry_run=False) # execute
|
||||||
|
stats = await runner.run(pattern_filter="CP-AUTH-001") # single pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from compliance.services.control_dedup import (
|
||||||
|
ControlDedupChecker,
|
||||||
|
DedupResult,
|
||||||
|
canonicalize_text,
|
||||||
|
ensure_qdrant_collection,
|
||||||
|
get_embedding,
|
||||||
|
normalize_action,
|
||||||
|
normalize_object,
|
||||||
|
qdrant_search,
|
||||||
|
qdrant_search_cross_regulation,
|
||||||
|
qdrant_upsert,
|
||||||
|
CROSS_REG_LINK_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEDUP_COLLECTION = "atomic_controls_dedup"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Quality Score ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def quality_score(control: dict) -> float:
|
||||||
|
"""Score a control by richness of requirements, tests, evidence, and objective.
|
||||||
|
|
||||||
|
Higher score = better candidate for master control.
|
||||||
|
"""
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
reqs = control.get("requirements") or "[]"
|
||||||
|
if isinstance(reqs, str):
|
||||||
|
try:
|
||||||
|
reqs = json.loads(reqs)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
reqs = []
|
||||||
|
score += len(reqs) * 2.0
|
||||||
|
|
||||||
|
tests = control.get("test_procedure") or "[]"
|
||||||
|
if isinstance(tests, str):
|
||||||
|
try:
|
||||||
|
tests = json.loads(tests)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
tests = []
|
||||||
|
score += len(tests) * 1.5
|
||||||
|
|
||||||
|
evidence = control.get("evidence") or "[]"
|
||||||
|
if isinstance(evidence, str):
|
||||||
|
try:
|
||||||
|
evidence = json.loads(evidence)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
evidence = []
|
||||||
|
score += len(evidence) * 1.0
|
||||||
|
|
||||||
|
objective = control.get("objective") or ""
|
||||||
|
score += min(len(objective) / 200, 3.0)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
# ── Batch Dedup Runner ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDedupRunner:
|
||||||
|
"""Batch dedup orchestrator for existing Pass 0b atomic controls."""
|
||||||
|
|
||||||
|
def __init__(self, db, collection: str = DEDUP_COLLECTION):
|
||||||
|
self.db = db
|
||||||
|
self.collection = collection
|
||||||
|
self.stats = {
|
||||||
|
"total_controls": 0,
|
||||||
|
"patterns_processed": 0,
|
||||||
|
"sub_groups_processed": 0,
|
||||||
|
"masters": 0,
|
||||||
|
"linked": 0,
|
||||||
|
"review": 0,
|
||||||
|
"new_controls": 0,
|
||||||
|
"parent_links_transferred": 0,
|
||||||
|
"cross_reg_linked": 0,
|
||||||
|
"errors": 0,
|
||||||
|
"skipped_title_identical": 0,
|
||||||
|
}
|
||||||
|
self._progress_pattern = ""
|
||||||
|
self._progress_count = 0
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
dry_run: bool = False,
|
||||||
|
pattern_filter: str = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Run the full batch dedup pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dry_run: If True, compute stats but don't modify DB.
|
||||||
|
pattern_filter: If set, only process this pattern_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stats dict with counts.
|
||||||
|
"""
|
||||||
|
start = time.monotonic()
|
||||||
|
logger.info("BatchDedup starting (dry_run=%s, pattern_filter=%s)",
|
||||||
|
dry_run, pattern_filter)
|
||||||
|
|
||||||
|
# Ensure Qdrant collection
|
||||||
|
await ensure_qdrant_collection(collection=self.collection)
|
||||||
|
|
||||||
|
# Phase 1: Intra-pattern dedup
|
||||||
|
groups = self._load_pattern_groups(pattern_filter)
|
||||||
|
for pattern_id, controls in groups:
|
||||||
|
try:
|
||||||
|
await self._process_pattern_group(pattern_id, controls, dry_run)
|
||||||
|
self.stats["patterns_processed"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("BatchDedup error on pattern %s: %s", pattern_id, e)
|
||||||
|
self.stats["errors"] += 1
|
||||||
|
|
||||||
|
# Phase 2: Cross-regulation dedup (skip in dry_run for speed)
|
||||||
|
if not dry_run:
|
||||||
|
await self._run_cross_regulation_pass()
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
self.stats["elapsed_seconds"] = round(elapsed, 1)
|
||||||
|
logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats)
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
def _load_pattern_groups(self, pattern_filter: str = None) -> list:
|
||||||
|
"""Load all Pass 0b controls grouped by pattern_id, largest first."""
|
||||||
|
conditions = [
|
||||||
|
"decomposition_method = 'pass0b'",
|
||||||
|
"release_state != 'deprecated'",
|
||||||
|
"release_state != 'duplicate'",
|
||||||
|
]
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
if pattern_filter:
|
||||||
|
conditions.append("pattern_id = :pf")
|
||||||
|
params["pf"] = pattern_filter
|
||||||
|
|
||||||
|
where = " AND ".join(conditions)
|
||||||
|
rows = self.db.execute(text(f"""
|
||||||
|
SELECT id::text, control_id, title, objective,
|
||||||
|
pattern_id, requirements::text, test_procedure::text,
|
||||||
|
evidence::text, release_state,
|
||||||
|
generation_metadata->>'merge_group_hint' as merge_group_hint,
|
||||||
|
generation_metadata->>'action_object_class' as action_object_class
|
||||||
|
FROM canonical_controls
|
||||||
|
WHERE {where}
|
||||||
|
ORDER BY pattern_id, control_id
|
||||||
|
"""), params).fetchall()
|
||||||
|
|
||||||
|
# Group by pattern_id
|
||||||
|
by_pattern = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
by_pattern[r[4]].append({
|
||||||
|
"uuid": r[0],
|
||||||
|
"control_id": r[1],
|
||||||
|
"title": r[2],
|
||||||
|
"objective": r[3],
|
||||||
|
"pattern_id": r[4],
|
||||||
|
"requirements": r[5],
|
||||||
|
"test_procedure": r[6],
|
||||||
|
"evidence": r[7],
|
||||||
|
"release_state": r[8],
|
||||||
|
"merge_group_hint": r[9] or "",
|
||||||
|
"action_object_class": r[10] or "",
|
||||||
|
})
|
||||||
|
|
||||||
|
self.stats["total_controls"] = len(rows)
|
||||||
|
|
||||||
|
# Sort patterns by group size (descending) for progress visibility
|
||||||
|
sorted_groups = sorted(by_pattern.items(), key=lambda x: len(x[1]), reverse=True)
|
||||||
|
logger.info("BatchDedup loaded %d controls in %d patterns",
|
||||||
|
len(rows), len(sorted_groups))
|
||||||
|
return sorted_groups
|
||||||
|
|
||||||
|
def _sub_group_by_merge_hint(self, controls: list) -> dict:
|
||||||
|
"""Group controls by merge_group_hint composite key."""
|
||||||
|
groups = defaultdict(list)
|
||||||
|
for c in controls:
|
||||||
|
hint = c["merge_group_hint"]
|
||||||
|
if hint:
|
||||||
|
groups[hint].append(c)
|
||||||
|
else:
|
||||||
|
# No hint → each control is its own group
|
||||||
|
groups[f"__no_hint_{c['uuid']}"].append(c)
|
||||||
|
return dict(groups)
|
||||||
|
|
||||||
|
async def _process_pattern_group(
|
||||||
|
self,
|
||||||
|
pattern_id: str,
|
||||||
|
controls: list,
|
||||||
|
dry_run: bool,
|
||||||
|
):
|
||||||
|
"""Process all controls within a single pattern_id."""
|
||||||
|
self._progress_pattern = pattern_id
|
||||||
|
self._progress_count = 0
|
||||||
|
total = len(controls)
|
||||||
|
|
||||||
|
sub_groups = self._sub_group_by_merge_hint(controls)
|
||||||
|
|
||||||
|
for hint, group in sub_groups.items():
|
||||||
|
if len(group) < 2:
|
||||||
|
# Single control → always master
|
||||||
|
master = group[0]
|
||||||
|
self.stats["masters"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
await self._embed_and_index(master)
|
||||||
|
self._progress_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by quality score (best first)
|
||||||
|
sorted_group = sorted(group, key=quality_score, reverse=True)
|
||||||
|
master = sorted_group[0]
|
||||||
|
self.stats["masters"] += 1
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
await self._embed_and_index(master)
|
||||||
|
|
||||||
|
for candidate in sorted_group[1:]:
|
||||||
|
await self._check_and_link(master, candidate, pattern_id, dry_run)
|
||||||
|
self._progress_count += 1
|
||||||
|
|
||||||
|
self.stats["sub_groups_processed"] += 1
|
||||||
|
|
||||||
|
# Progress logging every 100 controls
|
||||||
|
if self._progress_count > 0 and self._progress_count % 100 == 0:
|
||||||
|
logger.info(
|
||||||
|
"BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d",
|
||||||
|
pattern_id, self._progress_count, total,
|
||||||
|
self.stats["masters"], self.stats["linked"], self.stats["review"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _check_and_link(
|
||||||
|
self,
|
||||||
|
master: dict,
|
||||||
|
candidate: dict,
|
||||||
|
pattern_id: str,
|
||||||
|
dry_run: bool,
|
||||||
|
):
|
||||||
|
"""Check if candidate is a duplicate of master and link if so."""
|
||||||
|
# Short-circuit: identical titles within same merge_group → direct link
|
||||||
|
if (candidate["title"].strip().lower() == master["title"].strip().lower()
|
||||||
|
and candidate["merge_group_hint"] == master["merge_group_hint"]
|
||||||
|
and candidate["merge_group_hint"]):
|
||||||
|
self.stats["linked"] += 1
|
||||||
|
self.stats["skipped_title_identical"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
await self._mark_duplicate(master, candidate, confidence=1.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract action/object from merge_group_hint (format: "action_type:norm_obj:trigger_key")
|
||||||
|
parts = candidate["merge_group_hint"].split(":", 2)
|
||||||
|
action = parts[0] if len(parts) > 0 else ""
|
||||||
|
obj = parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
# Build canonical text and get embedding for candidate
|
||||||
|
canonical = canonicalize_text(action, obj, candidate["title"])
|
||||||
|
embedding = await get_embedding(canonical)
|
||||||
|
|
||||||
|
if not embedding:
|
||||||
|
# Can't embed → keep as new control
|
||||||
|
self.stats["new_controls"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
await self._embed_and_index(candidate)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Search the dedup collection for similar controls
|
||||||
|
results = await qdrant_search(
|
||||||
|
embedding, pattern_id, top_k=5, collection=self.collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
# No matches → new master
|
||||||
|
self.stats["new_controls"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
await self._embed_and_index(candidate)
|
||||||
|
return
|
||||||
|
|
||||||
|
best = results[0]
|
||||||
|
best_score = best.get("score", 0.0)
|
||||||
|
best_payload = best.get("payload", {})
|
||||||
|
best_uuid = best_payload.get("control_uuid", "")
|
||||||
|
|
||||||
|
# Same action+object (since same merge_group_hint) → use standard thresholds
|
||||||
|
from compliance.services.control_dedup import LINK_THRESHOLD, REVIEW_THRESHOLD
|
||||||
|
|
||||||
|
if best_score > LINK_THRESHOLD:
|
||||||
|
self.stats["linked"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
# Link to the matched master (which may differ from our `master`)
|
||||||
|
await self._mark_duplicate_to(
|
||||||
|
master_uuid=best_uuid,
|
||||||
|
candidate=candidate,
|
||||||
|
confidence=best_score,
|
||||||
|
)
|
||||||
|
elif best_score > REVIEW_THRESHOLD:
|
||||||
|
self.stats["review"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
self._write_review(candidate, best_payload, best_score)
|
||||||
|
else:
|
||||||
|
# Below threshold → becomes a new master
|
||||||
|
self.stats["new_controls"] += 1
|
||||||
|
if not dry_run:
|
||||||
|
await self._index_with_embedding(candidate, embedding)
|
||||||
|
|
||||||
|
async def _embed_and_index(self, control: dict):
|
||||||
|
"""Compute embedding and index a control in the dedup Qdrant collection."""
|
||||||
|
parts = control["merge_group_hint"].split(":", 2)
|
||||||
|
action = parts[0] if len(parts) > 0 else ""
|
||||||
|
obj = parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
norm_action = normalize_action(action)
|
||||||
|
norm_object = normalize_object(obj)
|
||||||
|
canonical = canonicalize_text(action, obj, control["title"])
|
||||||
|
embedding = await get_embedding(canonical)
|
||||||
|
|
||||||
|
if not embedding:
|
||||||
|
return
|
||||||
|
|
||||||
|
await qdrant_upsert(
|
||||||
|
point_id=control["uuid"],
|
||||||
|
embedding=embedding,
|
||||||
|
payload={
|
||||||
|
"control_uuid": control["uuid"],
|
||||||
|
"control_id": control["control_id"],
|
||||||
|
"title": control["title"],
|
||||||
|
"pattern_id": control["pattern_id"],
|
||||||
|
"action_normalized": norm_action,
|
||||||
|
"object_normalized": norm_object,
|
||||||
|
"canonical_text": canonical,
|
||||||
|
},
|
||||||
|
collection=self.collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _index_with_embedding(self, control: dict, embedding: list):
|
||||||
|
"""Index a control with a pre-computed embedding."""
|
||||||
|
parts = control["merge_group_hint"].split(":", 2)
|
||||||
|
action = parts[0] if len(parts) > 0 else ""
|
||||||
|
obj = parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
norm_action = normalize_action(action)
|
||||||
|
norm_object = normalize_object(obj)
|
||||||
|
canonical = canonicalize_text(action, obj, control["title"])
|
||||||
|
|
||||||
|
await qdrant_upsert(
|
||||||
|
point_id=control["uuid"],
|
||||||
|
embedding=embedding,
|
||||||
|
payload={
|
||||||
|
"control_uuid": control["uuid"],
|
||||||
|
"control_id": control["control_id"],
|
||||||
|
"title": control["title"],
|
||||||
|
"pattern_id": control["pattern_id"],
|
||||||
|
"action_normalized": norm_action,
|
||||||
|
"object_normalized": norm_object,
|
||||||
|
"canonical_text": canonical,
|
||||||
|
},
|
||||||
|
collection=self.collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float):
|
||||||
|
"""Mark candidate as duplicate of master, transfer parent links."""
|
||||||
|
self.db.execute(text("""
|
||||||
|
UPDATE canonical_controls
|
||||||
|
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
|
||||||
|
WHERE id = CAST(:cand AS uuid)
|
||||||
|
"""), {"master": master["uuid"], "cand": candidate["uuid"]})
|
||||||
|
|
||||||
|
# Add dedup_merge link
|
||||||
|
self.db.execute(text("""
|
||||||
|
INSERT INTO control_parent_links
|
||||||
|
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||||
|
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
|
||||||
|
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||||
|
"""), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence})
|
||||||
|
|
||||||
|
# Transfer parent links from candidate to master
|
||||||
|
transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"])
|
||||||
|
self.stats["parent_links_transferred"] += transferred
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
async def _mark_duplicate_to(self, master_uuid: str, candidate: dict, confidence: float):
|
||||||
|
"""Mark candidate as duplicate of a Qdrant-matched master."""
|
||||||
|
self.db.execute(text("""
|
||||||
|
UPDATE canonical_controls
|
||||||
|
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
|
||||||
|
WHERE id = CAST(:cand AS uuid)
|
||||||
|
"""), {"master": master_uuid, "cand": candidate["uuid"]})
|
||||||
|
|
||||||
|
# Add dedup_merge link
|
||||||
|
self.db.execute(text("""
|
||||||
|
INSERT INTO control_parent_links
|
||||||
|
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||||
|
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
|
||||||
|
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||||
|
"""), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence})
|
||||||
|
|
||||||
|
# Transfer parent links
|
||||||
|
transferred = self._transfer_parent_links(master_uuid, candidate["uuid"])
|
||||||
|
self.stats["parent_links_transferred"] += transferred
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int:
|
||||||
|
"""Move existing parent links from duplicate to master.
|
||||||
|
|
||||||
|
Returns the number of links transferred.
|
||||||
|
"""
|
||||||
|
# Find parent links pointing TO the duplicate (where it was the child control)
|
||||||
|
rows = self.db.execute(text("""
|
||||||
|
SELECT parent_control_uuid::text, link_type, confidence,
|
||||||
|
source_regulation, source_article, obligation_candidate_id::text
|
||||||
|
FROM control_parent_links
|
||||||
|
WHERE control_uuid = CAST(:dup AS uuid)
|
||||||
|
AND link_type = 'decomposition'
|
||||||
|
"""), {"dup": duplicate_uuid}).fetchall()
|
||||||
|
|
||||||
|
transferred = 0
|
||||||
|
for r in rows:
|
||||||
|
parent_uuid = r[0]
|
||||||
|
# Skip self-references
|
||||||
|
if parent_uuid == master_uuid:
|
||||||
|
continue
|
||||||
|
self.db.execute(text("""
|
||||||
|
INSERT INTO control_parent_links
|
||||||
|
(control_uuid, parent_control_uuid, link_type, confidence,
|
||||||
|
source_regulation, source_article, obligation_candidate_id)
|
||||||
|
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), :lt, :conf,
|
||||||
|
:sr, :sa, CAST(:oci AS uuid))
|
||||||
|
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||||
|
"""), {
|
||||||
|
"cu": master_uuid,
|
||||||
|
"pu": parent_uuid,
|
||||||
|
"lt": r[1],
|
||||||
|
"conf": float(r[2]) if r[2] else 1.0,
|
||||||
|
"sr": r[3],
|
||||||
|
"sa": r[4],
|
||||||
|
"oci": r[5],
|
||||||
|
})
|
||||||
|
transferred += 1
|
||||||
|
|
||||||
|
return transferred
|
||||||
|
|
||||||
|
def _write_review(self, candidate: dict, matched_payload: dict, score: float):
|
||||||
|
"""Write a dedup review entry for borderline matches."""
|
||||||
|
self.db.execute(text("""
|
||||||
|
INSERT INTO control_dedup_reviews
|
||||||
|
(candidate_control_id, candidate_title, candidate_objective,
|
||||||
|
matched_control_uuid, matched_control_id,
|
||||||
|
similarity_score, dedup_stage, dedup_details)
|
||||||
|
VALUES (:ccid, :ct, :co, CAST(:mcu AS uuid), :mci,
|
||||||
|
:ss, 'batch_dedup', :dd::jsonb)
|
||||||
|
"""), {
|
||||||
|
"ccid": candidate["control_id"],
|
||||||
|
"ct": candidate["title"],
|
||||||
|
"co": candidate.get("objective", ""),
|
||||||
|
"mcu": matched_payload.get("control_uuid"),
|
||||||
|
"mci": matched_payload.get("control_id"),
|
||||||
|
"ss": score,
|
||||||
|
"dd": json.dumps({
|
||||||
|
"merge_group_hint": candidate["merge_group_hint"],
|
||||||
|
"pattern_id": candidate["pattern_id"],
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
async def _run_cross_regulation_pass(self):
|
||||||
|
"""Phase 2: Find cross-regulation duplicates among surviving masters."""
|
||||||
|
logger.info("BatchDedup Phase 2: Cross-regulation pass starting...")
|
||||||
|
|
||||||
|
# Load all non-duplicate pass0b controls that are now masters
|
||||||
|
rows = self.db.execute(text("""
|
||||||
|
SELECT id::text, control_id, title, pattern_id,
|
||||||
|
generation_metadata->>'merge_group_hint' as merge_group_hint
|
||||||
|
FROM canonical_controls
|
||||||
|
WHERE decomposition_method = 'pass0b'
|
||||||
|
AND release_state != 'duplicate'
|
||||||
|
AND release_state != 'deprecated'
|
||||||
|
ORDER BY control_id
|
||||||
|
""")).fetchall()
|
||||||
|
|
||||||
|
logger.info("BatchDedup Cross-reg: %d masters to check", len(rows))
|
||||||
|
cross_linked = 0
|
||||||
|
|
||||||
|
for i, r in enumerate(rows):
|
||||||
|
uuid = r[0]
|
||||||
|
hint = r[4] or ""
|
||||||
|
parts = hint.split(":", 2)
|
||||||
|
action = parts[0] if len(parts) > 0 else ""
|
||||||
|
obj = parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
canonical = canonicalize_text(action, obj, r[2])
|
||||||
|
embedding = await get_embedding(canonical)
|
||||||
|
if not embedding:
|
||||||
|
continue
|
||||||
|
|
||||||
|
results = await qdrant_search_cross_regulation(
|
||||||
|
embedding, top_k=5, collection=self.collection,
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if best match is from a DIFFERENT pattern
|
||||||
|
best = results[0]
|
||||||
|
best_score = best.get("score", 0.0)
|
||||||
|
best_payload = best.get("payload", {})
|
||||||
|
|
||||||
|
if (best_score > CROSS_REG_LINK_THRESHOLD
|
||||||
|
and best_payload.get("pattern_id") != r[3]
|
||||||
|
and best_payload.get("control_uuid") != uuid):
|
||||||
|
# Cross-regulation link
|
||||||
|
self.db.execute(text("""
|
||||||
|
INSERT INTO control_parent_links
|
||||||
|
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||||
|
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
|
||||||
|
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||||
|
"""), {
|
||||||
|
"cu": best_payload["control_uuid"],
|
||||||
|
"pu": uuid,
|
||||||
|
"conf": best_score,
|
||||||
|
})
|
||||||
|
self.db.commit()
|
||||||
|
cross_linked += 1
|
||||||
|
|
||||||
|
if (i + 1) % 500 == 0:
|
||||||
|
logger.info("BatchDedup Cross-reg: %d/%d checked, %d linked",
|
||||||
|
i + 1, len(rows), cross_linked)
|
||||||
|
|
||||||
|
self.stats["cross_reg_linked"] = cross_linked
|
||||||
|
logger.info("BatchDedup Cross-reg complete: %d links created", cross_linked)
|
||||||
|
|
||||||
|
def get_status(self) -> dict:
|
||||||
|
"""Return current progress stats (for status endpoint)."""
|
||||||
|
return {
|
||||||
|
"pattern": self._progress_pattern,
|
||||||
|
"progress": self._progress_count,
|
||||||
|
**self.stats,
|
||||||
|
}
|
||||||
@@ -317,10 +317,12 @@ async def qdrant_search(
|
|||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
pattern_id: str,
|
pattern_id: str,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
collection: Optional[str] = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search Qdrant for similar atomic controls, filtered by pattern_id."""
|
"""Search Qdrant for similar atomic controls, filtered by pattern_id."""
|
||||||
if not embedding:
|
if not embedding:
|
||||||
return []
|
return []
|
||||||
|
coll = collection or QDRANT_COLLECTION
|
||||||
body: dict = {
|
body: dict = {
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"limit": top_k,
|
"limit": top_k,
|
||||||
@@ -334,7 +336,7 @@ async def qdrant_search(
|
|||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
|
f"{QDRANT_URL}/collections/{coll}/points/search",
|
||||||
json=body,
|
json=body,
|
||||||
)
|
)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
@@ -349,6 +351,7 @@ async def qdrant_search(
|
|||||||
async def qdrant_search_cross_regulation(
|
async def qdrant_search_cross_regulation(
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
|
collection: Optional[str] = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search Qdrant for similar controls across ALL regulations (no pattern_id filter).
|
"""Search Qdrant for similar controls across ALL regulations (no pattern_id filter).
|
||||||
|
|
||||||
@@ -356,6 +359,7 @@ async def qdrant_search_cross_regulation(
|
|||||||
"""
|
"""
|
||||||
if not embedding:
|
if not embedding:
|
||||||
return []
|
return []
|
||||||
|
coll = collection or QDRANT_COLLECTION
|
||||||
body: dict = {
|
body: dict = {
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"limit": top_k,
|
"limit": top_k,
|
||||||
@@ -364,7 +368,7 @@ async def qdrant_search_cross_regulation(
|
|||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
|
f"{QDRANT_URL}/collections/{coll}/points/search",
|
||||||
json=body,
|
json=body,
|
||||||
)
|
)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
@@ -380,10 +384,12 @@ async def qdrant_upsert(
|
|||||||
point_id: str,
|
point_id: str,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
payload: dict,
|
payload: dict,
|
||||||
|
collection: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Upsert a single point into the atomic_controls Qdrant collection."""
|
"""Upsert a single point into a Qdrant collection."""
|
||||||
if not embedding:
|
if not embedding:
|
||||||
return False
|
return False
|
||||||
|
coll = collection or QDRANT_COLLECTION
|
||||||
body = {
|
body = {
|
||||||
"points": [{
|
"points": [{
|
||||||
"id": point_id,
|
"id": point_id,
|
||||||
@@ -394,7 +400,7 @@ async def qdrant_upsert(
|
|||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.put(
|
resp = await client.put(
|
||||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points",
|
f"{QDRANT_URL}/collections/{coll}/points",
|
||||||
json=body,
|
json=body,
|
||||||
)
|
)
|
||||||
return resp.status_code == 200
|
return resp.status_code == 200
|
||||||
@@ -403,27 +409,31 @@ async def qdrant_upsert(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def ensure_qdrant_collection(vector_size: int = 1024) -> bool:
|
async def ensure_qdrant_collection(
|
||||||
"""Create the Qdrant collection if it doesn't exist (idempotent)."""
|
vector_size: int = 1024,
|
||||||
|
collection: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Create a Qdrant collection if it doesn't exist (idempotent)."""
|
||||||
|
coll = collection or QDRANT_COLLECTION
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
# Check if exists
|
# Check if exists
|
||||||
resp = await client.get(f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}")
|
resp = await client.get(f"{QDRANT_URL}/collections/{coll}")
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
return True
|
return True
|
||||||
# Create
|
# Create
|
||||||
resp = await client.put(
|
resp = await client.put(
|
||||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}",
|
f"{QDRANT_URL}/collections/{coll}",
|
||||||
json={
|
json={
|
||||||
"vectors": {"size": vector_size, "distance": "Cosine"},
|
"vectors": {"size": vector_size, "distance": "Cosine"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
logger.info("Created Qdrant collection: %s", QDRANT_COLLECTION)
|
logger.info("Created Qdrant collection: %s", coll)
|
||||||
# Create payload indexes
|
# Create payload indexes
|
||||||
for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]:
|
for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]:
|
||||||
await client.put(
|
await client.put(
|
||||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/index",
|
f"{QDRANT_URL}/collections/{coll}/index",
|
||||||
json={"field_name": field_name, "field_schema": "keyword"},
|
json={"field_name": field_name, "field_schema": "keyword"},
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
@@ -710,6 +720,7 @@ class ControlDedupChecker:
|
|||||||
action: str,
|
action: str,
|
||||||
obj: str,
|
obj: str,
|
||||||
pattern_id: str,
|
pattern_id: str,
|
||||||
|
collection: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Index a new atomic control in Qdrant for future dedup checks."""
|
"""Index a new atomic control in Qdrant for future dedup checks."""
|
||||||
norm_action = normalize_action(action)
|
norm_action = normalize_action(action)
|
||||||
@@ -730,4 +741,5 @@ class ControlDedupChecker:
|
|||||||
"object_normalized": norm_object,
|
"object_normalized": norm_object,
|
||||||
"canonical_text": canonical,
|
"canonical_text": canonical,
|
||||||
},
|
},
|
||||||
|
collection=collection,
|
||||||
)
|
)
|
||||||
|
|||||||
42
backend-compliance/migrations/078_batch_dedup.sql
Normal file
42
backend-compliance/migrations/078_batch_dedup.sql
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
-- Migration 078: Batch Dedup — Schema extensions for 85k→~18-25k reduction
|
||||||
|
-- Adds merged_into_uuid tracking, performance indexes for batch dedup,
|
||||||
|
-- and extends link_type CHECK to include 'cross_regulation'.
|
||||||
|
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- 1. merged_into_uuid: Track which master a duplicate was merged into
|
||||||
|
-- =============================================================================
|
||||||
|
|
||||||
|
ALTER TABLE canonical_controls
|
||||||
|
ADD COLUMN IF NOT EXISTS merged_into_uuid UUID REFERENCES canonical_controls(id);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cc_merged_into
|
||||||
|
ON canonical_controls(merged_into_uuid) WHERE merged_into_uuid IS NOT NULL;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- 2. Performance indexes for batch dedup queries
|
||||||
|
-- =============================================================================
|
||||||
|
|
||||||
|
-- Index on merge_group_hint inside generation_metadata (for sub-grouping)
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cc_merge_group_hint
|
||||||
|
ON canonical_controls ((generation_metadata->>'merge_group_hint'))
|
||||||
|
WHERE decomposition_method = 'pass0b';
|
||||||
|
|
||||||
|
-- Composite index for pattern-based dedup loading
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cc_pattern_dedup
|
||||||
|
ON canonical_controls (pattern_id, release_state)
|
||||||
|
WHERE decomposition_method = 'pass0b';
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- 3. Extend link_type CHECK to include 'cross_regulation'
|
||||||
|
-- =============================================================================
|
||||||
|
|
||||||
|
ALTER TABLE control_parent_links
|
||||||
|
DROP CONSTRAINT IF EXISTS control_parent_links_link_type_check;
|
||||||
|
|
||||||
|
ALTER TABLE control_parent_links
|
||||||
|
ADD CONSTRAINT control_parent_links_link_type_check
|
||||||
|
CHECK (link_type IN ('decomposition', 'dedup_merge', 'manual', 'crosswalk', 'cross_regulation'));
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
433
backend-compliance/tests/test_batch_dedup_runner.py
Normal file
433
backend-compliance/tests/test_batch_dedup_runner.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""Tests for Batch Dedup Runner (batch_dedup_runner.py).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- quality_score(): Richness ranking
|
||||||
|
- BatchDedupRunner._sub_group_by_merge_hint(): Composite key grouping
|
||||||
|
- Master selection (highest quality score wins)
|
||||||
|
- Duplicate linking (mark + parent-link transfer)
|
||||||
|
- Dry run mode (no DB changes)
|
||||||
|
- Cross-regulation pass
|
||||||
|
- Progress reporting / stats
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||||||
|
|
||||||
|
from compliance.services.batch_dedup_runner import (
|
||||||
|
quality_score,
|
||||||
|
BatchDedupRunner,
|
||||||
|
DEDUP_COLLECTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# quality_score TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQualityScore:
|
||||||
|
"""Quality scoring: richer controls should score higher."""
|
||||||
|
|
||||||
|
def test_empty_control(self):
|
||||||
|
score = quality_score({})
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
def test_requirements_weight(self):
|
||||||
|
score = quality_score({"requirements": json.dumps(["r1", "r2", "r3"])})
|
||||||
|
assert score == pytest.approx(6.0) # 3 * 2.0
|
||||||
|
|
||||||
|
def test_test_procedure_weight(self):
|
||||||
|
score = quality_score({"test_procedure": json.dumps(["t1", "t2"])})
|
||||||
|
assert score == pytest.approx(3.0) # 2 * 1.5
|
||||||
|
|
||||||
|
def test_evidence_weight(self):
|
||||||
|
score = quality_score({"evidence": json.dumps(["e1"])})
|
||||||
|
assert score == pytest.approx(1.0) # 1 * 1.0
|
||||||
|
|
||||||
|
def test_objective_weight_capped(self):
|
||||||
|
short = quality_score({"objective": "x" * 100})
|
||||||
|
long = quality_score({"objective": "x" * 1000})
|
||||||
|
assert short == pytest.approx(0.5) # 100/200
|
||||||
|
assert long == pytest.approx(3.0) # capped at 3.0
|
||||||
|
|
||||||
|
def test_combined_score(self):
|
||||||
|
control = {
|
||||||
|
"requirements": json.dumps(["r1", "r2"]),
|
||||||
|
"test_procedure": json.dumps(["t1"]),
|
||||||
|
"evidence": json.dumps(["e1", "e2"]),
|
||||||
|
"objective": "x" * 400,
|
||||||
|
}
|
||||||
|
# 2*2 + 1*1.5 + 2*1.0 + min(400/200, 3) = 4 + 1.5 + 2 + 2 = 9.5
|
||||||
|
assert quality_score(control) == pytest.approx(9.5)
|
||||||
|
|
||||||
|
def test_json_string_vs_list(self):
|
||||||
|
"""Both JSON strings and already-parsed lists should work."""
|
||||||
|
a = quality_score({"requirements": json.dumps(["r1", "r2"])})
|
||||||
|
b = quality_score({"requirements": '["r1", "r2"]'})
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
def test_null_fields(self):
|
||||||
|
"""None values should not crash."""
|
||||||
|
score = quality_score({
|
||||||
|
"requirements": None,
|
||||||
|
"test_procedure": None,
|
||||||
|
"evidence": None,
|
||||||
|
"objective": None,
|
||||||
|
})
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
def test_ranking_order(self):
|
||||||
|
"""Rich control should rank above sparse control."""
|
||||||
|
rich = {
|
||||||
|
"requirements": json.dumps(["r1", "r2", "r3"]),
|
||||||
|
"test_procedure": json.dumps(["t1", "t2"]),
|
||||||
|
"evidence": json.dumps(["e1"]),
|
||||||
|
"objective": "A comprehensive objective for this control.",
|
||||||
|
}
|
||||||
|
sparse = {
|
||||||
|
"requirements": json.dumps(["r1"]),
|
||||||
|
"objective": "Short",
|
||||||
|
}
|
||||||
|
assert quality_score(rich) > quality_score(sparse)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sub-grouping TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubGrouping:
|
||||||
|
def _make_runner(self):
|
||||||
|
db = MagicMock()
|
||||||
|
return BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
def test_groups_by_merge_hint(self):
|
||||||
|
runner = self._make_runner()
|
||||||
|
controls = [
|
||||||
|
{"uuid": "a", "merge_group_hint": "implement:mfa:none"},
|
||||||
|
{"uuid": "b", "merge_group_hint": "implement:mfa:none"},
|
||||||
|
{"uuid": "c", "merge_group_hint": "test:firewall:periodic"},
|
||||||
|
]
|
||||||
|
groups = runner._sub_group_by_merge_hint(controls)
|
||||||
|
assert len(groups) == 2
|
||||||
|
assert len(groups["implement:mfa:none"]) == 2
|
||||||
|
assert len(groups["test:firewall:periodic"]) == 1
|
||||||
|
|
||||||
|
def test_empty_hint_gets_own_group(self):
|
||||||
|
runner = self._make_runner()
|
||||||
|
controls = [
|
||||||
|
{"uuid": "x", "merge_group_hint": ""},
|
||||||
|
{"uuid": "y", "merge_group_hint": ""},
|
||||||
|
]
|
||||||
|
groups = runner._sub_group_by_merge_hint(controls)
|
||||||
|
# Each empty-hint control gets its own group
|
||||||
|
assert len(groups) == 2
|
||||||
|
|
||||||
|
def test_single_control_single_group(self):
|
||||||
|
runner = self._make_runner()
|
||||||
|
controls = [
|
||||||
|
{"uuid": "a", "merge_group_hint": "implement:mfa:none"},
|
||||||
|
]
|
||||||
|
groups = runner._sub_group_by_merge_hint(controls)
|
||||||
|
assert len(groups) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Master Selection TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMasterSelection:
|
||||||
|
"""Best quality score should become master."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_highest_score_is_master(self):
|
||||||
|
"""In a group, the control with highest quality_score is master."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.execute = MagicMock()
|
||||||
|
db.commit = MagicMock()
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
sparse = _make_control("s1", reqs=1, hint="implement:mfa:none")
|
||||||
|
rich = _make_control("r1", reqs=5, tests=3, evidence=2, hint="implement:mfa:none")
|
||||||
|
medium = _make_control("m1", reqs=2, tests=1, hint="implement:mfa:none")
|
||||||
|
|
||||||
|
controls = [sparse, medium, rich]
|
||||||
|
|
||||||
|
# Mock embedding to avoid real API calls
|
||||||
|
with patch("compliance.services.batch_dedup_runner.get_embedding",
|
||||||
|
new_callable=AsyncMock, return_value=[0.1] * 1024), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_upsert",
|
||||||
|
new_callable=AsyncMock, return_value=True), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_search",
|
||||||
|
new_callable=AsyncMock, return_value=[{
|
||||||
|
"score": 0.95,
|
||||||
|
"payload": {"control_uuid": rich["uuid"],
|
||||||
|
"control_id": rich["control_id"]},
|
||||||
|
}]):
|
||||||
|
await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True)
|
||||||
|
|
||||||
|
# Rich should be master (1 master), others linked (2 linked)
|
||||||
|
assert runner.stats["masters"] == 1
|
||||||
|
assert runner.stats["linked"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dry Run TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDryRun:
|
||||||
|
"""Dry run should compute stats but NOT modify DB."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dry_run_no_db_writes(self):
|
||||||
|
db = MagicMock()
|
||||||
|
db.execute = MagicMock()
|
||||||
|
db.commit = MagicMock()
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
controls = [
|
||||||
|
_make_control("a", reqs=3, hint="implement:mfa:none"),
|
||||||
|
_make_control("b", reqs=1, hint="implement:mfa:none"),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("compliance.services.batch_dedup_runner.get_embedding",
|
||||||
|
new_callable=AsyncMock, return_value=[0.1] * 1024), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_upsert",
|
||||||
|
new_callable=AsyncMock, return_value=True), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_search",
|
||||||
|
new_callable=AsyncMock, return_value=[{
|
||||||
|
"score": 0.95,
|
||||||
|
"payload": {"control_uuid": "a-uuid",
|
||||||
|
"control_id": "AUTH-001"},
|
||||||
|
}]):
|
||||||
|
await runner._process_pattern_group("CP-AUTH-001", controls, dry_run=True)
|
||||||
|
|
||||||
|
# No DB execute calls for UPDATE/INSERT (only the initial load query was mocked)
|
||||||
|
# In dry_run, _mark_duplicate and _embed_and_index are skipped
|
||||||
|
assert runner.stats["masters"] == 1
|
||||||
|
# qdrant_upsert should NOT have been called (dry_run skips indexing)
|
||||||
|
from compliance.services.batch_dedup_runner import qdrant_upsert
|
||||||
|
# No commit for dedup operations
|
||||||
|
db.commit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parent Link Transfer TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParentLinkTransfer:
|
||||||
|
"""Parent links should migrate from duplicate to master."""
|
||||||
|
|
||||||
|
def test_transfer_parent_links(self):
|
||||||
|
db = MagicMock()
|
||||||
|
# Mock: duplicate has 2 parent links
|
||||||
|
db.execute.return_value.fetchall.return_value = [
|
||||||
|
("parent-1", "decomposition", 1.0, "DSGVO", "Art. 32", "obl-1"),
|
||||||
|
("parent-2", "decomposition", 0.9, "NIS2", "Art. 21", "obl-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
count = runner._transfer_parent_links("master-uuid", "dup-uuid")
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
# Two INSERT calls for the transferred links
|
||||||
|
assert db.execute.call_count == 3 # 1 SELECT + 2 INSERTs
|
||||||
|
|
||||||
|
def test_transfer_skips_self_reference(self):
|
||||||
|
db = MagicMock()
|
||||||
|
# Parent link points to master itself → should be skipped
|
||||||
|
db.execute.return_value.fetchall.return_value = [
|
||||||
|
("master-uuid", "decomposition", 1.0, "DSGVO", "Art. 32", "obl-1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
count = runner._transfer_parent_links("master-uuid", "dup-uuid")
|
||||||
|
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Title-identical Short-circuit TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTitleIdenticalShortCircuit:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identical_titles_skip_embedding(self):
|
||||||
|
"""Controls with identical titles in same merge group → direct link."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.execute = MagicMock()
|
||||||
|
db.commit = MagicMock()
|
||||||
|
# Mock the parent link transfer query
|
||||||
|
db.execute.return_value.fetchall.return_value = []
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
master = _make_control("m", reqs=3, hint="implement:mfa:none",
|
||||||
|
title="MFA implementieren")
|
||||||
|
candidate = _make_control("c", reqs=1, hint="implement:mfa:none",
|
||||||
|
title="MFA implementieren")
|
||||||
|
|
||||||
|
with patch("compliance.services.batch_dedup_runner.get_embedding",
|
||||||
|
new_callable=AsyncMock) as mock_embed:
|
||||||
|
await runner._check_and_link(master, candidate, "CP-AUTH-001", dry_run=False)
|
||||||
|
|
||||||
|
# Embedding should NOT be called (title-identical short-circuit)
|
||||||
|
mock_embed.assert_not_called()
|
||||||
|
assert runner.stats["linked"] == 1
|
||||||
|
assert runner.stats["skipped_title_identical"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Cross-Regulation Pass TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossRegulationPass:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_reg_creates_link(self):
|
||||||
|
db = MagicMock()
|
||||||
|
db.execute = MagicMock()
|
||||||
|
db.commit = MagicMock()
|
||||||
|
# First call: load masters
|
||||||
|
db.execute.return_value.fetchall.return_value = [
|
||||||
|
("uuid-1", "AUTH-001", "MFA implementieren", "CP-AUTH-001",
|
||||||
|
"implement:multi_factor_auth:none"),
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
cross_result = [{
|
||||||
|
"score": 0.96,
|
||||||
|
"payload": {
|
||||||
|
"control_uuid": "uuid-2",
|
||||||
|
"control_id": "SEC-001",
|
||||||
|
"pattern_id": "CP-SEC-001", # different pattern!
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
|
with patch("compliance.services.batch_dedup_runner.get_embedding",
|
||||||
|
new_callable=AsyncMock, return_value=[0.1] * 1024), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation",
|
||||||
|
new_callable=AsyncMock, return_value=cross_result):
|
||||||
|
await runner._run_cross_regulation_pass()
|
||||||
|
|
||||||
|
assert runner.stats["cross_reg_linked"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_reg_ignores_same_pattern(self):
|
||||||
|
"""Cross-reg should NOT link controls from same pattern."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.execute = MagicMock()
|
||||||
|
db.commit = MagicMock()
|
||||||
|
db.execute.return_value.fetchall.return_value = [
|
||||||
|
("uuid-1", "AUTH-001", "MFA", "CP-AUTH-001", "implement:mfa:none"),
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
|
||||||
|
# Match from SAME pattern
|
||||||
|
cross_result = [{
|
||||||
|
"score": 0.97,
|
||||||
|
"payload": {
|
||||||
|
"control_uuid": "uuid-3",
|
||||||
|
"control_id": "AUTH-002",
|
||||||
|
"pattern_id": "CP-AUTH-001", # same pattern
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
|
with patch("compliance.services.batch_dedup_runner.get_embedding",
|
||||||
|
new_callable=AsyncMock, return_value=[0.1] * 1024), \
|
||||||
|
patch("compliance.services.batch_dedup_runner.qdrant_search_cross_regulation",
|
||||||
|
new_callable=AsyncMock, return_value=cross_result):
|
||||||
|
await runner._run_cross_regulation_pass()
|
||||||
|
|
||||||
|
assert runner.stats["cross_reg_linked"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Progress Stats TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressStats:
|
||||||
|
|
||||||
|
def test_get_status(self):
|
||||||
|
db = MagicMock()
|
||||||
|
runner = BatchDedupRunner(db=db)
|
||||||
|
runner.stats["masters"] = 42
|
||||||
|
runner.stats["linked"] = 100
|
||||||
|
runner._progress_pattern = "CP-AUTH-001"
|
||||||
|
runner._progress_count = 500
|
||||||
|
|
||||||
|
status = runner.get_status()
|
||||||
|
assert status["pattern"] == "CP-AUTH-001"
|
||||||
|
assert status["progress"] == 500
|
||||||
|
assert status["masters"] == 42
|
||||||
|
assert status["linked"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route endpoint TESTS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchDedupRoutes:
|
||||||
|
"""Test the batch-dedup API endpoints."""
|
||||||
|
|
||||||
|
def test_status_endpoint_not_running(self):
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from compliance.api.crosswalk_routes import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router, prefix="/api/compliance")
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
with patch("compliance.api.crosswalk_routes.SessionLocal") as mock_session:
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_session.return_value = mock_db
|
||||||
|
mock_db.execute.return_value.fetchone.return_value = (85000, 0, 85000)
|
||||||
|
|
||||||
|
resp = client.get("/api/compliance/v1/canonical/migrate/batch-dedup/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["running"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HELPERS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_control(
|
||||||
|
prefix: str,
|
||||||
|
reqs: int = 0,
|
||||||
|
tests: int = 0,
|
||||||
|
evidence: int = 0,
|
||||||
|
hint: str = "",
|
||||||
|
title: str = None,
|
||||||
|
pattern_id: str = "CP-AUTH-001",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a mock control dict for testing."""
|
||||||
|
return {
|
||||||
|
"uuid": f"{prefix}-uuid",
|
||||||
|
"control_id": f"AUTH-{prefix}",
|
||||||
|
"title": title or f"Control {prefix}",
|
||||||
|
"objective": f"Objective for {prefix}",
|
||||||
|
"pattern_id": pattern_id,
|
||||||
|
"requirements": json.dumps([f"r{i}" for i in range(reqs)]),
|
||||||
|
"test_procedure": json.dumps([f"t{i}" for i in range(tests)]),
|
||||||
|
"evidence": json.dumps([f"e{i}" for i in range(evidence)]),
|
||||||
|
"release_state": "draft",
|
||||||
|
"merge_group_hint": hint,
|
||||||
|
"action_object_class": "",
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user