diff --git a/control-pipeline/scripts/gpre1_refine_large.py b/control-pipeline/scripts/gpre1_refine_large.py new file mode 100644 index 0000000..69305aa --- /dev/null +++ b/control-pipeline/scripts/gpre1_refine_large.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +G-pre1 Refinement: Re-cluster large object groups (>200 members in master_controls) +with k=10 sub-clusters for finer granularity. + +Replaces the large master controls with smaller, more specific ones. +""" + +import json +import logging +import os + +import httpx +import numpy as np +from sklearn.cluster import MiniBatchKMeans +from sqlalchemy import create_engine, text + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger("gpre1-refine") + +DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db") +EMBEDDING_URL = "http://embedding-service:8087" + + +def main(): + engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"}) + + # Step 1: Find large master controls and their object_group_ids + with engine.connect() as c: + large_mcs = c.execute(text(""" + SELECT mc.master_control_id, mc.object_group_id, mc.canonical_name, mc.total_controls, + og.members, og.member_count + FROM master_controls mc + JOIN object_groups og ON og.group_id = mc.object_group_id + WHERE mc.total_controls > 200 + ORDER BY mc.total_controls DESC + """)).fetchall() + + logger.info("Found %d large master controls to refine", len(large_mcs)) + + # Step 2: For each large group, re-cluster the object members + with engine.connect() as c: + max_gid = c.execute(text("SELECT COALESCE(MAX(group_id), 0) FROM object_groups")).scalar() + next_gid = max_gid + 1 + + groups_to_delete = [] + new_groups = [] + total_sub = 0 + + for mc_id, og_id, canonical, total, members_json, member_count in large_mcs: + members = json.loads(members_json) if isinstance(members_json, str) else members_json + + if len(members) < 20: + logger.info(" Skip %s (%d members) — too few to split", canonical, len(members)) + continue + + # Determine k based on group size + k = max(4, min(len(members) // 15, 20)) # 4-20 sub-clusters + + # Embed members + embeddings = _embed_texts(members) + if embeddings is None: + logger.error(" Failed to embed %s", canonical) + continue + + # Normalize + cluster + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1 + normalized = embeddings / norms + + kmeans = MiniBatchKMeans(n_clusters=k, batch_size=min(100, len(members)), + max_iter=50, random_state=42) + labels = kmeans.fit_predict(normalized) + + # Build sub-groups + subs: dict[int, list[str]] = {} + for i, member in enumerate(members): + subs.setdefault(int(labels[i]), []).append(member) + + for sub_members in subs.values(): + new_groups.append({ + "group_id": next_gid, + "canonical_name": sub_members[0], + "member_count": len(sub_members), + "members": json.dumps(sub_members), + "top_controls_count": 0, + }) + next_gid += 1 + total_sub += 1 + + groups_to_delete.append(og_id) + logger.info(" %s (%s, %d members) → %d sub-groups (k=%d)", + mc_id, canonical, len(members), len(subs), k) + + logger.info("Refinement: %d groups → %d sub-groups", len(groups_to_delete), total_sub) + + # Step 3: Update DB — replace old object_groups, delete old master_controls + with engine.begin() as c: + c.execute(text("SET search_path TO compliance, public")) + + # Delete old master controls and their members for affected groups + for og_id in groups_to_delete: + c.execute(text(""" + DELETE FROM master_control_members + WHERE master_control_uuid IN ( + SELECT id FROM master_controls WHERE object_group_id = :gid + ) + """), {"gid": og_id}) + c.execute(text("DELETE FROM master_controls WHERE object_group_id = :gid"), {"gid": og_id}) + c.execute(text("DELETE FROM object_groups WHERE group_id = :gid"), {"gid": og_id}) + + # Insert new sub-groups + for g in new_groups: + c.execute(text(""" + INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count) + VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count) + """), g) + + logger.info("DB updated: %d old groups deleted, %d new groups inserted", len(groups_to_delete), len(new_groups)) + + # Step 4: Re-run master control generation for affected groups + logger.info("Re-generating master controls for new sub-groups...") + _regenerate_master_controls(engine, [g["group_id"] for g in new_groups]) + + # Final stats + with engine.connect() as c: + mc_count = c.execute(text("SELECT count(*) FROM master_controls")).scalar() + og_count = c.execute(text("SELECT count(*) FROM object_groups")).scalar() + large = c.execute(text("SELECT count(*) FROM master_controls WHERE total_controls > 200")).scalar() + logger.info("Final: %d master controls, %d object groups, %d still >200", mc_count, og_count, large) + + +def _regenerate_master_controls(engine, group_ids: list[int]): + """Re-create master controls for specific object_group_ids.""" + from collections import defaultdict + from services.control_dedup import normalize_object + + # Build reverse index for new groups only + object_to_group = {} + with engine.connect() as c: + for gid in group_ids: + row = c.execute(text( + "SELECT group_id, canonical_name, members FROM object_groups WHERE group_id = :gid" + ), {"gid": gid}).fetchone() + if row: + members = json.loads(row[2]) if isinstance(row[2], str) else row[2] + for m in members: + object_to_group[m] = (row[0], row[1]) + + # Load controls for these objects + with engine.connect() as c: + rows = c.execute(text(""" + SELECT id, control_id, generation_metadata->>'merge_group_hint' AS hint + FROM canonical_controls + WHERE generation_metadata->>'merge_group_hint' IS NOT NULL + AND release_state NOT IN ('deprecated', 'rejected') + """)).fetchall() + + group_phases: dict[int, dict[str, list]] = defaultdict(lambda: defaultdict(list)) + group_names: dict[int, str] = {} + + for uuid, control_id, hint in rows: + parts = hint.split(":", 2) + if len(parts) < 2: + continue + action, obj = parts[0], parts[1] + phase = parts[2] if len(parts) > 2 else "implementation" + + normed = normalize_object(obj) + if normed in object_to_group: + gid, canonical = object_to_group[normed] + elif obj in object_to_group: + gid, canonical = object_to_group[obj] + else: + continue + + group_phases[gid][phase].append((str(uuid), control_id, action)) + group_names[gid] = canonical + + # Create master controls + mc_count = 0 + mem_count = 0 + with engine.begin() as c: + c.execute(text("SET search_path TO compliance, public")) + for gid, phases in group_phases.items(): + if len(phases) < 2: + continue + + mc_id = "MC-%d" % gid + canonical = group_names.get(gid, "unknown") + sorted_phases = sorted(phases.keys()) + phase_counts = {p: len(ctrls) for p, ctrls in phases.items()} + total = sum(phase_counts.values()) + + c.execute(text(""" + INSERT INTO master_controls + (master_control_id, object_group_id, canonical_name, + phases_covered, phase_control_count, total_controls) + VALUES (:mcid, :gid, :name, + CAST(:phases AS jsonb), CAST(:pcounts AS jsonb), :total) + """), { + "mcid": mc_id, "gid": gid, "name": canonical, + "phases": json.dumps(sorted_phases), + "pcounts": json.dumps(phase_counts), + "total": total, + }) + + mc_uuid = c.execute(text( + "SELECT id FROM master_controls WHERE master_control_id = :mcid" + ), {"mcid": mc_id}).scalar() + + for phase, controls in phases.items(): + for ctrl_uuid, ctrl_id, action in controls: + c.execute(text(""" + INSERT INTO master_control_members + (master_control_uuid, control_uuid, phase, action) + VALUES (CAST(:mc AS uuid), CAST(:ctrl AS uuid), :phase, :action) + """), {"mc": str(mc_uuid), "ctrl": ctrl_uuid, "phase": phase, "action": action}) + mem_count += 1 + + mc_count += 1 + + logger.info("Created %d new master controls with %d members", mc_count, mem_count) + + +def _embed_texts(texts: list[str]) -> np.ndarray | None: + """Embed texts with retry logic.""" + try: + result = np.zeros((len(texts), 1024), dtype=np.float32) + batch_size = 64 + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + for attempt in range(3): + try: + with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client: + resp = client.post(f"{EMBEDDING_URL}/embed", json={"texts": batch}) + resp.raise_for_status() + embs = resp.json().get("embeddings", []) + end = min(i + len(embs), len(texts)) + result[i:end] = np.array(embs, dtype=np.float32) + break + except Exception as e: + if attempt == 2: + logger.error("Embed batch %d failed: %s", i, e) + import time + time.sleep(2) + return result + except Exception as e: + logger.error("Embedding failed: %s", e) + return None + + +if __name__ == "__main__": + main()