fix: paginated indexing to avoid OOM on 53k controls

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-24 16:31:20 +02:00
parent 043bcb65d8
commit 1a3101066e

View File

@@ -1999,20 +1999,14 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s
"""), {"since": req.since, "until": req.until}).fetchall() """), {"since": req.since, "until": req.until}).fetchall()
target_ids = {r[0] for r in targets} target_ids = {r[0] for r in targets}
# Load ALL other draft controls (the reference set)
all_drafts = db.execute(text("""
SELECT id::text, control_id, title, objective
FROM compliance.canonical_controls
WHERE release_state = 'draft'
ORDER BY control_id
""")).fetchall()
# Exclude targets from reference set
reference = [r for r in all_drafts if r[0] not in target_ids]
total_targets = len(targets) total_targets = len(targets)
total_reference = len(reference)
# Count reference controls (all drafts except targets)
total_reference = db.execute(text("""
SELECT COUNT(*) FROM compliance.canonical_controls
WHERE release_state = 'draft'
AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp))
"""), {"since": req.since, "until": req.until}).scalar()
_harmonization_recheck_status[job_id] = { _harmonization_recheck_status[job_id] = {
"status": "phase1_indexing", "total_targets": total_targets, "status": "phase1_indexing", "total_targets": total_targets,
@@ -2037,21 +2031,38 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s
"vectors": {"size": dim, "distance": "Cosine"}, "vectors": {"size": dim, "distance": "Cosine"},
}) })
# Index reference controls in batches # Index reference controls in paginated batches (never load all into memory)
BATCH = 32 EMBED_BATCH = 8
DB_PAGE = 500
indexed = 0 indexed = 0
for i in range(0, total_reference, BATCH): db_offset = 0
batch = reference[i:i + BATCH]
while db_offset < total_reference + DB_PAGE:
page_rows = db.execute(text("""
SELECT id::text, control_id, title, objective
FROM compliance.canonical_controls
WHERE release_state = 'draft'
AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp))
ORDER BY control_id
LIMIT :lim OFFSET :off
"""), {"since": req.since, "until": req.until, "lim": DB_PAGE, "off": db_offset}).fetchall()
if not page_rows:
break
for eb_start in range(0, len(page_rows), EMBED_BATCH):
batch = page_rows[eb_start:eb_start + EMBED_BATCH]
texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch] texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch]
async with httpx.AsyncClient(timeout=60.0) as client: try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts}) resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts})
if resp.status_code != 200: if resp.status_code != 200:
continue continue
embeddings = resp.json().get("embeddings", []) embeddings = resp.json().get("embeddings", [])
points = [] points = []
for j, (r, emb) in enumerate(zip(batch, embeddings)): for r, emb in zip(batch, embeddings):
if not emb: if not emb:
continue continue
points.append({ points.append({
@@ -2066,11 +2077,14 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s
json={"points": points}, json={"points": points},
) )
indexed += len(points) indexed += len(points)
except Exception as e:
logger.warning("Indexing batch error at offset %d: %s", db_offset + eb_start, e)
if (i + BATCH) % 1000 < BATCH: db_offset += DB_PAGE
_harmonization_recheck_status[job_id]["indexed"] = indexed _harmonization_recheck_status[job_id]["indexed"] = indexed
_harmonization_recheck_status[job_id]["status"] = "phase1_indexing" _harmonization_recheck_status[job_id]["status"] = "phase1_indexing"
logger.info("Recheck indexing: %d/%d reference controls", indexed, total_reference) if db_offset % 5000 < DB_PAGE:
logger.info("Recheck indexing: %d/%d", indexed, total_reference)
logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed) logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed)