fix: paginated indexing to avoid OOM on 53k controls
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1999,20 +1999,14 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s
|
|||||||
"""), {"since": req.since, "until": req.until}).fetchall()
|
"""), {"since": req.since, "until": req.until}).fetchall()
|
||||||
|
|
||||||
target_ids = {r[0] for r in targets}
|
target_ids = {r[0] for r in targets}
|
||||||
|
|
||||||
# Load ALL other draft controls (the reference set)
|
|
||||||
all_drafts = db.execute(text("""
|
|
||||||
SELECT id::text, control_id, title, objective
|
|
||||||
FROM compliance.canonical_controls
|
|
||||||
WHERE release_state = 'draft'
|
|
||||||
ORDER BY control_id
|
|
||||||
""")).fetchall()
|
|
||||||
|
|
||||||
# Exclude targets from reference set
|
|
||||||
reference = [r for r in all_drafts if r[0] not in target_ids]
|
|
||||||
|
|
||||||
total_targets = len(targets)
|
total_targets = len(targets)
|
||||||
total_reference = len(reference)
|
|
||||||
|
# Count reference controls (all drafts except targets)
|
||||||
|
total_reference = db.execute(text("""
|
||||||
|
SELECT COUNT(*) FROM compliance.canonical_controls
|
||||||
|
WHERE release_state = 'draft'
|
||||||
|
AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp))
|
||||||
|
"""), {"since": req.since, "until": req.until}).scalar()
|
||||||
|
|
||||||
_harmonization_recheck_status[job_id] = {
|
_harmonization_recheck_status[job_id] = {
|
||||||
"status": "phase1_indexing", "total_targets": total_targets,
|
"status": "phase1_indexing", "total_targets": total_targets,
|
||||||
@@ -2037,40 +2031,60 @@ async def _run_harmonization_recheck(req: HarmonizationRecheckRequest, job_id: s
|
|||||||
"vectors": {"size": dim, "distance": "Cosine"},
|
"vectors": {"size": dim, "distance": "Cosine"},
|
||||||
})
|
})
|
||||||
|
|
||||||
# Index reference controls in batches
|
# Index reference controls in paginated batches (never load all into memory)
|
||||||
BATCH = 32
|
EMBED_BATCH = 8
|
||||||
|
DB_PAGE = 500
|
||||||
indexed = 0
|
indexed = 0
|
||||||
for i in range(0, total_reference, BATCH):
|
db_offset = 0
|
||||||
batch = reference[i:i + BATCH]
|
|
||||||
texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch]
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
while db_offset < total_reference + DB_PAGE:
|
||||||
resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts})
|
page_rows = db.execute(text("""
|
||||||
if resp.status_code != 200:
|
SELECT id::text, control_id, title, objective
|
||||||
continue
|
FROM compliance.canonical_controls
|
||||||
embeddings = resp.json().get("embeddings", [])
|
WHERE release_state = 'draft'
|
||||||
|
AND NOT (updated_at >= CAST(:since AS timestamp) AND updated_at < CAST(:until AS timestamp))
|
||||||
|
ORDER BY control_id
|
||||||
|
LIMIT :lim OFFSET :off
|
||||||
|
"""), {"since": req.since, "until": req.until, "lim": DB_PAGE, "off": db_offset}).fetchall()
|
||||||
|
|
||||||
points = []
|
if not page_rows:
|
||||||
for j, (r, emb) in enumerate(zip(batch, embeddings)):
|
break
|
||||||
if not emb:
|
|
||||||
continue
|
|
||||||
points.append({
|
|
||||||
"id": str(_uuid.uuid5(_uuid.NAMESPACE_DNS, r[0])),
|
|
||||||
"vector": emb,
|
|
||||||
"payload": {"control_uuid": r[0], "control_id": r[1], "title": r[2] or ""},
|
|
||||||
})
|
|
||||||
|
|
||||||
if points:
|
for eb_start in range(0, len(page_rows), EMBED_BATCH):
|
||||||
await client.put(
|
batch = page_rows[eb_start:eb_start + EMBED_BATCH]
|
||||||
f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points",
|
texts = [f"{r[2] or ''} {(r[3] or '')[:200]}" for r in batch]
|
||||||
json={"points": points},
|
|
||||||
)
|
|
||||||
indexed += len(points)
|
|
||||||
|
|
||||||
if (i + BATCH) % 1000 < BATCH:
|
try:
|
||||||
_harmonization_recheck_status[job_id]["indexed"] = indexed
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
_harmonization_recheck_status[job_id]["status"] = "phase1_indexing"
|
resp = await client.post(f"{EMBEDDING_URL}/embed", json={"texts": texts})
|
||||||
logger.info("Recheck indexing: %d/%d reference controls", indexed, total_reference)
|
if resp.status_code != 200:
|
||||||
|
continue
|
||||||
|
embeddings = resp.json().get("embeddings", [])
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for r, emb in zip(batch, embeddings):
|
||||||
|
if not emb:
|
||||||
|
continue
|
||||||
|
points.append({
|
||||||
|
"id": str(_uuid.uuid5(_uuid.NAMESPACE_DNS, r[0])),
|
||||||
|
"vector": emb,
|
||||||
|
"payload": {"control_uuid": r[0], "control_id": r[1], "title": r[2] or ""},
|
||||||
|
})
|
||||||
|
|
||||||
|
if points:
|
||||||
|
await client.put(
|
||||||
|
f"{QDRANT_URL}/collections/{RECHECK_COLLECTION}/points",
|
||||||
|
json={"points": points},
|
||||||
|
)
|
||||||
|
indexed += len(points)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Indexing batch error at offset %d: %s", db_offset + eb_start, e)
|
||||||
|
|
||||||
|
db_offset += DB_PAGE
|
||||||
|
_harmonization_recheck_status[job_id]["indexed"] = indexed
|
||||||
|
_harmonization_recheck_status[job_id]["status"] = "phase1_indexing"
|
||||||
|
if db_offset % 5000 < DB_PAGE:
|
||||||
|
logger.info("Recheck indexing: %d/%d", indexed, total_reference)
|
||||||
|
|
||||||
logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed)
|
logger.info("Recheck Phase 1 done: %d reference controls indexed", indexed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user