Files
breakpilot-core/control-pipeline/scripts/gpre1_refine_large.py
T
Benjamin Admin 0092c4fe47 feat(pipeline): G-pre1 refinement script for large object groups
Splits master controls >200 members by re-clustering their object groups
with k=4-20 per group. First round: 38 groups → 325 sub-groups → 253 new MCs.
25 generic MCs remain (monitoring, procedure, etc.) — need regulation-source split.

Session summary: Block F complete, Control Generation (1,599+), Pass 0a/0b,
Production Sync, G-pre1/2/3 Object Clustering + Master Controls + API,
G1-G4 Compliance Execution Layer (Decision Trace, Commit Ledger, Decision Memory,
Pre-Deployment Enforcement).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-06 20:41:49 +02:00

255 lines
10 KiB
Python

#!/usr/bin/env python3
"""
G-pre1 Refinement: Re-cluster large object groups (>200 members in master_controls)
with k=10 sub-clusters for finer granularity.
Replaces the large master controls with smaller, more specific ones.
"""
import json
import logging
import os
import httpx
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import create_engine, text
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("gpre1-refine")
DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db")
EMBEDDING_URL = "http://embedding-service:8087"
def main():
engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"})
# Step 1: Find large master controls and their object_group_ids
with engine.connect() as c:
large_mcs = c.execute(text("""
SELECT mc.master_control_id, mc.object_group_id, mc.canonical_name, mc.total_controls,
og.members, og.member_count
FROM master_controls mc
JOIN object_groups og ON og.group_id = mc.object_group_id
WHERE mc.total_controls > 200
ORDER BY mc.total_controls DESC
""")).fetchall()
logger.info("Found %d large master controls to refine", len(large_mcs))
# Step 2: For each large group, re-cluster the object members
with engine.connect() as c:
max_gid = c.execute(text("SELECT COALESCE(MAX(group_id), 0) FROM object_groups")).scalar()
next_gid = max_gid + 1
groups_to_delete = []
new_groups = []
total_sub = 0
for mc_id, og_id, canonical, total, members_json, member_count in large_mcs:
members = json.loads(members_json) if isinstance(members_json, str) else members_json
if len(members) < 20:
logger.info(" Skip %s (%d members) — too few to split", canonical, len(members))
continue
# Determine k based on group size
k = max(4, min(len(members) // 15, 20)) # 4-20 sub-clusters
# Embed members
embeddings = _embed_texts(members)
if embeddings is None:
logger.error(" Failed to embed %s", canonical)
continue
# Normalize + cluster
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
normalized = embeddings / norms
kmeans = MiniBatchKMeans(n_clusters=k, batch_size=min(100, len(members)),
max_iter=50, random_state=42)
labels = kmeans.fit_predict(normalized)
# Build sub-groups
subs: dict[int, list[str]] = {}
for i, member in enumerate(members):
subs.setdefault(int(labels[i]), []).append(member)
for sub_members in subs.values():
new_groups.append({
"group_id": next_gid,
"canonical_name": sub_members[0],
"member_count": len(sub_members),
"members": json.dumps(sub_members),
"top_controls_count": 0,
})
next_gid += 1
total_sub += 1
groups_to_delete.append(og_id)
logger.info(" %s (%s, %d members) → %d sub-groups (k=%d)",
mc_id, canonical, len(members), len(subs), k)
logger.info("Refinement: %d groups → %d sub-groups", len(groups_to_delete), total_sub)
# Step 3: Update DB — replace old object_groups, delete old master_controls
with engine.begin() as c:
c.execute(text("SET search_path TO compliance, public"))
# Delete old master controls and their members for affected groups
for og_id in groups_to_delete:
c.execute(text("""
DELETE FROM master_control_members
WHERE master_control_uuid IN (
SELECT id FROM master_controls WHERE object_group_id = :gid
)
"""), {"gid": og_id})
c.execute(text("DELETE FROM master_controls WHERE object_group_id = :gid"), {"gid": og_id})
c.execute(text("DELETE FROM object_groups WHERE group_id = :gid"), {"gid": og_id})
# Insert new sub-groups
for g in new_groups:
c.execute(text("""
INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count)
VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count)
"""), g)
logger.info("DB updated: %d old groups deleted, %d new groups inserted", len(groups_to_delete), len(new_groups))
# Step 4: Re-run master control generation for affected groups
logger.info("Re-generating master controls for new sub-groups...")
_regenerate_master_controls(engine, [g["group_id"] for g in new_groups])
# Final stats
with engine.connect() as c:
mc_count = c.execute(text("SELECT count(*) FROM master_controls")).scalar()
og_count = c.execute(text("SELECT count(*) FROM object_groups")).scalar()
large = c.execute(text("SELECT count(*) FROM master_controls WHERE total_controls > 200")).scalar()
logger.info("Final: %d master controls, %d object groups, %d still >200", mc_count, og_count, large)
def _regenerate_master_controls(engine, group_ids: list[int]):
"""Re-create master controls for specific object_group_ids."""
from collections import defaultdict
from services.control_dedup import normalize_object
# Build reverse index for new groups only
object_to_group = {}
with engine.connect() as c:
for gid in group_ids:
row = c.execute(text(
"SELECT group_id, canonical_name, members FROM object_groups WHERE group_id = :gid"
), {"gid": gid}).fetchone()
if row:
members = json.loads(row[2]) if isinstance(row[2], str) else row[2]
for m in members:
object_to_group[m] = (row[0], row[1])
# Load controls for these objects
with engine.connect() as c:
rows = c.execute(text("""
SELECT id, control_id, generation_metadata->>'merge_group_hint' AS hint
FROM canonical_controls
WHERE generation_metadata->>'merge_group_hint' IS NOT NULL
AND release_state NOT IN ('deprecated', 'rejected')
""")).fetchall()
group_phases: dict[int, dict[str, list]] = defaultdict(lambda: defaultdict(list))
group_names: dict[int, str] = {}
for uuid, control_id, hint in rows:
parts = hint.split(":", 2)
if len(parts) < 2:
continue
action, obj = parts[0], parts[1]
phase = parts[2] if len(parts) > 2 else "implementation"
normed = normalize_object(obj)
if normed in object_to_group:
gid, canonical = object_to_group[normed]
elif obj in object_to_group:
gid, canonical = object_to_group[obj]
else:
continue
group_phases[gid][phase].append((str(uuid), control_id, action))
group_names[gid] = canonical
# Create master controls
mc_count = 0
mem_count = 0
with engine.begin() as c:
c.execute(text("SET search_path TO compliance, public"))
for gid, phases in group_phases.items():
if len(phases) < 2:
continue
mc_id = "MC-%d" % gid
canonical = group_names.get(gid, "unknown")
sorted_phases = sorted(phases.keys())
phase_counts = {p: len(ctrls) for p, ctrls in phases.items()}
total = sum(phase_counts.values())
c.execute(text("""
INSERT INTO master_controls
(master_control_id, object_group_id, canonical_name,
phases_covered, phase_control_count, total_controls)
VALUES (:mcid, :gid, :name,
CAST(:phases AS jsonb), CAST(:pcounts AS jsonb), :total)
"""), {
"mcid": mc_id, "gid": gid, "name": canonical,
"phases": json.dumps(sorted_phases),
"pcounts": json.dumps(phase_counts),
"total": total,
})
mc_uuid = c.execute(text(
"SELECT id FROM master_controls WHERE master_control_id = :mcid"
), {"mcid": mc_id}).scalar()
for phase, controls in phases.items():
for ctrl_uuid, ctrl_id, action in controls:
c.execute(text("""
INSERT INTO master_control_members
(master_control_uuid, control_uuid, phase, action)
VALUES (CAST(:mc AS uuid), CAST(:ctrl AS uuid), :phase, :action)
"""), {"mc": str(mc_uuid), "ctrl": ctrl_uuid, "phase": phase, "action": action})
mem_count += 1
mc_count += 1
logger.info("Created %d new master controls with %d members", mc_count, mem_count)
def _embed_texts(texts: list[str]) -> np.ndarray | None:
"""Embed texts with retry logic."""
try:
result = np.zeros((len(texts), 1024), dtype=np.float32)
batch_size = 64
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
for attempt in range(3):
try:
with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
resp = client.post(f"{EMBEDDING_URL}/embed", json={"texts": batch})
resp.raise_for_status()
embs = resp.json().get("embeddings", [])
end = min(i + len(embs), len(texts))
result[i:end] = np.array(embs, dtype=np.float32)
break
except Exception as e:
if attempt == 2:
logger.error("Embed batch %d failed: %s", i, e)
import time
time.sleep(2)
return result
except Exception as e:
logger.error("Embedding failed: %s", e)
return None
if __name__ == "__main__":
main()