diff --git a/backend-compliance/compliance/services/specialist_agents/dse/_obligation_shadow.py b/backend-compliance/compliance/services/specialist_agents/dse/_obligation_shadow.py new file mode 100644 index 00000000..793d46a7 --- /dev/null +++ b/backend-compliance/compliance/services/specialist_agents/dse/_obligation_shadow.py @@ -0,0 +1,109 @@ +"""DSE Shadow-Verdrahtung der Obligation Aggregation Engine. + +Erzeugt aus den v3-`results` zusätzlich Obligation-Ergebnisse — AUSSCHLIESSLICH +für die Telemetrie (Shadow Mode). Ändert KEINE nutzer-sichtbaren Findings. + +Mapping control-level über generation_metadata.legal_obligations + +applicability.conditional; das `met`-Signal ist das Legacy-`passed` des Controls +(kein zusätzlicher Prüfer-Call, kein Key). Liefert die Vergleichszahlen, mit denen +sich der Umschalt-Entscheid später absichern lässt: + legacy_control_findings · obligation_shadow_results · collapse_factor · + na_count · met_failed_delta · top_collapsed_obligations +""" +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +async def fetch_obligation_markers(cids: list[str], db_url: str = "") -> dict[str, dict]: + """legal_obligations + applicability.conditional der Controls laden. + Leeres Dict bei Fehler/keiner DB (Shadow fällt still aus).""" + cids = [c for c in cids if c] + if not cids: + return {} + import json + dsn = db_url or os.getenv("DATABASE_URL") or os.getenv("COMPLIANCE_DATABASE_URL") + if not dsn: + return {} + try: + import asyncpg + conn = await asyncpg.connect(dsn) + rows = await conn.fetch( + "select control_id, generation_metadata->'legal_obligations' obl, " + "generation_metadata->'applicability'->>'conditional' cond " + "from compliance.canonical_controls " + "where control_id = any($1::text[]) " + "and generation_metadata ? 'legal_obligations'", cids) + await conn.close() + except Exception as e: + logger.warning("fetch_obligation_markers failed: %s", e) + return {} + out: dict[str, dict] = {} + for r in rows: + obl = r["obl"] + obl = json.loads(obl) if isinstance(obl, str) else obl + if obl: + out[r["control_id"]] = {"obl": obl, "cond": r["cond"]} + return out + + +def compute_obligation_shadow(results: list[dict], text: str, + markers: dict[str, dict]) -> dict[str, Any]: + """Reiner Shadow-Vergleich (keine DB, keine Seiteneffekte). `markers`: + {control_id: {obl:[...], cond:str|None}}. `met` = Legacy-`passed`.""" + from compliance.services.obligation_aggregation import ( + FAILED, LM, NA, PARTIAL, CriterionEval, aggregate_obligations, + ) + from compliance.services.obligation_applicability import applicable + + legacy = 0 + evals: list[Any] = [] + contrib: dict[str, list] = {} + for r in results: + cid = r.get("control_id") + m = markers.get(cid) + if not m: + continue + passed = bool(r.get("passed")) + if not passed: + legacy += 1 + for ob in m["obl"]: + evals.append(CriterionEval(ob, LM, passed, cid, "", "", m.get("cond"))) + contrib.setdefault(ob, []).append((cid, passed)) + if not evals: + return {"status": "no obligation markers on result controls"} + + obls = aggregate_obligations(evals, applicable_fn=applicable, doc_text=text) + findings = sum(1 for o in obls if o.status in (FAILED, PARTIAL)) + na = sum(1 for o in obls if o.status == NA) + top = [] + for o in obls: + cs = contrib.get(o.obligation_id, []) + fehlt = sum(1 for _, p in cs if not p) + if fehlt >= 2: + top.append({"obligation": o.obligation_id, "fehlt": fehlt, + "total": len(cs), "status": o.status}) + top.sort(key=lambda x: -x["fehlt"]) + return { + "legacy_control_findings": legacy, + "obligation_shadow_results": len(obls), + "obligation_findings": findings, + "collapse_factor": round(legacy / findings, 2) if findings else None, + "na_count": na, + "met_failed_delta": legacy - findings, + "top_collapsed_obligations": top[:10], + } + + +async def build_obligation_shadow(results: list[dict], text: str, + db_url: str = "") -> dict[str, Any]: + """Async-Wrapper: Marker laden, dann Shadow rechnen. NIE in `results` schreiben.""" + cids = [r.get("control_id") for r in results if r.get("control_id")] + markers = await fetch_obligation_markers(cids, db_url) + if not markers: + return {"status": "no markers"} + return compute_obligation_shadow(results, text, markers) diff --git a/backend-compliance/compliance/services/specialist_agents/dse/v3_engine.py b/backend-compliance/compliance/services/specialist_agents/dse/v3_engine.py index eee7cfae..901c66e5 100644 --- a/backend-compliance/compliance/services/specialist_agents/dse/v3_engine.py +++ b/backend-compliance/compliance/services/specialist_agents/dse/v3_engine.py @@ -158,6 +158,17 @@ async def run_v3_pipeline( except Exception as e: logger.warning("dse tiered eval skipped: %s", e) + # Layer 4 (SHADOW): Obligation-Aggregation NUR in die Telemetrie. Greift NICHT + # in `results` ein — nutzer-sichtbare Findings bleiben unverändert. Liefert die + # Vergleichszahlen für den späteren Umschalt-Entscheid (collapse_factor etc.). + obligation_shadow: dict[str, Any] = {} + try: + from ._obligation_shadow import build_obligation_shadow + obligation_shadow = await build_obligation_shadow(results, text, db_url) + except Exception as e: + logger.warning("dse obligation shadow skipped: %s", e) + obligation_shadow = {"error": str(e)} + telemetry = { "layer_0_field_hits": len(boost_field_ids), "layer_0_field_ids": boost_field_ids, @@ -169,6 +180,7 @@ async def run_v3_pipeline( "offtopic_dropped": drop_stats.get("offtopic_dropped", 0), "gate_excluded": len(organizational), "organizational_checklist": organizational, + "obligation_shadow": obligation_shadow, } logger.info("dse v3 telemetry: %s", telemetry) return results, telemetry diff --git a/backend-compliance/tests/test_obligation_shadow.py b/backend-compliance/tests/test_obligation_shadow.py new file mode 100644 index 00000000..ad6a7737 --- /dev/null +++ b/backend-compliance/tests/test_obligation_shadow.py @@ -0,0 +1,52 @@ +"""Unit-Tests für die DSE Shadow-Verdrahtung (compute_obligation_shadow, pure).""" +from compliance.services.specialist_agents.dse._obligation_shadow import ( + compute_obligation_shadow, +) + + +def _markers(n, ob, cond=None): + return {f"C{i}": {"obl": [ob], "cond": cond} for i in range(n)} + + +class TestComputeShadow: + def test_collapse_and_delta(self): + results = [{"control_id": f"C{i}", "passed": False} for i in range(5)] + s = compute_obligation_shadow(results, "x", _markers(5, "recipients_disclosed")) + assert s["legacy_control_findings"] == 5 + assert s["obligation_findings"] == 1 # 5 → 1 + assert s["collapse_factor"] == 5.0 + assert s["met_failed_delta"] == 4 + top = s["top_collapsed_obligations"][0] + assert top["obligation"] == "recipients_disclosed" and top["fehlt"] == 5 + + def test_fp_correction_one_passed_collapses_to_met(self): + results = [{"control_id": f"C{i}", "passed": i == 0} for i in range(5)] + s = compute_obligation_shadow(results, "x", _markers(5, "recipients_disclosed")) + assert s["legacy_control_findings"] == 4 + assert s["obligation_findings"] == 0 # MET (anderswo erfüllt) + assert s["met_failed_delta"] == 4 + + def test_na_when_predicate_false(self): + results = [{"control_id": "C0", "passed": False}] + m = {"C0": {"obl": ["third_country_transfer_disclosed"], + "cond": "has_third_country_transfer"}} + s = compute_obligation_shadow(results, "nur innerhalb der eu", m) + assert s["na_count"] == 1 + assert s["obligation_findings"] == 0 # NA statt FEHLT + + def test_na_predicate_true_keeps_finding(self): + results = [{"control_id": "C0", "passed": False}] + m = {"C0": {"obl": ["third_country_transfer_disclosed"], + "cond": "has_third_country_transfer"}} + s = compute_obligation_shadow(results, "übermittlung in ein drittland", m) + assert s["na_count"] == 0 + assert s["obligation_findings"] == 1 + + def test_no_markers_returns_status(self): + s = compute_obligation_shadow([{"control_id": "C0", "passed": False}], "x", {}) + assert "no obligation" in s["status"] + + def test_does_not_mutate_results(self): + results = [{"control_id": "C0", "passed": False}] + compute_obligation_shadow(results, "x", _markers(1, "recipients_disclosed")) + assert set(results[0].keys()) == {"control_id", "passed"} # unverändert