#!/usr/bin/env python3 """ G-pre1: Object Clustering via Mini-Batch K-Means on Embeddings. Clusters ~144k unique normalized objects into ~15-25k semantic groups using bge-m3 embeddings and Mini-Batch K-Means. Usage (inside control-pipeline container): python3 /app/scripts/gpre1_object_clustering.py --k 20000 python3 /app/scripts/gpre1_object_clustering.py --k 20000 --dry-run """ import argparse import json import logging import sys import time from collections import Counter import httpx import numpy as np from sklearn.cluster import MiniBatchKMeans from sqlalchemy import create_engine, text logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("gpre1") import os DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db") EMBEDDING_URL = "http://embedding-service:8087" BATCH_SIZE = 64 # Embeddings per API call def extract_objects(engine) -> tuple[list[str], dict[str, int]]: """Extract unique normalized objects and their frequencies.""" from services.control_dedup import normalize_object logger.info("Extracting objects from canonical_controls...") with engine.connect() as c: rows = c.execute(text(""" SELECT split_part(generation_metadata->>'merge_group_hint', ':', 2) AS obj, count(*) AS freq FROM canonical_controls WHERE generation_metadata->>'merge_group_hint' IS NOT NULL AND generation_metadata->>'merge_group_hint' != '' GROUP BY 1 """)).fetchall() # Normalize and aggregate norm_freq: Counter = Counter() norm_to_raw: dict[str, list[str]] = {} for raw_obj, freq in rows: if not raw_obj or not raw_obj.strip(): continue normed = normalize_object(raw_obj) norm_freq[normed] += freq norm_to_raw.setdefault(normed, []).append(raw_obj) objects = list(norm_freq.keys()) freqs = {obj: norm_freq[obj] for obj in objects} logger.info("Extracted %d unique normalized objects (from %d raw)", len(objects), len(rows)) return objects, freqs def generate_embeddings(objects: list[str]) -> np.ndarray: """Generate embeddings via embedding-service in batches. Uses pre-allocated numpy array to avoid Python list memory overhead (Python float = 28 bytes vs numpy float32 = 4 bytes). """ total = len(objects) # Pre-allocate: 144k × 1024 × 4 bytes = ~590 MB (vs ~4 GB with Python lists) result = np.zeros((total, 1024), dtype=np.float32) logger.info("Generating embeddings for %d objects (pre-allocated %.0f MB)...", total, result.nbytes / 1024 / 1024) failed_batches = [] for i in range(0, total, BATCH_SIZE): batch = objects[i:i + BATCH_SIZE] success = False for attempt in range(3): # Max 3 retries per batch try: with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client: resp = client.post( f"{EMBEDDING_URL}/embed", json={"texts": batch}, ) resp.raise_for_status() embeddings = resp.json().get("embeddings", []) end = min(i + len(embeddings), total) result[i:end] = np.array(embeddings, dtype=np.float32) success = True break except Exception as e: if attempt < 2: logger.warning("Batch %d attempt %d failed: %s — retrying", i, attempt + 1, e) import time time.sleep(2) else: logger.error("Batch %d failed after 3 attempts: %s", i, e) failed_batches.append(i) if (i + BATCH_SIZE) % 5000 == 0 or i + BATCH_SIZE >= total: logger.info(" Embedded %d/%d (%.1f%%) [%d failed]", min(i + BATCH_SIZE, total), total, min(i + BATCH_SIZE, total) / total * 100, len(failed_batches)) return result def cluster_objects(embeddings: np.ndarray, k: int) -> np.ndarray: """Run Mini-Batch K-Means clustering.""" logger.info("Clustering %d objects into %d groups (Mini-Batch K-Means)...", len(embeddings), k) # Normalize embeddings for cosine-like clustering norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms[norms == 0] = 1 normalized = embeddings / norms kmeans = MiniBatchKMeans( n_clusters=k, batch_size=1000, max_iter=100, random_state=42, verbose=0, ) labels = kmeans.fit_predict(normalized) logger.info("Clustering done. Inertia: %.2f", kmeans.inertia_) return labels def store_results(engine, objects: list[str], freqs: dict[str, int], labels: np.ndarray, dry_run: bool): """Store clustering results in object_groups table.""" # Build groups groups: dict[int, list[tuple[str, int]]] = {} for i, obj in enumerate(objects): gid = int(labels[i]) groups.setdefault(gid, []).append((obj, freqs.get(obj, 0))) # Pick canonical name (highest frequency in group) results = [] for gid, members in groups.items(): members_sorted = sorted(members, key=lambda x: -x[1]) canonical = members_sorted[0][0] results.append({ "group_id": gid, "canonical_name": canonical, "member_count": len(members), "members": json.dumps([m[0] for m in members_sorted]), "top_controls_count": members_sorted[0][1], }) # Stats sizes = [r["member_count"] for r in results] logger.info("Groups: %d total", len(results)) logger.info(" Singletons: %d", sum(1 for s in sizes if s == 1)) logger.info(" Groups 2-5: %d", sum(1 for s in sizes if 2 <= s <= 5)) logger.info(" Groups 6-20: %d", sum(1 for s in sizes if 6 <= s <= 20)) logger.info(" Groups 21-100: %d", sum(1 for s in sizes if 21 <= s <= 100)) logger.info(" Groups >100: %d", sum(1 for s in sizes if s > 100)) logger.info(" Max group size: %d", max(sizes)) logger.info(" Avg group size: %.1f", sum(sizes) / len(sizes)) # Top 10 largest groups top10 = sorted(results, key=lambda x: -x["member_count"])[:10] logger.info("\nTop 10 largest groups:") for g in top10: members_list = json.loads(g["members"]) logger.info(" [%d] %s (%d members): %s", g["group_id"], g["canonical_name"], g["member_count"], ", ".join(members_list[:5])) if dry_run: logger.info("DRY RUN — not writing to DB") return # Write to DB with engine.begin() as conn: conn.execute(text("SET search_path TO compliance, public")) conn.execute(text("DELETE FROM object_groups")) # Clear old results for r in results: conn.execute(text(""" INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count) VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count) """), r) logger.info("Wrote %d groups to object_groups table", len(results)) def main(): parser = argparse.ArgumentParser() parser.add_argument("--k", type=int, default=20000, help="Number of clusters") parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"}) # Step 1: Extract objects, freqs = extract_objects(engine) # Step 2: Embed embeddings = generate_embeddings(objects) logger.info("Embedding matrix: %s (%.1f MB)", embeddings.shape, embeddings.nbytes / 1024 / 1024) # Adjust k if we have fewer objects k = min(args.k, len(objects) // 2) logger.info("Using k=%d (requested %d, objects=%d)", k, args.k, len(objects)) # Step 3: Cluster labels = cluster_objects(embeddings, k) # Step 4: Store store_results(engine, objects, freqs, labels, args.dry_run) if __name__ == "__main__": main()