Files
breakpilot-core/control-pipeline/scripts/gpre1_object_clustering.py
T
Benjamin Admin ad24835940 feat(pipeline): G-pre1/2/3 — Object Clustering + Master Controls + API
G-pre1: 144k objects clustered into 7,466 groups via Mini-Batch K-Means
  on bge-m3 embeddings. Two-stage: k=5000 base + sub-cluster groups >50.
G-pre2: 5,114 Master Controls from lifecycle phase chains
  (define→implement→test→monitor), linking 172,504 atomic controls.
G-pre3: REST API for Master Controls
  GET /v1/master-controls (list, search, filter)
  GET /v1/master-controls/stats
  GET /v1/master-controls/{mc_id} (detail with phase-controls)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-06 15:11:38 +02:00

220 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
G-pre1: Object Clustering via Mini-Batch K-Means on Embeddings.
Clusters ~144k unique normalized objects into ~15-25k semantic groups
using bge-m3 embeddings and Mini-Batch K-Means.
Usage (inside control-pipeline container):
python3 /app/scripts/gpre1_object_clustering.py --k 20000
python3 /app/scripts/gpre1_object_clustering.py --k 20000 --dry-run
"""
import argparse
import json
import logging
import sys
import time
from collections import Counter
import httpx
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import create_engine, text
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("gpre1")
import os
DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db")
EMBEDDING_URL = "http://embedding-service:8087"
BATCH_SIZE = 64 # Embeddings per API call
def extract_objects(engine) -> tuple[list[str], dict[str, int]]:
"""Extract unique normalized objects and their frequencies."""
from services.control_dedup import normalize_object
logger.info("Extracting objects from canonical_controls...")
with engine.connect() as c:
rows = c.execute(text("""
SELECT split_part(generation_metadata->>'merge_group_hint', ':', 2) AS obj,
count(*) AS freq
FROM canonical_controls
WHERE generation_metadata->>'merge_group_hint' IS NOT NULL
AND generation_metadata->>'merge_group_hint' != ''
GROUP BY 1
""")).fetchall()
# Normalize and aggregate
norm_freq: Counter = Counter()
norm_to_raw: dict[str, list[str]] = {}
for raw_obj, freq in rows:
if not raw_obj or not raw_obj.strip():
continue
normed = normalize_object(raw_obj)
norm_freq[normed] += freq
norm_to_raw.setdefault(normed, []).append(raw_obj)
objects = list(norm_freq.keys())
freqs = {obj: norm_freq[obj] for obj in objects}
logger.info("Extracted %d unique normalized objects (from %d raw)", len(objects), len(rows))
return objects, freqs
def generate_embeddings(objects: list[str]) -> np.ndarray:
"""Generate embeddings via embedding-service in batches.
Uses pre-allocated numpy array to avoid Python list memory overhead
(Python float = 28 bytes vs numpy float32 = 4 bytes).
"""
total = len(objects)
# Pre-allocate: 144k × 1024 × 4 bytes = ~590 MB (vs ~4 GB with Python lists)
result = np.zeros((total, 1024), dtype=np.float32)
logger.info("Generating embeddings for %d objects (pre-allocated %.0f MB)...",
total, result.nbytes / 1024 / 1024)
failed_batches = []
for i in range(0, total, BATCH_SIZE):
batch = objects[i:i + BATCH_SIZE]
success = False
for attempt in range(3): # Max 3 retries per batch
try:
with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
resp = client.post(
f"{EMBEDDING_URL}/embed",
json={"texts": batch},
)
resp.raise_for_status()
embeddings = resp.json().get("embeddings", [])
end = min(i + len(embeddings), total)
result[i:end] = np.array(embeddings, dtype=np.float32)
success = True
break
except Exception as e:
if attempt < 2:
logger.warning("Batch %d attempt %d failed: %s — retrying", i, attempt + 1, e)
import time
time.sleep(2)
else:
logger.error("Batch %d failed after 3 attempts: %s", i, e)
failed_batches.append(i)
if (i + BATCH_SIZE) % 5000 == 0 or i + BATCH_SIZE >= total:
logger.info(" Embedded %d/%d (%.1f%%) [%d failed]",
min(i + BATCH_SIZE, total), total,
min(i + BATCH_SIZE, total) / total * 100,
len(failed_batches))
return result
def cluster_objects(embeddings: np.ndarray, k: int) -> np.ndarray:
"""Run Mini-Batch K-Means clustering."""
logger.info("Clustering %d objects into %d groups (Mini-Batch K-Means)...", len(embeddings), k)
# Normalize embeddings for cosine-like clustering
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
normalized = embeddings / norms
kmeans = MiniBatchKMeans(
n_clusters=k,
batch_size=1000,
max_iter=100,
random_state=42,
verbose=0,
)
labels = kmeans.fit_predict(normalized)
logger.info("Clustering done. Inertia: %.2f", kmeans.inertia_)
return labels
def store_results(engine, objects: list[str], freqs: dict[str, int],
labels: np.ndarray, dry_run: bool):
"""Store clustering results in object_groups table."""
# Build groups
groups: dict[int, list[tuple[str, int]]] = {}
for i, obj in enumerate(objects):
gid = int(labels[i])
groups.setdefault(gid, []).append((obj, freqs.get(obj, 0)))
# Pick canonical name (highest frequency in group)
results = []
for gid, members in groups.items():
members_sorted = sorted(members, key=lambda x: -x[1])
canonical = members_sorted[0][0]
results.append({
"group_id": gid,
"canonical_name": canonical,
"member_count": len(members),
"members": json.dumps([m[0] for m in members_sorted]),
"top_controls_count": members_sorted[0][1],
})
# Stats
sizes = [r["member_count"] for r in results]
logger.info("Groups: %d total", len(results))
logger.info(" Singletons: %d", sum(1 for s in sizes if s == 1))
logger.info(" Groups 2-5: %d", sum(1 for s in sizes if 2 <= s <= 5))
logger.info(" Groups 6-20: %d", sum(1 for s in sizes if 6 <= s <= 20))
logger.info(" Groups 21-100: %d", sum(1 for s in sizes if 21 <= s <= 100))
logger.info(" Groups >100: %d", sum(1 for s in sizes if s > 100))
logger.info(" Max group size: %d", max(sizes))
logger.info(" Avg group size: %.1f", sum(sizes) / len(sizes))
# Top 10 largest groups
top10 = sorted(results, key=lambda x: -x["member_count"])[:10]
logger.info("\nTop 10 largest groups:")
for g in top10:
members_list = json.loads(g["members"])
logger.info(" [%d] %s (%d members): %s",
g["group_id"], g["canonical_name"], g["member_count"],
", ".join(members_list[:5]))
if dry_run:
logger.info("DRY RUN — not writing to DB")
return
# Write to DB
with engine.begin() as conn:
conn.execute(text("SET search_path TO compliance, public"))
conn.execute(text("DELETE FROM object_groups")) # Clear old results
for r in results:
conn.execute(text("""
INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count)
VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count)
"""), r)
logger.info("Wrote %d groups to object_groups table", len(results))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--k", type=int, default=20000, help="Number of clusters")
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"})
# Step 1: Extract
objects, freqs = extract_objects(engine)
# Step 2: Embed
embeddings = generate_embeddings(objects)
logger.info("Embedding matrix: %s (%.1f MB)", embeddings.shape,
embeddings.nbytes / 1024 / 1024)
# Adjust k if we have fewer objects
k = min(args.k, len(objects) // 2)
logger.info("Using k=%d (requested %d, objects=%d)", k, args.k, len(objects))
# Step 3: Cluster
labels = cluster_objects(embeddings, k)
# Step 4: Store
store_results(engine, objects, freqs, labels, args.dry_run)
if __name__ == "__main__":
main()