diff --git a/control-pipeline/api/control_generator_routes.py b/control-pipeline/api/control_generator_routes.py index d544b03..dd05be0 100644 --- a/control-pipeline/api/control_generator_routes.py +++ b/control-pipeline/api/control_generator_routes.py @@ -1959,22 +1959,25 @@ async def get_anchor_backfill_status(backfill_id: str): # ============================================================================= -# HARMONIZATION RECHECK — verify promoted controls against Qdrant +# HARMONIZATION RECHECK — index ALL drafts, then check target controls # ============================================================================= class HarmonizationRecheckRequest(BaseModel): dry_run: bool = True - since: str = "2026-04-24 08:30:00" # timestamp filter for promoted controls + since: str = "2026-04-24 08:30:00" until: str = "2026-04-24 09:00:00" _harmonization_recheck_status: dict = {} +RECHECK_COLLECTION = "draft_controls_recheck" + async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: str): - """Re-check promoted controls via Embedding + LLM dedup.""" + """Two-phase recheck: (1) index ALL drafts, (2) search target controls against them.""" import os import httpx + import uuid as _uuid QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333") EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") @@ -1982,11 +1985,11 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") AUTO_DUP = 0.92 THRESHOLD = 0.85 - COLLECTION = "atomic_controls_dedup" db = SessionLocal() try: - rows = db.execute(text(""" + # Load target controls (the ones we want to check) + targets = db.execute(text(""" SELECT id::text, control_id, title, objective FROM compliance.canonical_controls WHERE release_state = 'draft' @@ -1995,23 +1998,94 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s ORDER BY control_id """), {"since": req.since, "until": req.until}).fetchall() - total = len(rows) + target_ids = {r[0] for r in targets} + + # Load ALL other draft controls (the reference set) + all_drafts = db.execute(text(""" + SELECT id::text, control_id, title, objective + FROM compliance.canonical_controls + WHERE release_state = 'draft' + ORDER BY control_id + """)).fetchall() + + # Exclude targets from reference set + reference = [r for r in all_drafts if r[0] not in target_ids] + + total_targets = len(targets) + total_reference = len(reference) + + _harmonization_recheck_status[job_id] = { + "status": "phase1_indexing", "total_targets": total_targets, + "total_reference": total_reference, "indexed": 0, + "processed": 0, "unique": 0, "duplicate": 0, + "llm_calls": 0, "errors": 0, "dry_run": req.dry_run, + } + + logger.info("Harmonization recheck: %d targets, %d reference controls", total_targets, total_reference) + + # Phase 1: Create temporary Qdrant collection and index reference controls + async with httpx.AsyncClient(timeout=30.0) as client: + # Delete old collection if exists + await client.delete(f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}") + + # Get embedding dimension + resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": ["test"]}) + dim = len(resp.json().get("embeddings", [[]])[0]) + + # Create collection + await client.put(f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}", json={ + "vectors": {"size": dim, "distance": "Cosine"}, + }) + + # Index reference controls in batches + BATCH = 32 + indexed = 0 + for i in range(0, total_reference, BATCH): + batch = reference[i:i + BATCH] + texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch] + + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts}) + if resp.status_code != 200: + continue + embeddings = resp.json().get("embeddings", []) + + points = [] + for j, (r, emb) in enumerate(zip(batch, embeddings)): + if not emb: + continue + points.append({ + "id": str(_uuid.uuid5(_uuid.NAMESPACE_DNS, r[0])), + "vector": emb, + "payload": {"control_uuid": r[0], "control_id": r[1], "title": r[2] or ""}, + }) + + if points: + await client.put( + f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points", + json={"points": points}, + ) + indexed += len(points) + + if (i + BATCH) % 1000 < BATCH: + _harmonization_recheck_status[job_id]["indexed"] = indexed + _harmonization_recheck_status[job_id]["status"] = "phase1_indexing" + logger.info("Recheck indexing: %d/%d reference controls", indexed, total_reference) + + logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed) + + # Phase 2: Check each target against the reference collection + _harmonization_recheck_status[job_id]["status"] = "phase2_checking" unique = 0 duplicate = 0 llm_calls = 0 no_match = 0 errors = 0 - _harmonization_recheck_status[job_id] = { - "status": "running", "total": total, "processed": 0, - "unique": 0, "duplicate": 0, "llm_calls": 0, "dry_run": req.dry_run, - } - - for i, row in enumerate(rows): + for i, row in enumerate(targets): try: - search_text = f"{row.title or ''} {(row.objective or '')[:200]}" + search_text = f"{row[2] or ''} {(row[3] or '')[:200]}" - # Get embedding async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": [search_text]}) @@ -2023,17 +2097,13 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s errors += 1 continue - # Search Qdrant async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.post( - f"{QDRANT_URL}/collections/{COLLECTION}/points/search", + f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points/search", json={"vector": emb, "limit": 3, "score_threshold": THRESHOLD, "with_payload": {"include": ["control_id", "title"]}}) results = resp.json().get("result", []) if resp.status_code == 200 else [] - # Exclude self - results = [r for r in results - if r.get("payload", {}).get("control_uuid") != row[0]] if not results: no_match += 1 @@ -2053,7 +2123,6 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s duplicate += 1 elif best_score >= THRESHOLD: - # LLM verification try: async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post(f"{OLLAMA_URL}/api/chat", json={ @@ -2065,7 +2134,7 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s 'Antworte NUR mit JSON: {"verdict":"DUPLIKAT" oder "VERSCHIEDEN","reason":"..."}' )}, {"role": "user", "content": ( - f"Control A:\n{row.title or ''}\n\n" + f"Control A:\n{row[2] or ''}\n\n" f"Control B:\n{best_title}\n\nDuplikat?" )}, ], @@ -2091,14 +2160,15 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s except Exception as e: errors += 1 - logger.warning("Harmonization recheck error %s: %s", row[1], e) + logger.warning("Recheck error %s: %s", row[1], e) - if (i + 1) % 100 == 0: + if (i + 1) % 50 == 0: if not req.dry_run: db.commit() _harmonization_recheck_status[job_id] = { - "status": "running", "total": total, "processed": i + 1, - "unique": unique, "duplicate": duplicate, + "status": "phase2_checking", "total_targets": total_targets, + "total_reference": total_reference, "indexed": indexed, + "processed": i + 1, "unique": unique, "duplicate": duplicate, "llm_calls": llm_calls, "no_match": no_match, "errors": errors, "dry_run": req.dry_run, } @@ -2106,14 +2176,19 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s if not req.dry_run: db.commit() + # Cleanup temporary collection + async with httpx.AsyncClient(timeout=30.0) as client: + await client.delete(f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}") + _harmonization_recheck_status[job_id] = { - "status": "completed", "total": total, "processed": total, - "unique": unique, "duplicate": duplicate, + "status": "completed", "total_targets": total_targets, + "total_reference": total_reference, "indexed": indexed, + "processed": total_targets, "unique": unique, "duplicate": duplicate, "llm_calls": llm_calls, "no_match": no_match, "errors": errors, "dry_run": req.dry_run, } - logger.info("Harmonization recheck %s: %d total, %d unique, %d dup, %d llm, %d err", - job_id, total, unique, duplicate, llm_calls, errors) + logger.info("Recheck DONE: %d targets, %d unique, %d dup, %d llm, %d err", + total_targets, unique, duplicate, llm_calls, errors) except Exception as e: logger.error("Harmonization recheck %s failed: %s", job_id, e) @@ -2124,8 +2199,9 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s @router.post("/generate/harmonization-recheck") async def start_harmonization_recheck(req: HarmonizationRecheckRequest): - """Re-check promoted controls against Qdrant dedup collection. - Uses Embedding + LLM verification for borderline matches. + """Re-check promoted controls against ALL other draft controls. + Phase 1: Index all non-target drafts into temp Qdrant collection. + Phase 2: Search each target control, Embedding + LLM for borderline. """ import uuid job_id = str(uuid.uuid4())[:8]