#!/usr/bin/env python3 """ Add L2 sub-topics to broad tokens. Instead of just "incident", produces "incident:response", "incident:detection", etc. Only processes tokens with >500 controls AND <90% audit accuracy. Usage: python3 /app/scripts/gpre0_add_subtopics.py --dry-run python3 /app/scripts/gpre0_add_subtopics.py """ import argparse import json import logging import os import time from collections import defaultdict from pathlib import Path import httpx from sqlalchemy import create_engine, text logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger("gpre0-subtopics") DB_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db", ) ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_MODEL = "claude-haiku-4-5-20251001" ANTHROPIC_URL = "https://api.anthropic.com/v1/messages" CHECKPOINT_DIR = Path("/tmp/gpre0_subtopic_checkpoints") # Tokens that are too broad — need L2 sub-topics BROAD_TOKENS = { # Round 1 (already done) "risk_management", "policy", "audit_logging", "incident", "access_control", "compliance_audit", "asset_management", "key_management", "third_party_management", "monitoring", "financial_reporting", "data_classification", "change_management", "alerting", "multi_factor_auth", "api_security", "certificate_management", "human_resources_security", "training", "data_processing_agreement", "data_processing_register", "consumer_protection", "input_validation", "vulnerability", "dpia", "data_breach_notification", "backup", "supply_chain_due_diligence", "awareness", "privacy_by_design", "credentials", "logging_configuration", # Round 2 (remaining large tokens) "supervisory_authority", "certification", "secure_development", "product_safety", "personal_data", "data_subject_rights", "consent", "ai_system", "encryption", "data_retention", "disaster_recovery", "data_transfer", "aml", "transport_encryption", "network_security", "physical_security", "medical_device", "patch_management", "cookie_consent", "video_surveillance", "network_segmentation", "telecommunications", "privileged_access", "session_management", "password_policy", "governance", "whistleblowing", "payment_services", "health_data", "sensitive_data", "ecommerce", "sustainability_reporting", "critical_infrastructure", "regulatory", } SYSTEM_PROMPT = """Du bist ein Compliance-Spezialist. Jeder Control hat bereits ein Hauptthema (L1 Token). Deine Aufgabe: Bestimme ein SPEZIFISCHES Sub-Thema (L2) innerhalb des Hauptthemas. Das L2 Sub-Thema soll den KONKRETEN Aspekt beschreiben. Verwende kurze, klare englische Bezeichnungen. Beispiele: - L1=incident, Titel="Incident Response Plan erstellen" → L2="response_plan" - L1=incident, Titel="Sicherheitsvorfälle erkennen" → L2="detection" - L1=incident, Titel="Recovery nach Vorfall dokumentieren" → L2="recovery" - L1=incident, Titel="Forensische Analyse durchführen" → L2="forensics" - L1=risk_management, Titel="Risikobewertung durchführen" → L2="assessment" - L1=risk_management, Titel="Risikominderungsmaßnahmen umsetzen" → L2="treatment" - L1=risk_management, Titel="Restrisiko akzeptieren" → L2="acceptance" - L1=access_control, Titel="Rollenbasierte Zugriffskontrolle" → L2="rbac" - L1=access_control, Titel="Zugriffsrechte regelmäßig prüfen" → L2="access_review" - L1=access_control, Titel="Identitätsmanagement implementieren" → L2="identity_management" - L1=monitoring, Titel="Systemverfügbarkeit überwachen" → L2="availability" - L1=monitoring, Titel="Sicherheitsereignisse überwachen" → L2="security_events" - L1=policy, Titel="Datenschutzrichtlinie erstellen" → L2="data_protection" - L1=policy, Titel="Acceptable Use Policy definieren" → L2="acceptable_use" - L1=policy, Titel="Passwortrichtlinie festlegen" → L2="password" - L1=financial_reporting, Titel="Jahresabschluss erstellen" → L2="annual_accounts" - L1=financial_reporting, Titel="Steuererklärung einreichen" → L2="tax" - L1=alerting, Titel="Datenpanne an Behörde melden" → L2="breach_notification" - L1=alerting, Titel="Sicherheitswarnung eskalieren" → L2="escalation" REGELN: - L2 soll 1-3 Wörter sein, snake_case - L2 soll SPEZIFISCH sein (nicht das L1 wiederholen) - Verwende konsistente L2-Bezeichnungen für ähnliche Controls Antworte NUR als JSON-Array: [{"id":"...","l2":"subtopic"}, ...]""" def call_claude(controls_batch: list[dict]) -> tuple[list[dict], dict]: """Send batch to Claude for L2 sub-topic assignment.""" items = [] for c in controls_batch: items.append( f'- id="{c["control_id"]}" ' f'L1="{c["current_object"]}" ' f't="{c["title"]}" ' f'o="{c["objective"][:80]}"' ) prompt = "Bestimme L2 Sub-Topics:\n" + "\n".join(items) headers = { "x-api-key": ANTHROPIC_API_KEY, "anthropic-version": "2023-06-01", "content-type": "application/json", } payload = { "model": ANTHROPIC_MODEL, "max_tokens": 1500, "temperature": 0.0, "system": SYSTEM_PROMPT, "messages": [{"role": "user", "content": prompt}], } try: resp = httpx.post( ANTHROPIC_URL, headers=headers, json=payload, timeout=45.0 ) resp.raise_for_status() data = resp.json() content = data.get("content", [{}])[0].get("text", "") usage = data.get("usage", {}) start = content.find("[") end = content.rfind("]") + 1 if start >= 0 and end > start: return json.loads(content[start:end]), usage return [], usage except httpx.TimeoutException: logger.error("TIMEOUT — skipping") return [], {} except httpx.HTTPStatusError as e: if e.response.status_code == 429: logger.warning("Rate limited — waiting 60s") time.sleep(60) else: logger.error("API error %d", e.response.status_code) return [], {} except Exception as e: logger.error("Failed: %s", e) return [], {} def main(): parser = argparse.ArgumentParser() parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() engine = create_engine( DB_URL, connect_args={"options": "-c search_path=compliance,public"} ) # Build LIKE patterns for broad tokens like_clauses = " OR ".join( f"cc.generation_metadata->>'merge_group_hint' LIKE '%:{tok}:%'" for tok in BROAD_TOKENS ) with engine.connect() as c: rows = c.execute(text(f""" SELECT cc.id, cc.control_id, cc.title, COALESCE(cc.objective, '') as objective, cc.generation_metadata->>'merge_group_hint' as hint FROM canonical_controls cc WHERE cc.generation_metadata->>'merge_group_hint' IS NOT NULL AND cc.release_state NOT IN ('deprecated', 'rejected') AND ({like_clauses}) """)).fetchall() controls = [] for uuid, cid, title, objective, hint in rows: parts = hint.split(":", 2) if hint else [] obj = parts[1] if len(parts) > 1 else "" if obj in BROAD_TOKENS: controls.append({ "uuid": str(uuid), "control_id": cid, "title": title or "", "objective": objective or "", "current_hint": hint, "current_object": obj, }) logger.info("Found %d controls in broad tokens to add L2 sub-topics", len(controls)) # Process total_tagged = 0 total_skipped = 0 total_input_tokens = 0 total_output_tokens = 0 corrections = [] l2_stats: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for i in range(0, len(controls), args.batch_size): batch = controls[i:i + args.batch_size] results, usage = call_claude(batch) total_input_tokens += usage.get("input_tokens", 0) total_output_tokens += usage.get("output_tokens", 0) if not results: total_skipped += len(batch) continue result_map = {r.get("id", ""): r for r in results} for ctrl in batch: r = result_map.get(ctrl["control_id"], {}) l2 = r.get("l2", "") if not l2: total_skipped += 1 continue total_tagged += 1 old_hint = ctrl["current_hint"] parts = old_hint.split(":", 2) action = parts[0] if parts else "implement" l1 = parts[1] if len(parts) > 1 else "unknown" phase = parts[2] if len(parts) > 2 else "implementation" # New format: action:L1_L2:phase new_obj = f"{l1}_{l2}" new_hint = f"{action}:{new_obj}:{phase}" corrections.append({ "uuid": ctrl["uuid"], "old_hint": old_hint, "new_hint": new_hint, }) l2_stats[l1][l2] += 1 processed = min(i + args.batch_size, len(controls)) if processed % 5000 < args.batch_size or processed >= len(controls): logger.info( "Progress: %d/%d (tagged=%d skip=%d)", processed, len(controls), total_tagged, total_skipped, ) time.sleep(0.3) # Report cost_in = total_input_tokens / 1_000_000 * 0.80 cost_out = total_output_tokens / 1_000_000 * 4.00 logger.info("\n" + "=" * 60) logger.info("SUBTOPIC REPORT") logger.info("=" * 60) logger.info("Total: %d | Tagged: %d | Skipped: %d", len(controls), total_tagged, total_skipped) logger.info("Cost: $%.2f (Haiku)", cost_in + cost_out) # Show L2 distribution per L1 for l1, subs in sorted(l2_stats.items()): top_subs = sorted(subs.items(), key=lambda x: -x[1])[:10] logger.info("\n%s (%d unique L2):", l1, len(subs)) for l2, cnt in top_subs: logger.info(" %4d %s_%s", cnt, l1, l2) # Save corrections CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) corr_file = CHECKPOINT_DIR / "corrections_subtopics.json" corr_file.write_text(json.dumps(corrections)) logger.info("\nSaved %d corrections to %s", len(corrections), corr_file) if args.dry_run: logger.info("DRY RUN — not updating DB") return if corrections: logger.info("Applying %d corrections...", len(corrections)) with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) for corr in corrections: c.execute(text(""" UPDATE canonical_controls SET generation_metadata = jsonb_set( generation_metadata, '{merge_group_hint}', to_jsonb(CAST(:new_hint AS text)) ) WHERE id = CAST(:uuid AS uuid) """), {"uuid": corr["uuid"], "new_hint": corr["new_hint"]}) logger.info("Done. %d hints updated.", len(corrections)) if __name__ == "__main__": main()