#!/usr/bin/env python3 """ Phase 2: Validate and correct merge_group_hints using Claude Haiku. Re-classifies each control's object token against the expanded ontology (74 canonical tokens). Corrects wrong hints in the DB. SAFETY: Split into 4 batches. NEVER retries on timeout (double-billing!). Writes checkpoint after each API call for safe resume. Usage: python3 /app/scripts/gpre0_validate_hints.py --batch-id 1 --dry-run python3 /app/scripts/gpre0_validate_hints.py --batch-id 1 python3 /app/scripts/gpre0_validate_hints.py --batch-id 2 python3 /app/scripts/gpre0_validate_hints.py --batch-id 3 python3 /app/scripts/gpre0_validate_hints.py --batch-id 4 """ import argparse import json import logging import os import time from collections import defaultdict from pathlib import Path import httpx from sqlalchemy import create_engine, text logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger("gpre0-validate") DB_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db", ) ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_MODEL = "claude-haiku-4-5-20251001" ANTHROPIC_URL = "https://api.anthropic.com/v1/messages" CHECKPOINT_DIR = Path("/tmp/gpre0_checkpoints") SYSTEM_PROMPT = """Du bist ein Compliance-Klassifizierer. Ordne jeden Control GENAU EINEM Token zu. REGEL: Waehle IMMER den naechstbesten Token aus der Liste. OTHER nur wenn ABSOLUT kein Token auch nur entfernt passt (<1% der Faelle). Im Zweifel: den breitesten passenden Token waehlen (z.B. "policy" fuer Governance-Dokumente, "procedure" fuer Ablauf-Definitionen, "risk_management" fuer Bewertungen). TOKENS: SECURITY: multi_factor_auth, password_policy, credentials, session_management, privileged_access, access_control, encryption, transport_encryption, key_management, certificate_management, network_security, network_segmentation, firewall, vpn, remote_access, monitoring (NUR Echtzeit-Systemueberwachung), audit_logging (Protokollierung/Audit Trail), siem, alerting (Meldepflichten), compliance_audit (externe Pruefungen), vulnerability, patch_management, backup, disaster_recovery, physical_security, secure_development, api_security, input_validation, container_security, logging_configuration DATA_PROTECTION: personal_data (DSGVO-Verarbeitung), sensitive_data (Art.9), health_data, consent, data_subject_rights, data_retention, data_transfer, data_breach_notification, dpia, data_processing_agreement, privacy_by_design, data_processing_register, data_classification, cookie_consent, video_surveillance GOVERNANCE: policy (Richtlinie definieren), procedure (Verfahren definieren), process (Betriebsprozess ausfuehren), training (Schulung), awareness, incident (Vorfallsbehandlung), risk_management, third_party_management, change_management, documentation, records_management, compliance_reporting, asset_management, human_resources_security REGULATORY: supervisory_authority, certification (Zertifizierung/Konformitaet), product_safety, ai_system, financial_reporting, aml, whistleblowing, consumer_protection, ecommerce, telecommunications, medical_device, payment_services, critical_infrastructure, supply_chain_due_diligence, sustainability_reporting ABGRENZUNGEN: - monitoring = NUR Echtzeit-Systemueberwachung, NICHT Audit/Schulung/Bewertung - audit_logging = Protokollierung, NICHT externe Pruefung (→ compliance_audit) - procedure = Verfahren DEFINIEREN, NICHT Vorfaelle behandeln (→ incident) - personal_data = DSGVO-Verarbeitung, NICHT Zertifizierung (→ certification) - alerting = Meldepflichten, NICHT Vorfallsbehandlung (→ incident) Antworte NUR als JSON-Array: [{"id":"...","token":"...","conf":0.9}, ...] KEIN weiterer Text. Nur das Array.""" def call_claude(controls_batch: list[dict]) -> tuple[list[dict], dict]: """Send batch to Claude. NO RETRY on timeout (double-billing risk!).""" items = [] for c in controls_batch: items.append( f'- id="{c["control_id"]}" ' f'cur="{c["current_object"]}" ' f't="{c["title"]}" ' f'o="{c["objective"][:100]}"' ) prompt = "Klassifiziere:\n" + "\n".join(items) headers = { "x-api-key": ANTHROPIC_API_KEY, "anthropic-version": "2023-06-01", "content-type": "application/json", } payload = { "model": ANTHROPIC_MODEL, "max_tokens": 1500, "temperature": 0.0, "system": SYSTEM_PROMPT, "messages": [{"role": "user", "content": prompt}], } try: resp = httpx.post( ANTHROPIC_URL, headers=headers, json=payload, timeout=45.0 ) resp.raise_for_status() data = resp.json() content = data.get("content", [{}])[0].get("text", "") usage = data.get("usage", {}) start = content.find("[") end = content.rfind("]") + 1 if start >= 0 and end > start: return json.loads(content[start:end]), usage logger.warning("No JSON array in response") return [], usage except httpx.TimeoutException: # CRITICAL: Do NOT retry! Log and skip. logger.error("TIMEOUT — skipping batch (NOT retrying to avoid double-billing)") return [], {} except httpx.HTTPStatusError as e: if e.response.status_code == 429: logger.warning("Rate limited — waiting 60s then skipping") time.sleep(60) else: logger.error("API error %d — skipping batch", e.response.status_code) return [], {} except Exception as e: logger.error("Request failed — skipping: %s", e) return [], {} def load_checkpoint(batch_id: int) -> int: """Load last processed index for this batch.""" cp_file = CHECKPOINT_DIR / f"batch_{batch_id}.json" if cp_file.exists(): data = json.loads(cp_file.read_text()) return data.get("last_index", 0) return 0 def save_checkpoint(batch_id: int, last_index: int, stats: dict): """Save progress checkpoint.""" CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) cp_file = CHECKPOINT_DIR / f"batch_{batch_id}.json" cp_file.write_text(json.dumps({ "batch_id": batch_id, "last_index": last_index, **stats, })) def main(): parser = argparse.ArgumentParser() parser.add_argument("--batch-id", type=int, required=True) parser.add_argument("--total-batches", type=int, default=10) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") args = parser.parse_args() engine = create_engine( DB_URL, connect_args={"options": "-c search_path=compliance,public"} ) # Load ALL control IDs ordered deterministically, then select quarter with engine.connect() as c: all_ids = c.execute(text(""" SELECT cc.id FROM canonical_controls cc WHERE cc.generation_metadata->>'merge_group_hint' IS NOT NULL AND cc.generation_metadata->>'merge_group_hint' != '' AND cc.release_state NOT IN ('deprecated', 'rejected') ORDER BY cc.id """)).fetchall() total = len(all_ids) chunk = total // args.total_batches start_idx = (args.batch_id - 1) * chunk end_idx = total if args.batch_id == args.total_batches else args.batch_id * chunk batch_ids = [str(r[0]) for r in all_ids[start_idx:end_idx]] logger.info("Batch %d/%d: controls %d-%d (%d controls of %d total)", args.batch_id, args.total_batches, start_idx, end_idx, len(batch_ids), total) # Load full data for this batch id_list = ",".join(f"'{uid}'" for uid in batch_ids) with engine.connect() as c: rows = c.execute(text(f""" SELECT cc.id, cc.control_id, cc.title, COALESCE(cc.objective, '') as objective, cc.generation_metadata->>'merge_group_hint' as hint FROM canonical_controls cc WHERE cc.id IN ({id_list}) ORDER BY cc.id """)).fetchall() controls = [] for uuid, cid, title, objective, hint in rows: parts = hint.split(":", 2) if hint else [] controls.append({ "uuid": str(uuid), "control_id": cid, "title": title or "", "objective": objective or "", "current_hint": hint, "current_object": parts[1] if len(parts) > 1 else hint, }) # Resume from checkpoint? start_from = 0 if args.resume: start_from = load_checkpoint(args.batch_id) if start_from > 0: logger.info("Resuming from index %d", start_from) # Process total_same = 0 total_changed = 0 total_other = 0 total_skipped = 0 total_input_tokens = 0 total_output_tokens = 0 corrections: list[dict] = [] change_stats: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for i in range(start_from, len(controls), args.batch_size): batch = controls[i:i + args.batch_size] results, usage = call_claude(batch) total_input_tokens += usage.get("input_tokens", 0) total_output_tokens += usage.get("output_tokens", 0) if not results: total_skipped += len(batch) save_checkpoint(args.batch_id, i + args.batch_size, { "same": total_same, "changed": total_changed, "other": total_other, "skipped": total_skipped, }) continue result_map = {r.get("id", ""): r for r in results} for ctrl in batch: r = result_map.get(ctrl["control_id"], {}) new_token = r.get("token", "") if not new_token: total_skipped += 1 continue old_obj = ctrl["current_object"] if new_token == "OTHER": total_other += 1 elif new_token == old_obj: total_same += 1 else: total_changed += 1 parts = ctrl["current_hint"].split(":", 2) action = parts[0] if parts else "implement" phase = parts[2] if len(parts) > 2 else "implementation" corrections.append({ "uuid": ctrl["uuid"], "old_hint": ctrl["current_hint"], "new_hint": f"{action}:{new_token}:{phase}", }) change_stats[old_obj][new_token] += 1 # Checkpoint every batch save_checkpoint(args.batch_id, i + args.batch_size, { "same": total_same, "changed": total_changed, "other": total_other, "skipped": total_skipped, }) processed = min(i + args.batch_size, len(controls)) if processed % 1000 < args.batch_size or processed >= len(controls): logger.info( "Batch %d: %d/%d (same=%d changed=%d other=%d skip=%d)", args.batch_id, processed, len(controls), total_same, total_changed, total_other, total_skipped, ) time.sleep(0.3) # Report cost_in = total_input_tokens / 1_000_000 * 0.80 # Haiku cost_out = total_output_tokens / 1_000_000 * 4.00 # Haiku total_cost = cost_in + cost_out total_proc = total_same + total_changed + total_other logger.info("\n" + "=" * 60) logger.info("BATCH %d REPORT", args.batch_id) logger.info("=" * 60) logger.info("Processed: %d | Skipped: %d", total_proc, total_skipped) logger.info("Same: %d (%.1f%%)", total_same, total_same / max(total_proc, 1) * 100) logger.info("Changed: %d (%.1f%%)", total_changed, total_changed / max(total_proc, 1) * 100) logger.info("OTHER: %d (%.1f%%)", total_other, total_other / max(total_proc, 1) * 100) logger.info("Cost: $%.2f (Haiku)", total_cost) logger.info("Cost/ctrl: $%.5f", total_cost / max(total_proc, 1)) # Top changes flat = [] for old, news in change_stats.items(): for new, cnt in news.items(): flat.append((cnt, old, new)) logger.info("\nTop Changes:") for cnt, old, new in sorted(flat, reverse=True)[:20]: logger.info(" %4d × %s → %s", cnt, old, new) # Always save corrections to file (recovery safety) corr_file = CHECKPOINT_DIR / f"corrections_batch_{args.batch_id}.json" if corrections: CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) corr_file.write_text(json.dumps(corrections)) logger.info("Saved %d corrections to %s", len(corrections), corr_file) if args.dry_run: logger.info("\nDRY RUN — not updating DB") return # Apply corrections in single transaction if corrections: logger.info("\nApplying %d corrections...", len(corrections)) with engine.begin() as c: c.execute(text("SET search_path TO compliance, public")) for corr in corrections: c.execute(text(""" UPDATE canonical_controls SET generation_metadata = jsonb_set( generation_metadata, '{merge_group_hint}', to_jsonb(CAST(:new_hint AS text)) ) WHERE id = CAST(:uuid AS uuid) """), {"uuid": corr["uuid"], "new_hint": corr["new_hint"]}) logger.info("Done. %d hints corrected.", len(corrections)) else: logger.info("No corrections needed.") if __name__ == "__main__": main()