#!/usr/bin/env python3 """ Phase 0: Quality Audit for Master Control Assignments. Uses Claude Sonnet to validate whether controls are correctly assigned to their Master Controls. Samples controls from large and small MCs. Usage: python3 /app/scripts/gpre_quality_audit.py python3 /app/scripts/gpre_quality_audit.py --large-sample 50 --small-sample 10 python3 /app/scripts/gpre_quality_audit.py --mc MC-8292 # single MC """ import argparse import json import logging import os import random import time from collections import defaultdict import httpx from sqlalchemy import create_engine, text logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger("quality-audit") DB_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db", ) ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_MODEL = os.getenv("AUDIT_MODEL", "claude-sonnet-4-20250514") ANTHROPIC_URL = "https://api.anthropic.com/v1/messages" SYSTEM_PROMPT = """Du bist ein Compliance-Experte der prüft ob Controls korrekt zu Master Controls zugeordnet sind. Für jeden Control beantworte: 1. MATCH: Gehört dieser Control thematisch zum Master Control Topic? 2. CONFIDENCE: Wie sicher bist du? (0.0-1.0) 3. REASON: Kurze Begründung (max 1 Satz) 4. SUGGESTED_TOPIC: Falls MATCH=false, welches Topic wäre korrekt? Wichtige Unterscheidungen: - "monitoring" = kontinuierliche Überwachung, Alerting, Log-Analyse - "training" = Schulung, Awareness, Lernmaterialien - "personal_data" = personenbezogene Daten, DSGVO-Betroffenenrechte - "procedure" = Verfahren, Prozesse (aber NICHT wenn es spezifisch um Incidents geht) - "incident" = Sicherheitsvorfälle, Breach Notification, Recovery - "policy" = Richtlinien, Regelwerke, Governance-Dokumente - "encryption" = Verschlüsselung, Kryptografie, Key Management - "audit_logging" = Protokollierung, Audit Trail, Nachvollziehbarkeit Antworte NUR als JSON-Array, ein Objekt pro Control.""" def call_claude(controls_batch: list[dict], mc_topic: str) -> list[dict]: """Send a batch of controls to Claude for validation.""" items = [] for c in controls_batch: items.append( f"- Control '{c['control_id']}': " f"Titel=\"{c['title']}\", " f"Objective=\"{c['objective'][:150]}...\", " f"Phase={c['phase']}, Action={c['action']}" ) prompt = ( f"Master Control Topic: \"{mc_topic}\"\n\n" f"Prüfe diese {len(controls_batch)} Controls:\n\n" + "\n".join(items) + "\n\nAntwort als JSON-Array mit Feldern: " "control_id, match (bool), confidence (float), reason (str), " "suggested_topic (str, nur wenn match=false)." ) headers = { "x-api-key": ANTHROPIC_API_KEY, "anthropic-version": "2023-06-01", "content-type": "application/json", } payload = { "model": ANTHROPIC_MODEL, "max_tokens": 2048, "temperature": 0.1, "system": SYSTEM_PROMPT, "messages": [{"role": "user", "content": prompt}], } for attempt in range(3): try: resp = httpx.post( ANTHROPIC_URL, headers=headers, json=payload, timeout=60.0, ) resp.raise_for_status() data = resp.json() content = data.get("content", [{}])[0].get("text", "") usage = data.get("usage", {}) # Parse JSON from response start = content.find("[") end = content.rfind("]") + 1 if start >= 0 and end > start: results = json.loads(content[start:end]) return results, usage logger.warning("No JSON array in response: %s", content[:200]) return [], usage except httpx.HTTPStatusError as e: if e.response.status_code == 429: wait = 30 * (attempt + 1) logger.warning("Rate limited, waiting %ds...", wait) time.sleep(wait) else: logger.error("API error: %s", e) return [], {} except Exception as e: logger.error("Request failed (attempt %d): %s", attempt + 1, e) if attempt < 2: time.sleep(5) return [], {} def main(): parser = argparse.ArgumentParser() parser.add_argument("--large-sample", type=int, default=50, help="Controls to sample per large MC") parser.add_argument("--small-sample", type=int, default=10, help="Controls to sample per small MC") parser.add_argument("--small-mc-count", type=int, default=50, help="Number of small MCs to audit") parser.add_argument("--mc", type=str, default=None, help="Audit a single MC by ID (e.g., MC-8292)") parser.add_argument("--batch-size", type=int, default=10, help="Controls per API call") args = parser.parse_args() engine = create_engine( DB_URL, connect_args={"options": "-c search_path=compliance,public"} ) # Load MCs to audit with engine.connect() as c: if args.mc: mcs = c.execute(text(""" SELECT id, master_control_id, canonical_name, total_controls FROM master_controls WHERE master_control_id = :mc """), {"mc": args.mc}).fetchall() else: # Large MCs (>200) + random small MCs large = c.execute(text(""" SELECT id, master_control_id, canonical_name, total_controls FROM master_controls WHERE total_controls > 200 ORDER BY total_controls DESC """)).fetchall() small = c.execute(text(""" SELECT id, master_control_id, canonical_name, total_controls FROM master_controls WHERE total_controls BETWEEN 10 AND 200 ORDER BY RANDOM() LIMIT :cnt """), {"cnt": args.small_mc_count}).fetchall() mcs = list(large) + list(small) logger.info("Auditing %d Master Controls", len(mcs)) # Results tracking total_checked = 0 total_match = 0 total_mismatch = 0 total_input_tokens = 0 total_output_tokens = 0 mc_results: dict[str, dict] = {} all_mismatches: list[dict] = [] for mc_uuid, mc_id, canonical, total in mcs: is_large = total > 200 sample_size = args.large_sample if is_large else args.small_sample # Sample controls with engine.connect() as c: controls = c.execute(text(""" SELECT mcm.control_uuid, mcm.phase, mcm.action, cc.control_id, cc.title, COALESCE(cc.objective, '') as objective FROM master_control_members mcm JOIN canonical_controls cc ON cc.id = mcm.control_uuid WHERE mcm.master_control_uuid = CAST(:mc AS uuid) ORDER BY RANDOM() LIMIT :n """), {"mc": str(mc_uuid), "n": sample_size}).fetchall() if not controls: continue control_dicts = [ {"control_uuid": str(r[0]), "phase": r[1], "action": r[2], "control_id": r[3], "title": r[4] or "", "objective": r[5] or ""} for r in controls ] logger.info("\n%s: %s (%d total, sampling %d)", mc_id, canonical, total, len(control_dicts)) mc_match = 0 mc_mismatch = 0 # Process in batches for i in range(0, len(control_dicts), args.batch_size): batch = control_dicts[i:i + args.batch_size] results, usage = call_claude(batch, canonical) total_input_tokens += usage.get("input_tokens", 0) total_output_tokens += usage.get("output_tokens", 0) for r in results: if r.get("match", True): mc_match += 1 total_match += 1 else: mc_mismatch += 1 total_mismatch += 1 mismatch = { "mc_id": mc_id, "mc_topic": canonical, "control_id": r.get("control_id", "?"), "confidence": r.get("confidence", 0), "reason": r.get("reason", ""), "suggested_topic": r.get("suggested_topic", ""), } all_mismatches.append(mismatch) total_checked += len(results) # Rate limit time.sleep(1) accuracy = mc_match / (mc_match + mc_mismatch) if (mc_match + mc_mismatch) > 0 else 1.0 mc_results[mc_id] = { "canonical": canonical, "total": total, "checked": mc_match + mc_mismatch, "match": mc_match, "mismatch": mc_mismatch, "accuracy": accuracy, } logger.info(" → %d/%d correct (%.1f%%)", mc_match, mc_match + mc_mismatch, accuracy * 100) # Final report _print_report(mc_results, all_mismatches, total_checked, total_match, total_mismatch, total_input_tokens, total_output_tokens) def _print_report(mc_results, mismatches, checked, match, mismatch, input_tok, output_tok): """Print the quality audit report.""" logger.info("\n" + "=" * 70) logger.info("QUALITY AUDIT REPORT") logger.info("=" * 70) logger.info("Total controls checked: %d", checked) logger.info("Correct assignments: %d (%.1f%%)", match, match / max(checked, 1) * 100) logger.info("Wrong assignments: %d (%.1f%%)", mismatch, mismatch / max(checked, 1) * 100) # Cost estimate cost_input = input_tok / 1_000_000 * 3.0 # Sonnet input: $3/MTok cost_output = output_tok / 1_000_000 * 15.0 # Sonnet output: $15/MTok logger.info("\nAPI Usage: %d input + %d output tokens", input_tok, output_tok) logger.info("Estimated cost: $%.2f", cost_input + cost_output) # Per-MC breakdown (worst first) logger.info("\n--- Per-MC Accuracy (worst first) ---") sorted_mcs = sorted(mc_results.values(), key=lambda x: x["accuracy"]) for mc in sorted_mcs: flag = "❌" if mc["accuracy"] < 0.9 else "⚠️" if mc["accuracy"] < 0.95 else "✅" logger.info(" %s %s (%s): %d/%d = %.1f%% [total: %d]", flag, mc["canonical"][:30].ljust(30), "large" if mc["total"] > 200 else "small", mc["match"], mc["checked"], mc["accuracy"] * 100, mc["total"]) # Top mismatches if mismatches: logger.info("\n--- Mismatches (all %d) ---", len(mismatches)) for m in sorted(mismatches, key=lambda x: -x.get("confidence", 0)): logger.info(" %s in %s (%s) → should be '%s': %s", m["control_id"], m["mc_id"], m["mc_topic"], m["suggested_topic"], m["reason"]) # Size-class breakdown large_mcs = [m for m in mc_results.values() if m["total"] > 200] small_mcs = [m for m in mc_results.values() if m["total"] <= 200] if large_mcs: lg_acc = sum(m["match"] for m in large_mcs) / max(sum(m["checked"] for m in large_mcs), 1) logger.info("\nLarge MCs (>200): %.1f%% accuracy (%d MCs)", lg_acc * 100, len(large_mcs)) if small_mcs: sm_acc = sum(m["match"] for m in small_mcs) / max(sum(m["checked"] for m in small_mcs), 1) logger.info("Small MCs (≤200): %.1f%% accuracy (%d MCs)", sm_acc * 100, len(small_mcs)) if __name__ == "__main__": main()