diff --git a/control-pipeline/api/control_generator_routes.py b/control-pipeline/api/control_generator_routes.py index dd05be0..37b9ae8 100644 --- a/control-pipeline/api/control_generator_routes.py +++ b/control-pipeline/api/control_generator_routes.py @@ -1999,20 +1999,14 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s """), {"since": req.since, "until": req.until}).fetchall() target_ids = {r[0] for r in targets} - - # Load ALL other draft controls (the reference set) - all_drafts = db.execute(text(""" - SELECT id::text, control_id, title, objective - FROM compliance.canonical_controls - WHERE release_state = 'draft' - ORDER BY control_id - """)).fetchall() - - # Exclude targets from reference set - reference = [r for r in all_drafts if r[0] not in target_ids] - total_targets = len(targets) - total_reference = len(reference) + + # Count reference controls (all drafts except targets) + total_reference = db.execute(text(""" + SELECT COUNT(*) FROM compliance.canonical_controls + WHERE release_state = 'draft' + AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp)) + """), {"since": req.since, "until": req.until}).scalar() _harmonization_recheck_status[job_id] = { "status": "phase1_indexing", "total_targets": total_targets, @@ -2037,40 +2031,60 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s "vectors": {"size": dim, "distance": "Cosine"}, }) - # Index reference controls in batches - BATCH = 32 + # Index reference controls in paginated batches (never load all into memory) + EMBED_BATCH = 8 + DB_PAGE = 500 indexed = 0 - for i in range(0, total_reference, BATCH): - batch = reference[i:i + BATCH] - texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch] + db_offset = 0 - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts}) - if resp.status_code != 200: - continue - embeddings = resp.json().get("embeddings", []) + while db_offset < total_reference + DB_PAGE: + page_rows = db.execute(text(""" + SELECT id::text, control_id, title, objective + FROM compliance.canonical_controls + WHERE release_state = 'draft' + AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp)) + ORDER BY control_id + LIMIT :lim OFFSET :off + """), {"since": req.since, "until": req.until, "lim": DB_PAGE, "off": db_offset}).fetchall() - points = [] - for j, (r, emb) in enumerate(zip(batch, embeddings)): - if not emb: - continue - points.append({ - "id": str(_uuid.uuid5(_uuid.NAMESPACE_DNS, r[0])), - "vector": emb, - "payload": {"control_uuid": r[0], "control_id": r[1], "title": r[2] or ""}, - }) + if not page_rows: + break - if points: - await client.put( - f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points", - json={"points": points}, - ) - indexed += len(points) + for eb_start in range(0, len(page_rows), EMBED_BATCH): + batch = page_rows[eb_start:eb_start + EMBED_BATCH] + texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch] - if (i + BATCH) % 1000 < BATCH: - _harmonization_recheck_status[job_id]["indexed"] = indexed - _harmonization_recheck_status[job_id]["status"] = "phase1_indexing" - logger.info("Recheck indexing: %d/%d reference controls", indexed, total_reference) + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts}) + if resp.status_code != 200: + continue + embeddings = resp.json().get("embeddings", []) + + points = [] + for r, emb in zip(batch, embeddings): + if not emb: + continue + points.append({ + "id": str(_uuid.uuid5(_uuid.NAMESPACE_DNS, r[0])), + "vector": emb, + "payload": {"control_uuid": r[0], "control_id": r[1], "title": r[2] or ""}, + }) + + if points: + await client.put( + f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points", + json={"points": points}, + ) + indexed += len(points) + except Exception as e: + logger.warning("Indexing batch error at offset %d: %s", db_offset + eb_start, e) + + db_offset += DB_PAGE + _harmonization_recheck_status[job_id]["indexed"] = indexed + _harmonization_recheck_status[job_id]["status"] = "phase1_indexing" + if db_offset % 5000 < DB_PAGE: + logger.info("Recheck indexing: %d/%d", indexed, total_reference) logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed)