Files
breakpilot-core/control-pipeline/scripts/gpre0_validate_hints.py
T
Benjamin Admin 8510af46eb feat(pipeline): MC Quality Overhaul — 74.5% → 92.8% accuracy, 5.3K → 13.6K MCs
Phase 0: Quality Audit script (Claude Sonnet, 1750 samples)
Phase 1: Object ontology expanded 31 → 74 tokens with descriptions + boundaries
Phase 2: 174K controls re-classified via Haiku (10 batches, $50)
  - Generic tokens removed (documentation, procedure, process)
  - L2 sub-topics added (108K + 64K controls)
  - Bad subtopics fixed (stakeholder_*, escalation fragments)
Phase 3: Re-clustering K=18704 (37K objects → 16.7K groups)
Phase 4: Direct MC generation from canonical tokens (gpre2_direct_mc.py)
Phase 5: Regulation-source split (gpre3, dry-run tested)

New features:
- Tenant-isolated document upload API (rag-service)
- BAuA crawler (Playwright, 131 PDFs downloaded)
- OSHA Technical Manual crawler (23 chapters)
- CE obligation extractor (6141 obligations from Qdrant)

RAG ingestion:
- 126 BAuA PDFs (TRBS/TRGS/ASR): 27,664 chunks
- OSHA Technical Manual: 7,241 chunks
- OSHA 1910 Subpart O (full): 745 chunks
- EuGH C-588/21 P: 216 chunks
- EU 2018/1725: 842 chunks

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-10 15:08:15 +02:00

352 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Phase 2: Validate and correct merge_group_hints using Claude Haiku.
Re-classifies each control's object token against the expanded ontology
(74 canonical tokens). Corrects wrong hints in the DB.
SAFETY: Split into 4 batches. NEVER retries on timeout (double-billing!).
Writes checkpoint after each API call for safe resume.
Usage:
python3 /app/scripts/gpre0_validate_hints.py --batch-id 1 --dry-run
python3 /app/scripts/gpre0_validate_hints.py --batch-id 1
python3 /app/scripts/gpre0_validate_hints.py --batch-id 2
python3 /app/scripts/gpre0_validate_hints.py --batch-id 3
python3 /app/scripts/gpre0_validate_hints.py --batch-id 4
"""
import argparse
import json
import logging
import os
import time
from collections import defaultdict
from pathlib import Path
import httpx
from sqlalchemy import create_engine, text
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger("gpre0-validate")
DB_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db",
)
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
ANTHROPIC_MODEL = "claude-haiku-4-5-20251001"
ANTHROPIC_URL = "https://api.anthropic.com/v1/messages"
CHECKPOINT_DIR = Path("/tmp/gpre0_checkpoints")
SYSTEM_PROMPT = """Du bist ein Compliance-Klassifizierer. Ordne jeden Control GENAU EINEM Token zu.
REGEL: Waehle IMMER den naechstbesten Token aus der Liste. OTHER nur wenn ABSOLUT
kein Token auch nur entfernt passt (<1% der Faelle). Im Zweifel: den breitesten
passenden Token waehlen (z.B. "policy" fuer Governance-Dokumente, "procedure" fuer
Ablauf-Definitionen, "risk_management" fuer Bewertungen).
TOKENS:
SECURITY: multi_factor_auth, password_policy, credentials, session_management,
privileged_access, access_control, encryption, transport_encryption,
key_management, certificate_management, network_security, network_segmentation,
firewall, vpn, remote_access, monitoring (NUR Echtzeit-Systemueberwachung),
audit_logging (Protokollierung/Audit Trail), siem, alerting (Meldepflichten),
compliance_audit (externe Pruefungen), vulnerability, patch_management,
backup, disaster_recovery, physical_security, secure_development,
api_security, input_validation, container_security, logging_configuration
DATA_PROTECTION: personal_data (DSGVO-Verarbeitung), sensitive_data (Art.9),
health_data, consent, data_subject_rights, data_retention, data_transfer,
data_breach_notification, dpia, data_processing_agreement, privacy_by_design,
data_processing_register, data_classification, cookie_consent, video_surveillance
GOVERNANCE: policy (Richtlinie definieren), procedure (Verfahren definieren),
process (Betriebsprozess ausfuehren), training (Schulung), awareness,
incident (Vorfallsbehandlung), risk_management, third_party_management,
change_management, documentation, records_management, compliance_reporting,
asset_management, human_resources_security
REGULATORY: supervisory_authority, certification (Zertifizierung/Konformitaet),
product_safety, ai_system, financial_reporting, aml, whistleblowing,
consumer_protection, ecommerce, telecommunications, medical_device,
payment_services, critical_infrastructure, supply_chain_due_diligence,
sustainability_reporting
ABGRENZUNGEN:
- monitoring = NUR Echtzeit-Systemueberwachung, NICHT Audit/Schulung/Bewertung
- audit_logging = Protokollierung, NICHT externe Pruefung (→ compliance_audit)
- procedure = Verfahren DEFINIEREN, NICHT Vorfaelle behandeln (→ incident)
- personal_data = DSGVO-Verarbeitung, NICHT Zertifizierung (→ certification)
- alerting = Meldepflichten, NICHT Vorfallsbehandlung (→ incident)
Antworte NUR als JSON-Array: [{"id":"...","token":"...","conf":0.9}, ...]
KEIN weiterer Text. Nur das Array."""
def call_claude(controls_batch: list[dict]) -> tuple[list[dict], dict]:
"""Send batch to Claude. NO RETRY on timeout (double-billing risk!)."""
items = []
for c in controls_batch:
items.append(
f'- id="{c["control_id"]}" '
f'cur="{c["current_object"]}" '
f't="{c["title"]}" '
f'o="{c["objective"][:100]}"'
)
prompt = "Klassifiziere:\n" + "\n".join(items)
headers = {
"x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
payload = {
"model": ANTHROPIC_MODEL,
"max_tokens": 1500,
"temperature": 0.0,
"system": SYSTEM_PROMPT,
"messages": [{"role": "user", "content": prompt}],
}
try:
resp = httpx.post(
ANTHROPIC_URL, headers=headers, json=payload, timeout=45.0
)
resp.raise_for_status()
data = resp.json()
content = data.get("content", [{}])[0].get("text", "")
usage = data.get("usage", {})
start = content.find("[")
end = content.rfind("]") + 1
if start >= 0 and end > start:
return json.loads(content[start:end]), usage
logger.warning("No JSON array in response")
return [], usage
except httpx.TimeoutException:
# CRITICAL: Do NOT retry! Log and skip.
logger.error("TIMEOUT — skipping batch (NOT retrying to avoid double-billing)")
return [], {}
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
logger.warning("Rate limited — waiting 60s then skipping")
time.sleep(60)
else:
logger.error("API error %d — skipping batch", e.response.status_code)
return [], {}
except Exception as e:
logger.error("Request failed — skipping: %s", e)
return [], {}
def load_checkpoint(batch_id: int) -> int:
"""Load last processed index for this batch."""
cp_file = CHECKPOINT_DIR / f"batch_{batch_id}.json"
if cp_file.exists():
data = json.loads(cp_file.read_text())
return data.get("last_index", 0)
return 0
def save_checkpoint(batch_id: int, last_index: int, stats: dict):
"""Save progress checkpoint."""
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
cp_file = CHECKPOINT_DIR / f"batch_{batch_id}.json"
cp_file.write_text(json.dumps({
"batch_id": batch_id,
"last_index": last_index,
**stats,
}))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-id", type=int, required=True)
parser.add_argument("--total-batches", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--resume", action="store_true",
help="Resume from checkpoint")
args = parser.parse_args()
engine = create_engine(
DB_URL, connect_args={"options": "-c search_path=compliance,public"}
)
# Load ALL control IDs ordered deterministically, then select quarter
with engine.connect() as c:
all_ids = c.execute(text("""
SELECT cc.id
FROM canonical_controls cc
WHERE cc.generation_metadata->>'merge_group_hint' IS NOT NULL
AND cc.generation_metadata->>'merge_group_hint' != ''
AND cc.release_state NOT IN ('deprecated', 'rejected')
ORDER BY cc.id
""")).fetchall()
total = len(all_ids)
chunk = total // args.total_batches
start_idx = (args.batch_id - 1) * chunk
end_idx = total if args.batch_id == args.total_batches else args.batch_id * chunk
batch_ids = [str(r[0]) for r in all_ids[start_idx:end_idx]]
logger.info("Batch %d/%d: controls %d-%d (%d controls of %d total)",
args.batch_id, args.total_batches, start_idx, end_idx, len(batch_ids), total)
# Load full data for this batch
id_list = ",".join(f"'{uid}'" for uid in batch_ids)
with engine.connect() as c:
rows = c.execute(text(f"""
SELECT cc.id, cc.control_id, cc.title,
COALESCE(cc.objective, '') as objective,
cc.generation_metadata->>'merge_group_hint' as hint
FROM canonical_controls cc
WHERE cc.id IN ({id_list})
ORDER BY cc.id
""")).fetchall()
controls = []
for uuid, cid, title, objective, hint in rows:
parts = hint.split(":", 2) if hint else []
controls.append({
"uuid": str(uuid), "control_id": cid,
"title": title or "", "objective": objective or "",
"current_hint": hint, "current_object": parts[1] if len(parts) > 1 else hint,
})
# Resume from checkpoint?
start_from = 0
if args.resume:
start_from = load_checkpoint(args.batch_id)
if start_from > 0:
logger.info("Resuming from index %d", start_from)
# Process
total_same = 0
total_changed = 0
total_other = 0
total_skipped = 0
total_input_tokens = 0
total_output_tokens = 0
corrections: list[dict] = []
change_stats: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
for i in range(start_from, len(controls), args.batch_size):
batch = controls[i:i + args.batch_size]
results, usage = call_claude(batch)
total_input_tokens += usage.get("input_tokens", 0)
total_output_tokens += usage.get("output_tokens", 0)
if not results:
total_skipped += len(batch)
save_checkpoint(args.batch_id, i + args.batch_size, {
"same": total_same, "changed": total_changed,
"other": total_other, "skipped": total_skipped,
})
continue
result_map = {r.get("id", ""): r for r in results}
for ctrl in batch:
r = result_map.get(ctrl["control_id"], {})
new_token = r.get("token", "")
if not new_token:
total_skipped += 1
continue
old_obj = ctrl["current_object"]
if new_token == "OTHER":
total_other += 1
elif new_token == old_obj:
total_same += 1
else:
total_changed += 1
parts = ctrl["current_hint"].split(":", 2)
action = parts[0] if parts else "implement"
phase = parts[2] if len(parts) > 2 else "implementation"
corrections.append({
"uuid": ctrl["uuid"],
"old_hint": ctrl["current_hint"],
"new_hint": f"{action}:{new_token}:{phase}",
})
change_stats[old_obj][new_token] += 1
# Checkpoint every batch
save_checkpoint(args.batch_id, i + args.batch_size, {
"same": total_same, "changed": total_changed,
"other": total_other, "skipped": total_skipped,
})
processed = min(i + args.batch_size, len(controls))
if processed % 1000 < args.batch_size or processed >= len(controls):
logger.info(
"Batch %d: %d/%d (same=%d changed=%d other=%d skip=%d)",
args.batch_id, processed, len(controls),
total_same, total_changed, total_other, total_skipped,
)
time.sleep(0.3)
# Report
cost_in = total_input_tokens / 1_000_000 * 0.80 # Haiku
cost_out = total_output_tokens / 1_000_000 * 4.00 # Haiku
total_cost = cost_in + cost_out
total_proc = total_same + total_changed + total_other
logger.info("\n" + "=" * 60)
logger.info("BATCH %d REPORT", args.batch_id)
logger.info("=" * 60)
logger.info("Processed: %d | Skipped: %d", total_proc, total_skipped)
logger.info("Same: %d (%.1f%%)", total_same, total_same / max(total_proc, 1) * 100)
logger.info("Changed: %d (%.1f%%)", total_changed, total_changed / max(total_proc, 1) * 100)
logger.info("OTHER: %d (%.1f%%)", total_other, total_other / max(total_proc, 1) * 100)
logger.info("Cost: $%.2f (Haiku)", total_cost)
logger.info("Cost/ctrl: $%.5f", total_cost / max(total_proc, 1))
# Top changes
flat = []
for old, news in change_stats.items():
for new, cnt in news.items():
flat.append((cnt, old, new))
logger.info("\nTop Changes:")
for cnt, old, new in sorted(flat, reverse=True)[:20]:
logger.info(" %4d × %s%s", cnt, old, new)
# Always save corrections to file (recovery safety)
corr_file = CHECKPOINT_DIR / f"corrections_batch_{args.batch_id}.json"
if corrections:
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
corr_file.write_text(json.dumps(corrections))
logger.info("Saved %d corrections to %s", len(corrections), corr_file)
if args.dry_run:
logger.info("\nDRY RUN — not updating DB")
return
# Apply corrections in single transaction
if corrections:
logger.info("\nApplying %d corrections...", len(corrections))
with engine.begin() as c:
c.execute(text("SET search_path TO compliance, public"))
for corr in corrections:
c.execute(text("""
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
generation_metadata,
'{merge_group_hint}',
to_jsonb(CAST(:new_hint AS text))
)
WHERE id = CAST(:uuid AS uuid)
"""), {"uuid": corr["uuid"], "new_hint": corr["new_hint"]})
logger.info("Done. %d hints corrected.", len(corrections))
else:
logger.info("No corrections needed.")
if __name__ == "__main__":
main()