#!/usr/bin/env python3 """ G-pre1 Refinement: Re-cluster large object groups (>200 members in master_controls) with k=10 sub-clusters for finer granularity. Replaces the large master controls with smaller, more specific ones. """ import json import logging import os import httpx import numpy as np from sklearn.cluster import MiniBatchKMeans from sqlalchemy import create_engine, text logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("gpre1-refine") DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db") EMBEDDING_URL = "http://embedding-service:8087" def main(): engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"}) # Step 1: Find large master controls and their object_group_ids with engine.connect() as c: large_mcs = c.execute(text(""" SELECT mc.master_control_id, mc.object_group_id, mc.canonical_name, mc.total_controls, og.members, og.member_count FROM master_controls mc JOIN object_groups og ON og.group_id = mc.object_group_id WHERE mc.total_controls > 200 ORDER BY mc.total_controls DESC """)).fetchall() logger.info("Found %d large master controls to refine", len(large_mcs)) # Step 2: For each large group, re-cluster the object members with engine.connect() as c: max_gid = c.execute(text("SELECT COALESCE(MAX(group_id), 0) FROM object_groups")).scalar() next_gid = max_gid + 1 groups_to_delete = [] new_groups = [] total_sub = 0 for mc_id, og_id, canonical, total, members_json, member_count in large_mcs: members = json.loads(members_json) if isinstance(members_json, str) else members_json if len(members) < 20: logger.info(" Skip %s (%d members) — too few to split", canonical, len(members)) continue # Determine k based on group size k = max(4, min(len(members) // 15, 20)) # 4-20 sub-clusters # Embed members embeddings = _embed_texts(members) if embeddings is None: logger.error(" Failed to embed %s", canonical) continue # Normalize + cluster norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms[norms == 0] = 1 normalized = embeddings / norms kmeans = MiniBatchKMeans(n_clusters=k, batch_size=min(100, len(members)), max_iter=50, random_state=42) labels = kmeans.fit_predict(normalized) # Build sub-groups subs: dict[int, list[str]] = {} for i, member in enumerate(members): subs.setdefault(int(labels[i]), []).append(member) for sub_members in subs.values(): new_groups.append({ "group_id": next_gid, "canonical_name": sub_members[0], "member_count": len(sub_members), "members": json.dumps(sub_members), "top_controls_count": 0, }) next_gid += 1 total_sub += 1 groups_to_delete.append(og_id) logger.info(" %s (%s, %d members) → %d sub-groups (k=%d)", mc_id, canonical, len(members), len(subs), k) logger.info("Refinement: %d groups → %d sub-groups", len(groups_to_delete), total_sub) # Step 3: Update DB — replace old object_groups, delete old master_controls with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) # Delete old master controls and their members for affected groups for og_id in groups_to_delete: c.execute(text(""" DELETE FROM master_control_members WHERE master_control_uuid IN ( SELECT id FROM master_controls WHERE object_group_id = :gid ) """), {"gid": og_id}) c.execute(text("DELETE FROM master_controls WHERE object_group_id = :gid"), {"gid": og_id}) c.execute(text("DELETE FROM object_groups WHERE group_id = :gid"), {"gid": og_id}) # Insert new sub-groups for g in new_groups: c.execute(text(""" INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count) VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count) """), g) logger.info("DB updated: %d old groups deleted, %d new groups inserted", len(groups_to_delete), len(new_groups)) # Step 4: Re-run master control generation for affected groups logger.info("Re-generating master controls for new sub-groups...") _regenerate_master_controls(engine, [g["group_id"] for g in new_groups]) # Final stats with engine.connect() as c: mc_count = c.execute(text("SELECT count(*) FROM master_controls")).scalar() og_count = c.execute(text("SELECT count(*) FROM object_groups")).scalar() large = c.execute(text("SELECT count(*) FROM master_controls WHERE total_controls > 200")).scalar() logger.info("Final: %d master controls, %d object groups, %d still >200", mc_count, og_count, large) def _regenerate_master_controls(engine, group_ids: list[int]): """Re-create master controls for specific object_group_ids.""" from collections import defaultdict from services.control_dedup import normalize_object # Build reverse index for new groups only object_to_group = {} with engine.connect() as c: for gid in group_ids: row = c.execute(text( "SELECT group_id, canonical_name, members FROM object_groups WHERE group_id = :gid" ), {"gid": gid}).fetchone() if row: members = json.loads(row[2]) if isinstance(row[2], str) else row[2] for m in members: object_to_group[m] = (row[0], row[1]) # Load controls for these objects with engine.connect() as c: rows = c.execute(text(""" SELECT id, control_id, generation_metadata->>'merge_group_hint' AS hint FROM canonical_controls WHERE generation_metadata->>'merge_group_hint' IS NOT NULL AND release_state NOT IN ('deprecated', 'rejected') """)).fetchall() group_phases: dict[int, dict[str, list]] = defaultdict(lambda: defaultdict(list)) group_names: dict[int, str] = {} for uuid, control_id, hint in rows: parts = hint.split(":", 2) if len(parts) < 2: continue action, obj = parts[0], parts[1] phase = parts[2] if len(parts) > 2 else "implementation" normed = normalize_object(obj) if normed in object_to_group: gid, canonical = object_to_group[normed] elif obj in object_to_group: gid, canonical = object_to_group[obj] else: continue group_phases[gid][phase].append((str(uuid), control_id, action)) group_names[gid] = canonical # Create master controls mc_count = 0 mem_count = 0 with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) for gid, phases in group_phases.items(): if len(phases) < 2: continue mc_id = "MC-%d" % gid canonical = group_names.get(gid, "unknown") sorted_phases = sorted(phases.keys()) phase_counts = {p: len(ctrls) for p, ctrls in phases.items()} total = sum(phase_counts.values()) c.execute(text(""" INSERT INTO master_controls (master_control_id, object_group_id, canonical_name, phases_covered, phase_control_count, total_controls) VALUES (:mcid, :gid, :name, CAST(:phases AS jsonb), CAST(:pcounts AS jsonb), :total) """), { "mcid": mc_id, "gid": gid, "name": canonical, "phases": json.dumps(sorted_phases), "pcounts": json.dumps(phase_counts), "total": total, }) mc_uuid = c.execute(text( "SELECT id FROM master_controls WHERE master_control_id = :mcid" ), {"mcid": mc_id}).scalar() for phase, controls in phases.items(): for ctrl_uuid, ctrl_id, action in controls: c.execute(text(""" INSERT INTO master_control_members (master_control_uuid, control_uuid, phase, action) VALUES (CAST(:mc AS uuid), CAST(:ctrl AS uuid), :phase, :action) """), {"mc": str(mc_uuid), "ctrl": ctrl_uuid, "phase": phase, "action": action}) mem_count += 1 mc_count += 1 logger.info("Created %d new master controls with %d members", mc_count, mem_count) def _embed_texts(texts: list[str]) -> np.ndarray | None: """Embed texts with retry logic.""" try: result = np.zeros((len(texts), 1024), dtype=np.float32) batch_size = 64 for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] for attempt in range(3): try: with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client: resp = client.post(f"{EMBEDDING_URL}/embed", json={"texts": batch}) resp.raise_for_status() embs = resp.json().get("embeddings", []) end = min(i + len(embs), len(texts)) result[i:end] = np.array(embs, dtype=np.float32) break except Exception as e: if attempt == 2: logger.error("Embed batch %d failed: %s", i, e) import time time.sleep(2) return result except Exception as e: logger.error("Embedding failed: %s", e) return None if __name__ == "__main__": main()