#!/usr/bin/env python3 """ G-pre3: Split large Master Controls by regulation source. For each MC with >200 controls: 1. Load member controls with parent's source_citation->>'source' 2. Group by regulation source 3. Sources with >= MIN_SOURCE_SIZE → new sub-MC 4. Small sources → merge into "mixed" bucket 5. UNKNOWN (no source_citation) → sub-cluster by embedding if >MAX_MC 6. Delete original large MC, create new sub-MCs Usage: python3 /app/scripts/gpre3_regulation_split.py --dry-run python3 /app/scripts/gpre3_regulation_split.py --min-source 15 --max-mc 100 """ import argparse import json import logging import os import re from collections import defaultdict from sqlalchemy import create_engine, text from services.embedding_utils import subcluster_controls logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger("gpre3") DB_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db", ) # ── Source key normalization ──────────────────────────────────────── # fmt: off _SOURCE_SHORT: dict[str, str] = { "DSGVO (EU) 2016/679": "dsgvo", "NIS2-Richtlinie (EU) 2022/2555": "nis2", "KI-Verordnung (EU) 2024/1689": "ai_act", "Cyber Resilience Act (CRA)": "cra", "Digital Services Act (DSA)": "dsa", "Digital Markets Act (DMA)": "dma", "Digital Operational Resilience Act": "dora", "Data Governance Act (DGA)": "dga", "Data Act": "data_act", "Maschinenverordnung (EU) 2023/1230": "machinery_reg", "Medizinprodukteverordnung (EU) 2017/745 (MDR)": "mdr", "European Health Data Space": "ehds", "European Accessibility Act": "eaa", "EU Cybersecurity Act": "eu_csa", "EU Blue Guide 2022": "eu_blue_guide", "EU-US Data Privacy Framework": "eu_us_dpf", "Markets in Crypto-Assets (MiCA)": "mica", "Standardvertragsklauseln (SCC)": "scc", "ePrivacy-Richtlinie": "eprivacy", "Batterieverordnung (EU) 2023/1542": "battery_reg", "Bundesdatenschutzgesetz (BDSG)": "bdsg", "BSI-Gesetz (BSIG 2025, NIS2-Umsetzung)": "bsig", "BSI-Kritisverordnung (BSI-KritisV)": "bsi_kritisv", "Geldwaeschegesetz (GwG)": "gwg", "Hinweisgeberschutzgesetz (HinSchG)": "hinschg", "Lieferkettensorgfaltspflichtengesetz (LkSG)": "lksg", "KRITIS-Dachgesetz (KRITISDachG)": "kritisdachg", "NIST SP 800-53 Rev. 5": "nist_800_53", "NIST Cybersecurity Framework 2.0": "nist_csf", "NIST Privacy Framework 1.0": "nist_privacy", "NIST SP 800-207 (Zero Trust)": "nist_zero_trust", "NIST SP 800-218 (SSDF)": "nist_ssdf", "NIST SP 800-63-3": "nist_800_63", "NIST AI Risk Management Framework": "nist_ai_rmf", "NISTIR 8259A IoT Security": "nist_iot", "OWASP Top 10 (2021)": "owasp_top10", "OWASP API Security Top 10 (2023)": "owasp_api", "OWASP ASVS 4.0": "owasp_asvs", "OWASP SAMM 2.0": "owasp_samm", "OWASP MASVS 2.0": "owasp_masvs", "OWASP Mobile Top 10": "owasp_mobile", "ENISA": "enisa", "TDDDG": "tdddg", "TKG": "tkg", "TMG": "tmg", "BGB": "bgb", "UWG": "uwg", "UrhG": "urhg", "BAIT (BaFin 2024)": "bait", "VAIT (BaFin 2022)": "vait", "AML-Verordnung": "aml_reg", "Zahlungsdiensterichtlinie 2": "psd2", "Telekommunikationsgesetz Oesterreich": "at_tkg", "Österreichisches Datenschutzgesetz (DSG)": "at_dsg", "Allgemeines Gleichbehandlungsgesetz (AGG)": "agg", "Aktiengesetz (AktG)": "aktg", "Handelsgesetzbuch (HGB)": "hgb", "GmbH-Gesetz (GmbHG)": "gmbhg", "Insolvenzordnung (InsO)": "inso", "Gewerbeordnung (GewO)": "gewo", "Abgabenordnung (AO)": "ao", } # fmt: on def source_to_key(source: str) -> str: """Normalize regulation source name to a short slug key.""" if source in _SOURCE_SHORT: return _SOURCE_SHORT[source] s = source.lower() s = re.sub(r"\(.*?\)", "", s) s = re.sub(r"[^a-z0-9äöüß]+", "_", s) s = re.sub(r"_+", "_", s).strip("_") return s[:40] if s else "unknown" # ── Main ─────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--min-source", type=int, default=15, help="Min controls per source for own sub-MC") parser.add_argument("--max-mc", type=int, default=100, help="Max controls per sub-MC before sub-clustering") parser.add_argument("--threshold", type=int, default=200, help="Only split MCs with more than N controls") parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() engine = create_engine( DB_URL, connect_args={"options": "-c search_path=compliance,public"} ) # Step 1: Find large master controls with engine.connect() as c: large_mcs = c.execute(text(""" SELECT mc.id, mc.master_control_id, mc.object_group_id, mc.canonical_name, mc.total_controls FROM master_controls mc WHERE mc.total_controls > :threshold ORDER BY mc.total_controls DESC """), {"threshold": args.threshold}).fetchall() logger.info("Found %d MCs with >%d controls", len(large_mcs), args.threshold) if not large_mcs: return # Step 2: Build split plans all_splits = [] for mc_uuid, mc_id, og_id, canonical, total in large_mcs: plan = _build_split_plan(engine, mc_uuid, mc_id, og_id, canonical, total, args) all_splits.append(plan) total_new = sum(len(sp["sub_groups"]) for sp in all_splits) total_covered = sum( sum(len(sg["controls"]) for sg in sp["sub_groups"]) for sp in all_splits ) logger.info("SUMMARY: %d large MCs → %d sub-MCs (%d controls)", len(all_splits), total_new, total_covered) if args.dry_run: logger.info("DRY RUN — not writing to DB") return _write_splits(engine, all_splits) def _build_split_plan(engine, mc_uuid, mc_id, og_id, canonical, total, args) -> dict: """Build a regulation-source split plan for one large MC.""" logger.info("\n━━━ %s: %s (%d controls) ━━━", mc_id, canonical, total) with engine.connect() as c: members = c.execute(text(""" SELECT mcm.control_uuid, mcm.phase, mcm.action, cc.control_id, cc.title, COALESCE(pc.source_citation->>'source', 'UNKNOWN') AS src FROM master_control_members mcm JOIN canonical_controls cc ON cc.id = mcm.control_uuid LEFT JOIN canonical_controls pc ON pc.id = cc.parent_control_uuid WHERE mcm.master_control_uuid = CAST(:mc_uuid AS uuid) """), {"mc_uuid": str(mc_uuid)}).fetchall() by_source: dict[str, list[dict]] = defaultdict(list) for ctrl_uuid, phase, action, cid, title, src in members: by_source[src].append({ "control_uuid": str(ctrl_uuid), "phase": phase, "action": action, "control_id": cid, "title": title, }) sorted_sources = sorted(by_source.items(), key=lambda x: -len(x[1])) for src, ctrls in sorted_sources[:8]: logger.info(" %4d %s", len(ctrls), src) if len(sorted_sources) > 8: logger.info(" ... +%d more sources", len(sorted_sources) - 8) plan = {"mc_uuid": str(mc_uuid), "mc_id": mc_id, "og_id": og_id, "canonical": canonical, "total": total, "sub_groups": []} own_mc_sources = [] mixed_controls = [] for src, ctrls in sorted_sources: if src == "UNKNOWN": continue if len(ctrls) >= args.min_source: own_mc_sources.append((src, ctrls)) else: mixed_controls.extend(ctrls) unknown_controls = by_source.get("UNKNOWN", []) # (a) Named regulation sub-MCs for src, ctrls in own_mc_sources: key = source_to_key(src) name = f"{canonical}_{key}" _add_subgroups(plan, name, src, ctrls, args.max_mc) # (b) Mixed small-source bucket if mixed_controls: _add_subgroups(plan, f"{canonical}_mixed", "mixed", mixed_controls, args.max_mc) # (c) UNKNOWN bucket if unknown_controls: _add_subgroups(plan, f"{canonical}_general", "general", unknown_controls, args.max_mc) logger.info(" → %d sub-groups:", len(plan["sub_groups"])) for sg in sorted(plan["sub_groups"], key=lambda x: -len(x["controls"])): logger.info(" %4d %s", len(sg["controls"]), sg["name"]) return plan def _add_subgroups(plan: dict, name: str, source: str, controls: list[dict], max_mc: int): """Add controls as one or more sub-groups to the plan.""" if len(controls) <= max_mc: plan["sub_groups"].append({"name": name, "source": source, "controls": controls}) else: clusters = subcluster_controls(controls, max_mc) for i, cluster in enumerate(clusters): sub_name = f"{name}_{i+1}" if len(clusters) > 1 else name plan["sub_groups"].append({"name": sub_name, "source": source, "controls": cluster}) def _write_splits(engine, splits: list[dict]): """Apply split plan: delete old MCs, create new object_groups + MCs.""" with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) max_gid = c.execute( text("SELECT COALESCE(MAX(group_id), 0) FROM object_groups") ).scalar() next_gid = max_gid + 1 total_mc = 0 total_mem = 0 for sp in splits: c.execute(text( "DELETE FROM master_control_members " "WHERE master_control_uuid = CAST(:u AS uuid)" ), {"u": sp["mc_uuid"]}) c.execute(text( "DELETE FROM master_controls WHERE id = CAST(:u AS uuid)" ), {"u": sp["mc_uuid"]}) logger.info("Deleted %s (%s)", sp["mc_id"], sp["canonical"]) for sg in sp["sub_groups"]: if not sg["controls"]: continue gid = next_gid next_gid += 1 members_list = list({ctrl["control_id"] for ctrl in sg["controls"]}) c.execute(text(""" INSERT INTO object_groups (group_id, canonical_name, member_count, members, top_controls_count) VALUES (:gid, :name, :cnt, CAST(:members AS jsonb), 0) """), {"gid": gid, "name": sg["name"], "cnt": len(members_list), "members": json.dumps(members_list)}) by_phase: dict[str, list[dict]] = defaultdict(list) for ctrl in sg["controls"]: by_phase[ctrl["phase"]].append(ctrl) sorted_phases = sorted(by_phase.keys()) phase_counts = {p: len(v) for p, v in by_phase.items()} mc_id = f"MC-{gid}" c.execute(text(""" INSERT INTO master_controls (master_control_id, object_group_id, canonical_name, phases_covered, phase_control_count, total_controls) VALUES (:mcid, :gid, :name, CAST(:phases AS jsonb), CAST(:pcounts AS jsonb), :total) """), {"mcid": mc_id, "gid": gid, "name": sg["name"], "phases": json.dumps(sorted_phases), "pcounts": json.dumps(phase_counts), "total": sum(phase_counts.values())}) mc_uuid = c.execute(text( "SELECT id FROM master_controls WHERE master_control_id = :mcid" ), {"mcid": mc_id}).scalar() for ctrl in sg["controls"]: c.execute(text(""" INSERT INTO master_control_members (master_control_uuid, control_uuid, phase, action) VALUES (CAST(:mc AS uuid), CAST(:ctrl AS uuid), :phase, :action) """), {"mc": str(mc_uuid), "ctrl": ctrl["control_uuid"], "phase": ctrl["phase"], "action": ctrl["action"]}) total_mem += 1 total_mc += 1 logger.info("Created %d new MCs with %d members", total_mc, total_mem) with engine.connect() as c: stats = c.execute(text(""" SELECT count(*), count(CASE WHEN total_controls > 200 THEN 1 END), AVG(total_controls)::int FROM compliance.master_controls """)).fetchone() logger.info("Final: %d MCs, %d still >200, avg %d controls/MC", stats[0], stats[1], stats[2]) if __name__ == "__main__": main()