#!/usr/bin/env python3 """ G-pre1 Step 2: Sub-cluster large object groups (>50 members) into k=4 sub-groups. Reads existing object_groups, re-embeds members of large groups, applies K-Means with k=4 per group, and writes sub-groups back. Usage (inside container or with PYTHONPATH): python3 /app/scripts/gpre1_subcluster.py python3 /app/scripts/gpre1_subcluster.py --min-size 100 # only groups >100 python3 /app/scripts/gpre1_subcluster.py --sub-k 6 # 6 sub-clusters """ import argparse import json import logging import os import httpx import numpy as np from sklearn.cluster import MiniBatchKMeans from sqlalchemy import create_engine, text logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("gpre1-sub") DB_URL = os.getenv("DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db") EMBEDDING_URL = "http://embedding-service:8087" def main(): parser = argparse.ArgumentParser() parser.add_argument("--min-size", type=int, default=50, help="Min group size to sub-cluster") parser.add_argument("--sub-k", type=int, default=4, help="Sub-clusters per group") parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() engine = create_engine(DB_URL, connect_args={"options": "-c search_path=compliance,public"}) # Load large groups with engine.connect() as c: groups = c.execute(text( "SELECT group_id, canonical_name, member_count, members " "FROM object_groups WHERE member_count > :min ORDER BY member_count DESC" ), {"min": args.min_size}).fetchall() logger.info("Found %d groups with >%d members to sub-cluster", len(groups), args.min_size) # Find next available group_id with engine.connect() as c: max_gid = c.execute(text("SELECT COALESCE(MAX(group_id), 0) FROM object_groups")).scalar() next_gid = max_gid + 1 total_sub_groups = 0 all_new_rows = [] groups_to_delete = [] for group_id, canonical_name, member_count, members_json in groups: members = json.loads(members_json) if isinstance(members_json, str) else members_json if len(members) < args.sub_k * 2: logger.info(" Skip group %d (%s, %d members) — too small for k=%d", group_id, canonical_name, len(members), args.sub_k) continue # Embed members embeddings = _embed_batch(members) if embeddings is None: logger.error(" Failed to embed group %d (%s)", group_id, canonical_name) continue # Normalize for cosine norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms[norms == 0] = 1 normalized = embeddings / norms # Sub-cluster k = min(args.sub_k, len(members) // 2) kmeans = MiniBatchKMeans(n_clusters=k, batch_size=min(100, len(members)), max_iter=50, random_state=42) labels = kmeans.fit_predict(normalized) # Build sub-groups sub_groups: dict[int, list[str]] = {} for i, member in enumerate(members): sub_groups.setdefault(int(labels[i]), []).append(member) # Create new rows for sub_id, sub_members in sub_groups.items(): sub_canonical = sub_members[0] # Most frequent would be better but we don't have freq here all_new_rows.append({ "group_id": next_gid, "canonical_name": sub_canonical, "member_count": len(sub_members), "members": json.dumps(sub_members), "top_controls_count": 0, "parent_group_id": group_id, }) next_gid += 1 groups_to_delete.append(group_id) total_sub_groups += len(sub_groups) if len(groups_to_delete) % 50 == 0: logger.info(" Processed %d/%d groups, %d sub-groups created", len(groups_to_delete), len(groups), total_sub_groups) logger.info("Sub-clustering complete: %d groups → %d sub-groups", len(groups_to_delete), total_sub_groups) # Stats sub_sizes = [r["member_count"] for r in all_new_rows] if sub_sizes: logger.info(" Sub-group sizes: avg=%.1f, max=%d, min=%d", sum(sub_sizes) / len(sub_sizes), max(sub_sizes), min(sub_sizes)) if args.dry_run: logger.info("DRY RUN — not writing to DB") for r in all_new_rows[:10]: logger.info(" [%d] %s (%d members)", r["group_id"], r["canonical_name"], r["member_count"]) return # Write to DB: delete old large groups, insert sub-groups with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) # Delete old large groups for gid in groups_to_delete: c.execute(text("DELETE FROM object_groups WHERE group_id = :gid"), {"gid": gid}) # Insert sub-groups for r in all_new_rows: c.execute(text(""" INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count) VALUES (:group_id, :canonical_name, :member_count, CAST(:members AS jsonb), :top_controls_count) """), r) logger.info("Wrote %d sub-groups to DB (replaced %d large groups)", len(all_new_rows), len(groups_to_delete)) # Final stats with engine.connect() as c: total = c.execute(text("SELECT count(*) FROM object_groups")).scalar() logger.info("Total groups in DB: %d", total) def _embed_batch(texts: list[str]) -> np.ndarray | None: """Embed a list of texts, return numpy array.""" try: all_emb = np.zeros((len(texts), 1024), dtype=np.float32) batch_size = 64 for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] with httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) as client: resp = client.post(f"{EMBEDDING_URL}/embed", json={"texts": batch}) resp.raise_for_status() embs = resp.json().get("embeddings", []) end = min(i + len(embs), len(texts)) all_emb[i:end] = np.array(embs, dtype=np.float32) return all_emb except Exception as e: logger.error("Embedding failed: %s", e) return None if __name__ == "__main__": main()