- _write_atomic_control() now uses RETURNING id and inserts into
control_parent_links (M:N) with source_regulation, source_article,
and obligation_candidate_id parsed from parent's source_citation
- New _parse_citation() helper for JSONB source_citation extraction
- New GET /controls/{id}/traceability endpoint returning full chain:
parent links with obligations, child controls, source_count
- Backend: control_type filter (atomic/rich) for controls + count
- Frontend: Rechtsgrundlagen section in ControlDetail showing all
parent links per source regulation with obligation text + strength
- Frontend: Atomic/Rich filter dropdown in Control Library list
- Frontend: GenerationStrategyBadge recognizes 'pass0b' strategy
- Tests: 3 new tests for parent_link creation + citation parsing,
existing batch test mock updated for RETURNING clause
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1799 lines
63 KiB
Python
1799 lines
63 KiB
Python
"""Tests for Decomposition Pass (Pass 0a + 0b).
|
|
|
|
Covers:
|
|
- ObligationCandidate / AtomicControlCandidate dataclasses
|
|
- Normative signal detection (regex patterns)
|
|
- Quality Gate (all 6 checks)
|
|
- passes_quality_gate logic
|
|
- _compute_extraction_confidence
|
|
- _parse_json_array / _parse_json_object
|
|
- _format_field / _format_citation
|
|
- _normalize_severity
|
|
- _template_fallback
|
|
- _build_pass0a_prompt / _build_pass0b_prompt
|
|
- DecompositionPass.run_pass0a (mocked LLM + DB)
|
|
- DecompositionPass.run_pass0b (mocked LLM + DB)
|
|
- DecompositionPass.decomposition_status (mocked DB)
|
|
"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
|
|
from compliance.services.decomposition_pass import (
|
|
ObligationCandidate,
|
|
AtomicControlCandidate,
|
|
quality_gate,
|
|
passes_quality_gate,
|
|
classify_obligation_type,
|
|
_NORMATIVE_RE,
|
|
_PFLICHT_RE,
|
|
_EMPFEHLUNG_RE,
|
|
_KANN_RE,
|
|
_RATIONALE_RE,
|
|
_TEST_RE,
|
|
_REPORTING_RE,
|
|
_parse_json_array,
|
|
_parse_json_object,
|
|
_ensure_list,
|
|
_format_field,
|
|
_format_citation,
|
|
_compute_extraction_confidence,
|
|
_normalize_severity,
|
|
_template_fallback,
|
|
_fallback_obligation,
|
|
_build_pass0a_prompt,
|
|
_build_pass0b_prompt,
|
|
_build_pass0a_batch_prompt,
|
|
_build_pass0b_batch_prompt,
|
|
_PASS0A_SYSTEM_PROMPT,
|
|
_PASS0B_SYSTEM_PROMPT,
|
|
DecompositionPass,
|
|
_classify_trigger_type,
|
|
_is_implementation_specific_text,
|
|
_text_similar,
|
|
_is_more_implementation_specific,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DATACLASS TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestObligationCandidate:
|
|
"""Tests for ObligationCandidate dataclass."""
|
|
|
|
def test_defaults(self):
|
|
oc = ObligationCandidate()
|
|
assert oc.candidate_id == ""
|
|
assert oc.normative_strength == "must"
|
|
assert oc.is_test_obligation is False
|
|
assert oc.release_state == "extracted"
|
|
assert oc.quality_flags == {}
|
|
|
|
def test_to_dict(self):
|
|
oc = ObligationCandidate(
|
|
candidate_id="OC-001-01",
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Betreiber müssen MFA implementieren",
|
|
action="implementieren",
|
|
object_="MFA",
|
|
)
|
|
d = oc.to_dict()
|
|
assert d["candidate_id"] == "OC-001-01"
|
|
assert d["object"] == "MFA"
|
|
assert "object_" not in d # should be "object" in dict
|
|
|
|
def test_full_creation(self):
|
|
oc = ObligationCandidate(
|
|
candidate_id="OC-MICA-0001-01",
|
|
parent_control_uuid="uuid-abc",
|
|
obligation_text="Betreiber müssen Kontinuität sicherstellen",
|
|
action="sicherstellen",
|
|
object_="Dienstleistungskontinuität",
|
|
condition="bei Ausfall des Handelssystems",
|
|
normative_strength="must",
|
|
is_test_obligation=False,
|
|
is_reporting_obligation=False,
|
|
extraction_confidence=0.90,
|
|
)
|
|
assert oc.condition == "bei Ausfall des Handelssystems"
|
|
assert oc.extraction_confidence == 0.90
|
|
|
|
|
|
class TestAtomicControlCandidate:
|
|
"""Tests for AtomicControlCandidate dataclass."""
|
|
|
|
def test_defaults(self):
|
|
ac = AtomicControlCandidate()
|
|
assert ac.severity == "medium"
|
|
assert ac.requirements == []
|
|
assert ac.test_procedure == []
|
|
|
|
def test_to_dict(self):
|
|
ac = AtomicControlCandidate(
|
|
candidate_id="AC-FIN-001",
|
|
title="Service Continuity Mechanism",
|
|
objective="Ensure continuity upon failure.",
|
|
requirements=["Failover mechanism"],
|
|
)
|
|
d = ac.to_dict()
|
|
assert d["title"] == "Service Continuity Mechanism"
|
|
assert len(d["requirements"]) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NORMATIVE SIGNAL DETECTION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNormativeSignals:
|
|
"""Tests for normative regex patterns."""
|
|
|
|
def test_muessen_detected(self):
|
|
assert _NORMATIVE_RE.search("Betreiber müssen sicherstellen")
|
|
|
|
def test_muss_detected(self):
|
|
assert _NORMATIVE_RE.search("Das System muss implementiert sein")
|
|
|
|
def test_hat_sicherzustellen(self):
|
|
assert _NORMATIVE_RE.search("Der Verantwortliche hat sicherzustellen")
|
|
|
|
def test_sind_verpflichtet(self):
|
|
assert _NORMATIVE_RE.search("Anbieter sind verpflichtet zu melden")
|
|
|
|
def test_ist_zu_dokumentieren(self):
|
|
assert _NORMATIVE_RE.search("Der Vorfall ist zu dokumentieren")
|
|
|
|
def test_shall(self):
|
|
assert _NORMATIVE_RE.search("The operator shall implement MFA")
|
|
|
|
def test_no_signal(self):
|
|
assert not _NORMATIVE_RE.search("Die Sonne scheint heute")
|
|
|
|
def test_rationale_detected(self):
|
|
assert _RATIONALE_RE.search("da schwache Passwörter Risiken bergen")
|
|
|
|
def test_test_signal_detected(self):
|
|
assert _TEST_RE.search("regelmäßige Tests der Wirksamkeit")
|
|
|
|
def test_reporting_signal_detected(self):
|
|
assert _REPORTING_RE.search("Behörden sind zu unterrichten")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# QUALITY GATE TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQualityGate:
|
|
"""Tests for quality_gate function."""
|
|
|
|
def test_valid_normative_obligation(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Betreiber müssen Verschlüsselung implementieren",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["has_normative_signal"] is True
|
|
assert flags["not_evidence_only"] is True
|
|
assert flags["min_length"] is True
|
|
assert flags["has_parent_link"] is True
|
|
|
|
def test_rationale_detected(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Dies liegt daran, weil schwache Konfigurationen ein Risiko darstellen",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["not_rationale"] is False
|
|
|
|
def test_evidence_only_rejected(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Screenshot der Konfiguration",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["not_evidence_only"] is False
|
|
|
|
def test_too_short_rejected(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="MFA",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["min_length"] is False
|
|
|
|
def test_no_parent_link(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="",
|
|
obligation_text="Betreiber müssen MFA implementieren",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["has_parent_link"] is False
|
|
|
|
def test_multi_verb_detected(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Betreiber müssen implementieren und dokumentieren sowie regelmäßig testen",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["single_action"] is False
|
|
|
|
def test_single_verb_passes(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Betreiber müssen MFA für alle privilegierten Konten implementieren",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["single_action"] is True
|
|
|
|
def test_no_normative_signal(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Ein DR-Plan beschreibt die Wiederherstellungsprozeduren im Detail",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["has_normative_signal"] is False
|
|
assert flags["obligation_type"] == "empfehlung"
|
|
|
|
def test_obligation_type_in_flags(self):
|
|
oc = ObligationCandidate(
|
|
parent_control_uuid="uuid-1",
|
|
obligation_text="Der Betreiber muss alle Daten verschlüsseln.",
|
|
)
|
|
flags = quality_gate(oc)
|
|
assert flags["obligation_type"] == "pflicht"
|
|
|
|
|
|
class TestPassesQualityGate:
|
|
"""Tests for passes_quality_gate function.
|
|
|
|
Note: has_normative_signal is NO LONGER critical — obligations without
|
|
normative signal are classified as 'empfehlung' instead of being rejected.
|
|
"""
|
|
|
|
def test_all_critical_pass(self):
|
|
flags = {
|
|
"has_normative_signal": True,
|
|
"obligation_type": "pflicht",
|
|
"single_action": True,
|
|
"not_rationale": True,
|
|
"not_evidence_only": True,
|
|
"min_length": True,
|
|
"has_parent_link": True,
|
|
}
|
|
assert passes_quality_gate(flags) is True
|
|
|
|
def test_no_normative_signal_still_passes(self):
|
|
"""No normative signal no longer causes rejection — classified as empfehlung."""
|
|
flags = {
|
|
"has_normative_signal": False,
|
|
"obligation_type": "empfehlung",
|
|
"single_action": True,
|
|
"not_rationale": True,
|
|
"not_evidence_only": True,
|
|
"min_length": True,
|
|
"has_parent_link": True,
|
|
}
|
|
assert passes_quality_gate(flags) is True
|
|
|
|
def test_evidence_only_fails(self):
|
|
flags = {
|
|
"has_normative_signal": True,
|
|
"obligation_type": "pflicht",
|
|
"single_action": True,
|
|
"not_rationale": True,
|
|
"not_evidence_only": False,
|
|
"min_length": True,
|
|
"has_parent_link": True,
|
|
}
|
|
assert passes_quality_gate(flags) is False
|
|
|
|
def test_non_critical_dont_block(self):
|
|
"""single_action, not_rationale, has_normative_signal are NOT critical."""
|
|
flags = {
|
|
"has_normative_signal": False, # Not critical
|
|
"obligation_type": "empfehlung",
|
|
"single_action": False, # Not critical
|
|
"not_rationale": False, # Not critical
|
|
"not_evidence_only": True,
|
|
"min_length": True,
|
|
"has_parent_link": True,
|
|
}
|
|
assert passes_quality_gate(flags) is True
|
|
|
|
|
|
class TestClassifyObligationType:
|
|
"""Tests for the 3-tier obligation type classification."""
|
|
|
|
def test_pflicht_muss(self):
|
|
assert classify_obligation_type("Der Betreiber muss alle Daten verschlüsseln") == "pflicht"
|
|
|
|
def test_pflicht_ist_zu(self):
|
|
assert classify_obligation_type("Die Meldung ist innerhalb von 72 Stunden zu erstatten") == "pflicht"
|
|
|
|
def test_pflicht_shall(self):
|
|
assert classify_obligation_type("The controller shall implement appropriate measures") == "pflicht"
|
|
|
|
def test_empfehlung_soll(self):
|
|
assert classify_obligation_type("Der Betreiber soll regelmäßige Audits durchführen") == "empfehlung"
|
|
|
|
def test_empfehlung_should(self):
|
|
assert classify_obligation_type("Organizations should implement security controls") == "empfehlung"
|
|
|
|
def test_empfehlung_sicherstellen(self):
|
|
assert classify_obligation_type("Die Verfügbarkeit der Systeme sicherstellen") == "empfehlung"
|
|
|
|
def test_kann(self):
|
|
assert classify_obligation_type("Der Betreiber kann zusätzliche Maßnahmen ergreifen") == "kann"
|
|
|
|
def test_kann_may(self):
|
|
assert classify_obligation_type("The organization may implement optional safeguards") == "kann"
|
|
|
|
def test_no_signal_defaults_to_empfehlung(self):
|
|
assert classify_obligation_type("Regelmäßige Überprüfung der Zugriffsrechte") == "empfehlung"
|
|
|
|
def test_pflicht_overrides_empfehlung(self):
|
|
"""If both pflicht and empfehlung signals present, pflicht wins."""
|
|
txt = "Der Betreiber muss sicherstellen, dass alle Daten verschlüsselt werden"
|
|
assert classify_obligation_type(txt) == "pflicht"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HELPER TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeExtractionConfidence:
|
|
"""Tests for _compute_extraction_confidence."""
|
|
|
|
def test_all_flags_pass(self):
|
|
flags = {
|
|
"has_normative_signal": True,
|
|
"single_action": True,
|
|
"not_rationale": True,
|
|
"not_evidence_only": True,
|
|
"min_length": True,
|
|
"has_parent_link": True,
|
|
}
|
|
assert _compute_extraction_confidence(flags) == 1.0
|
|
|
|
def test_no_flags_pass(self):
|
|
flags = {
|
|
"has_normative_signal": False,
|
|
"single_action": False,
|
|
"not_rationale": False,
|
|
"not_evidence_only": False,
|
|
"min_length": False,
|
|
"has_parent_link": False,
|
|
}
|
|
assert _compute_extraction_confidence(flags) == 0.0
|
|
|
|
def test_partial_flags(self):
|
|
flags = {
|
|
"has_normative_signal": True, # 0.30
|
|
"single_action": False,
|
|
"not_rationale": True, # 0.20
|
|
"not_evidence_only": True, # 0.15
|
|
"min_length": True, # 0.10
|
|
"has_parent_link": True, # 0.05
|
|
}
|
|
assert _compute_extraction_confidence(flags) == 0.80
|
|
|
|
|
|
class TestParseJsonArray:
|
|
"""Tests for _parse_json_array."""
|
|
|
|
def test_valid_array(self):
|
|
result = _parse_json_array('[{"a": 1}, {"a": 2}]')
|
|
assert len(result) == 2
|
|
assert result[0]["a"] == 1
|
|
|
|
def test_single_object_wrapped(self):
|
|
result = _parse_json_array('{"a": 1}')
|
|
assert len(result) == 1
|
|
|
|
def test_embedded_in_text(self):
|
|
result = _parse_json_array('Here is the result:\n[{"a": 1}]\nDone.')
|
|
assert len(result) == 1
|
|
|
|
def test_invalid_returns_empty(self):
|
|
result = _parse_json_array("not json at all")
|
|
assert result == []
|
|
|
|
def test_empty_array(self):
|
|
result = _parse_json_array("[]")
|
|
assert result == []
|
|
|
|
|
|
class TestParseJsonObject:
|
|
"""Tests for _parse_json_object."""
|
|
|
|
def test_valid_object(self):
|
|
result = _parse_json_object('{"title": "MFA"}')
|
|
assert result["title"] == "MFA"
|
|
|
|
def test_embedded_in_text(self):
|
|
result = _parse_json_object('```json\n{"title": "MFA"}\n```')
|
|
assert result["title"] == "MFA"
|
|
|
|
def test_invalid_returns_empty(self):
|
|
result = _parse_json_object("not json")
|
|
assert result == {}
|
|
|
|
|
|
class TestEnsureList:
|
|
"""Tests for _ensure_list."""
|
|
|
|
def test_list_passthrough(self):
|
|
assert _ensure_list(["a", "b"]) == ["a", "b"]
|
|
|
|
def test_string_wrapped(self):
|
|
assert _ensure_list("hello") == ["hello"]
|
|
|
|
def test_empty_string(self):
|
|
assert _ensure_list("") == []
|
|
|
|
def test_none(self):
|
|
assert _ensure_list(None) == []
|
|
|
|
def test_int(self):
|
|
assert _ensure_list(42) == []
|
|
|
|
|
|
class TestFormatField:
|
|
"""Tests for _format_field."""
|
|
|
|
def test_string_passthrough(self):
|
|
assert _format_field("hello") == "hello"
|
|
|
|
def test_json_list_string(self):
|
|
result = _format_field('["Req 1", "Req 2"]')
|
|
assert "- Req 1" in result
|
|
assert "- Req 2" in result
|
|
|
|
def test_list_input(self):
|
|
result = _format_field(["A", "B"])
|
|
assert "- A" in result
|
|
assert "- B" in result
|
|
|
|
def test_empty(self):
|
|
assert _format_field("") == ""
|
|
assert _format_field(None) == ""
|
|
|
|
|
|
class TestFormatCitation:
|
|
"""Tests for _format_citation."""
|
|
|
|
def test_json_dict(self):
|
|
result = _format_citation('{"source": "MiCA", "article": "Art. 8"}')
|
|
assert "MiCA" in result
|
|
assert "Art. 8" in result
|
|
|
|
def test_plain_string(self):
|
|
assert _format_citation("MiCA Art. 8") == "MiCA Art. 8"
|
|
|
|
def test_empty(self):
|
|
assert _format_citation("") == ""
|
|
assert _format_citation(None) == ""
|
|
|
|
|
|
class TestNormalizeSeverity:
|
|
"""Tests for _normalize_severity."""
|
|
|
|
def test_valid_values(self):
|
|
assert _normalize_severity("critical") == "critical"
|
|
assert _normalize_severity("HIGH") == "high"
|
|
assert _normalize_severity(" Medium ") == "medium"
|
|
assert _normalize_severity("low") == "low"
|
|
|
|
def test_invalid_defaults_to_medium(self):
|
|
assert _normalize_severity("unknown") == "medium"
|
|
assert _normalize_severity("") == "medium"
|
|
assert _normalize_severity(None) == "medium"
|
|
|
|
|
|
class TestTemplateFallback:
|
|
"""Tests for _template_fallback."""
|
|
|
|
def test_normal_obligation(self):
|
|
ac = _template_fallback(
|
|
obligation_text="Betreiber müssen MFA implementieren",
|
|
action="implementieren",
|
|
object_="MFA",
|
|
parent_title="Authentication Controls",
|
|
parent_severity="high",
|
|
parent_category="authentication",
|
|
is_test=False,
|
|
is_reporting=False,
|
|
)
|
|
assert "Implementieren" in ac.title
|
|
assert ac.severity == "high"
|
|
assert len(ac.requirements) == 1
|
|
|
|
def test_test_obligation(self):
|
|
ac = _template_fallback(
|
|
obligation_text="MFA muss regelmäßig getestet werden",
|
|
action="testen",
|
|
object_="MFA-Wirksamkeit",
|
|
parent_title="MFA Control",
|
|
parent_severity="medium",
|
|
parent_category="auth",
|
|
is_test=True,
|
|
is_reporting=False,
|
|
)
|
|
assert "Test:" in ac.title
|
|
assert "Testprotokoll" in ac.evidence
|
|
|
|
def test_reporting_obligation(self):
|
|
ac = _template_fallback(
|
|
obligation_text="Behörden sind über Vorfälle zu informieren",
|
|
action="informieren",
|
|
object_="zuständige Behörden",
|
|
parent_title="Incident Reporting",
|
|
parent_severity="high",
|
|
parent_category="governance",
|
|
is_test=False,
|
|
is_reporting=True,
|
|
)
|
|
assert "Meldepflicht:" in ac.title
|
|
assert "Meldeprozess-Dokumentation" in ac.evidence
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PROMPT BUILDER TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPromptBuilders:
|
|
"""Tests for LLM prompt builders."""
|
|
|
|
def test_pass0a_prompt_contains_all_fields(self):
|
|
prompt = _build_pass0a_prompt(
|
|
title="MFA Control",
|
|
objective="Implement MFA",
|
|
requirements="- Require TOTP\n- Hardware key",
|
|
test_procedure="- Test login",
|
|
source_ref="DSGVO Art. 32",
|
|
)
|
|
assert "MFA Control" in prompt
|
|
assert "Implement MFA" in prompt
|
|
assert "Require TOTP" in prompt
|
|
assert "DSGVO Art. 32" in prompt
|
|
assert "JSON-Array" in prompt
|
|
|
|
def test_pass0b_prompt_contains_all_fields(self):
|
|
prompt = _build_pass0b_prompt(
|
|
obligation_text="MFA implementieren",
|
|
action="implementieren",
|
|
object_="MFA",
|
|
parent_title="Auth Controls",
|
|
parent_category="authentication",
|
|
source_ref="DSGVO Art. 32",
|
|
)
|
|
assert "MFA implementieren" in prompt
|
|
assert "implementieren" in prompt
|
|
assert "Auth Controls" in prompt
|
|
assert "JSON" in prompt
|
|
|
|
def test_system_prompts_exist(self):
|
|
assert "REGELN" in _PASS0A_SYSTEM_PROMPT
|
|
assert "atomares" in _PASS0B_SYSTEM_PROMPT
|
|
|
|
def test_pass0a_prompt_contains_cot_steps(self):
|
|
"""Pass 0a system prompt must include Chain-of-Thought analysis steps."""
|
|
assert "ANALYSE-SCHRITTE" in _PASS0A_SYSTEM_PROMPT
|
|
assert "Adressaten" in _PASS0A_SYSTEM_PROMPT
|
|
assert "Handlung" in _PASS0A_SYSTEM_PROMPT
|
|
assert "normative Staerke" in _PASS0A_SYSTEM_PROMPT
|
|
assert "Meldepflicht" in _PASS0A_SYSTEM_PROMPT
|
|
assert "NICHT im Output" in _PASS0A_SYSTEM_PROMPT
|
|
|
|
def test_pass0b_prompt_contains_cot_steps(self):
|
|
"""Pass 0b system prompt must include Chain-of-Thought analysis steps."""
|
|
assert "ANALYSE-SCHRITTE" in _PASS0B_SYSTEM_PROMPT
|
|
assert "Anforderung" in _PASS0B_SYSTEM_PROMPT
|
|
assert "Massnahme" in _PASS0B_SYSTEM_PROMPT
|
|
assert "Pruefverfahren" in _PASS0B_SYSTEM_PROMPT
|
|
assert "Nachweis" in _PASS0B_SYSTEM_PROMPT
|
|
assert "NICHT im Output" in _PASS0B_SYSTEM_PROMPT
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DECOMPOSITION PASS INTEGRATION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDecompositionPassRun0a:
|
|
"""Tests for DecompositionPass.run_pass0a."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_extracts_obligations(self):
|
|
mock_db = MagicMock()
|
|
|
|
# Rich controls to decompose
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
(
|
|
"uuid-1", "CTRL-001",
|
|
"Service Continuity",
|
|
"Sicherstellen der Dienstleistungskontinuität",
|
|
'["Mechanismen implementieren", "Systeme testen"]',
|
|
'["Prüfung der Mechanismen"]',
|
|
'{"source": "MiCA", "article": "Art. 8"}',
|
|
"finance",
|
|
),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
llm_response = json.dumps([
|
|
{
|
|
"obligation_text": "Betreiber müssen Mechanismen zur Dienstleistungskontinuität implementieren",
|
|
"action": "implementieren",
|
|
"object": "Kontinuitätsmechanismen",
|
|
"condition": "bei Ausfall des Handelssystems",
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False,
|
|
"is_reporting_obligation": False,
|
|
},
|
|
{
|
|
"obligation_text": "Kontinuitätsmechanismen müssen regelmäßig getestet werden",
|
|
"action": "testen",
|
|
"object": "Kontinuitätsmechanismen",
|
|
"condition": None,
|
|
"normative_strength": "must",
|
|
"is_test_obligation": True,
|
|
"is_reporting_obligation": False,
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = llm_response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(limit=10)
|
|
|
|
assert stats["controls_processed"] == 1
|
|
assert stats["obligations_extracted"] == 2
|
|
assert stats["obligations_validated"] == 2
|
|
assert stats["errors"] == 0
|
|
|
|
# Verify DB writes: 1 SELECT + 2 INSERTs + 1 COMMIT
|
|
assert mock_db.execute.call_count >= 3
|
|
mock_db.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_fallback_on_empty_llm(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
(
|
|
"uuid-1", "CTRL-001",
|
|
"MFA Control",
|
|
"Betreiber müssen MFA implementieren",
|
|
"", "", "", "auth",
|
|
),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = "I cannot help with that." # Invalid JSON
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(limit=10)
|
|
|
|
assert stats["controls_processed"] == 1
|
|
# Fallback should create 1 obligation from the objective
|
|
assert stats["obligations_extracted"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_skips_empty_controls(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
("uuid-1", "CTRL-001", "", "", "", "", "", ""),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
# No LLM call needed — empty controls are skipped before LLM
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(limit=10)
|
|
|
|
assert stats["controls_skipped_empty"] == 1
|
|
assert stats["controls_processed"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_rejects_evidence_only(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
(
|
|
"uuid-1", "CTRL-001",
|
|
"Evidence List",
|
|
"Betreiber müssen Nachweise erbringen",
|
|
"", "", "", "governance",
|
|
),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
llm_response = json.dumps([
|
|
{
|
|
"obligation_text": "Dokumentation der Konfiguration",
|
|
"action": "dokumentieren",
|
|
"object": "Konfiguration",
|
|
"condition": None,
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False,
|
|
"is_reporting_obligation": False,
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = llm_response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(limit=10)
|
|
|
|
assert stats["obligations_extracted"] == 1
|
|
assert stats["obligations_rejected"] == 1
|
|
|
|
|
|
class TestDecompositionPassRun0b:
|
|
"""Tests for DecompositionPass.run_pass0b."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0b_creates_atomic_controls(self):
|
|
mock_db = MagicMock()
|
|
|
|
# Validated obligation candidates
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
(
|
|
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
|
|
"Betreiber müssen Kontinuität sicherstellen",
|
|
"sicherstellen", "Dienstleistungskontinuität",
|
|
False, False, # is_test, is_reporting
|
|
"Service Continuity", "finance",
|
|
'{"source": "MiCA", "article": "Art. 8"}',
|
|
"high", "FIN-001",
|
|
"continuous", False, # trigger_type, is_implementation_specific
|
|
),
|
|
]
|
|
|
|
# Mock _next_atomic_seq result
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
|
|
# Call sequence: 1=SELECT, 2=_next_atomic_seq, 3=INSERT control, 4=UPDATE oc
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_rows # SELECT candidates
|
|
if call_count[0] == 2:
|
|
return mock_seq # _next_atomic_seq
|
|
return MagicMock() # INSERT/UPDATE
|
|
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
llm_response = json.dumps({
|
|
"title": "Dienstleistungskontinuität bei Systemausfall",
|
|
"objective": "Sicherstellen, dass Dienstleistungen fortgeführt werden.",
|
|
"requirements": ["Failover-Mechanismus implementieren"],
|
|
"test_procedure": ["Failover-Test durchführen"],
|
|
"evidence": ["Systemarchitektur", "DR-Plan"],
|
|
"severity": "high",
|
|
"category": "operations",
|
|
})
|
|
|
|
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = llm_response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0b(limit=10)
|
|
|
|
assert stats["candidates_processed"] == 1
|
|
assert stats["controls_created"] == 1
|
|
assert stats["llm_failures"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0b_template_fallback(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
(
|
|
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
|
|
"Betreiber müssen MFA implementieren",
|
|
"implementieren", "MFA",
|
|
False, False,
|
|
"Auth Controls", "authentication",
|
|
"", "high", "AUTH-001",
|
|
"continuous", False,
|
|
),
|
|
]
|
|
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_rows
|
|
if call_count[0] == 2:
|
|
return mock_seq
|
|
return MagicMock()
|
|
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = "Sorry, invalid response" # LLM fails
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0b(limit=10)
|
|
|
|
assert stats["controls_created"] == 1
|
|
assert stats["llm_failures"] == 1
|
|
|
|
|
|
class TestDecompositionStatus:
|
|
"""Tests for DecompositionPass.decomposition_status."""
|
|
|
|
def test_returns_status(self):
|
|
mock_db = MagicMock()
|
|
mock_result = MagicMock()
|
|
# 9 columns: rich, decomposed, total, validated, rejected, composed, atomic, merged, enriched
|
|
mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800, 100, 2400)
|
|
mock_db.execute.return_value = mock_result
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
status = decomp.decomposition_status()
|
|
|
|
assert status["rich_controls"] == 5000
|
|
assert status["decomposed_controls"] == 1000
|
|
assert status["total_candidates"] == 3000
|
|
assert status["validated"] == 2500
|
|
assert status["rejected"] == 200
|
|
assert status["composed"] == 2000
|
|
assert status["atomic_controls"] == 1800
|
|
assert status["merged"] == 100
|
|
assert status["enriched"] == 2400
|
|
assert status["ready_for_pass0b"] == 2400 # 2500 validated - 100 merged
|
|
assert status["decomposition_pct"] == 20.0
|
|
# composition_pct: 2000 composed / 2400 ready_for_pass0b
|
|
assert status["composition_pct"] == 83.3
|
|
|
|
def test_handles_zero_division(self):
|
|
mock_db = MagicMock()
|
|
mock_result = MagicMock()
|
|
mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0, 0, 0)
|
|
mock_db.execute.return_value = mock_result
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
status = decomp.decomposition_status()
|
|
|
|
assert status["decomposition_pct"] == 0.0
|
|
assert status["composition_pct"] == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MIGRATION 061 SCHEMA TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMigration061:
|
|
"""Tests for migration 061 SQL file."""
|
|
|
|
def test_migration_file_exists(self):
|
|
from pathlib import Path
|
|
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
|
|
assert migration.exists(), "Migration 061 file missing"
|
|
|
|
def test_migration_contains_required_tables(self):
|
|
from pathlib import Path
|
|
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
|
|
content = migration.read_text()
|
|
assert "obligation_candidates" in content
|
|
assert "parent_control_uuid" in content
|
|
assert "decomposition_method" in content
|
|
assert "candidate_id" in content
|
|
assert "quality_flags" in content
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# BATCH PROMPT TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchPromptBuilders:
|
|
"""Tests for batch prompt builders."""
|
|
|
|
def test_pass0a_batch_prompt_contains_all_controls(self):
|
|
controls = [
|
|
{
|
|
"control_id": "AUTH-001",
|
|
"title": "MFA Control",
|
|
"objective": "Implement MFA",
|
|
"requirements": "- TOTP required",
|
|
"test_procedure": "- Test login",
|
|
"source_ref": "DSGVO Art. 32",
|
|
},
|
|
{
|
|
"control_id": "AUTH-002",
|
|
"title": "Password Policy",
|
|
"objective": "Enforce strong passwords",
|
|
"requirements": "- Min 12 chars",
|
|
"test_procedure": "- Test weak password",
|
|
"source_ref": "BSI IT-Grundschutz",
|
|
},
|
|
]
|
|
prompt = _build_pass0a_batch_prompt(controls)
|
|
assert "AUTH-001" in prompt
|
|
assert "AUTH-002" in prompt
|
|
assert "MFA Control" in prompt
|
|
assert "Password Policy" in prompt
|
|
assert "CONTROL 1" in prompt
|
|
assert "CONTROL 2" in prompt
|
|
assert "2 Controls" in prompt
|
|
|
|
def test_pass0a_batch_prompt_single_control(self):
|
|
controls = [
|
|
{
|
|
"control_id": "AUTH-001",
|
|
"title": "MFA",
|
|
"objective": "MFA",
|
|
"requirements": "",
|
|
"test_procedure": "",
|
|
"source_ref": "",
|
|
},
|
|
]
|
|
prompt = _build_pass0a_batch_prompt(controls)
|
|
assert "AUTH-001" in prompt
|
|
assert "1 Controls" in prompt
|
|
|
|
def test_pass0b_batch_prompt_contains_all_obligations(self):
|
|
obligations = [
|
|
{
|
|
"candidate_id": "OC-AUTH-001-01",
|
|
"obligation_text": "MFA implementieren",
|
|
"action": "implementieren",
|
|
"object": "MFA",
|
|
"parent_title": "Auth Controls",
|
|
"parent_category": "authentication",
|
|
"source_ref": "DSGVO Art. 32",
|
|
},
|
|
{
|
|
"candidate_id": "OC-AUTH-001-02",
|
|
"obligation_text": "MFA testen",
|
|
"action": "testen",
|
|
"object": "MFA",
|
|
"parent_title": "Auth Controls",
|
|
"parent_category": "authentication",
|
|
"source_ref": "DSGVO Art. 32",
|
|
},
|
|
]
|
|
prompt = _build_pass0b_batch_prompt(obligations)
|
|
assert "OC-AUTH-001-01" in prompt
|
|
assert "OC-AUTH-001-02" in prompt
|
|
assert "PFLICHT 1" in prompt
|
|
assert "PFLICHT 2" in prompt
|
|
assert "2 Pflichten" in prompt
|
|
|
|
|
|
class TestFallbackObligation:
|
|
"""Tests for _fallback_obligation helper."""
|
|
|
|
def test_uses_objective_when_available(self):
|
|
ctrl = {"title": "MFA", "objective": "Implement MFA for all users"}
|
|
result = _fallback_obligation(ctrl)
|
|
assert result["obligation_text"] == "Implement MFA for all users"
|
|
assert result["action"] == "sicherstellen"
|
|
|
|
def test_uses_title_when_no_objective(self):
|
|
ctrl = {"title": "MFA Control", "objective": ""}
|
|
result = _fallback_obligation(ctrl)
|
|
assert result["obligation_text"] == "MFA Control"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ANTHROPIC BATCHING INTEGRATION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDecompositionPassAnthropicBatch:
|
|
"""Tests for batched Anthropic API calls in Pass 0a/0b."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_anthropic_batched(self):
|
|
"""Test Pass 0a with Anthropic API and batch_size=2."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
|
|
"", "", "", "security"),
|
|
("uuid-2", "CTRL-002", "Encryption", "Encrypt data at rest",
|
|
"", "", "", "security"),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
# Anthropic returns JSON object keyed by control_id
|
|
batched_response = json.dumps({
|
|
"CTRL-001": [
|
|
{"obligation_text": "MFA muss implementiert werden",
|
|
"action": "implementieren", "object": "MFA",
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False, "is_reporting_obligation": False},
|
|
],
|
|
"CTRL-002": [
|
|
{"obligation_text": "Daten müssen verschlüsselt werden",
|
|
"action": "verschlüsseln", "object": "Daten",
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False, "is_reporting_obligation": False},
|
|
],
|
|
})
|
|
|
|
with patch(
|
|
"compliance.services.decomposition_pass._llm_anthropic",
|
|
new_callable=AsyncMock,
|
|
) as mock_llm:
|
|
mock_llm.return_value = batched_response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(
|
|
limit=10, batch_size=2, use_anthropic=True,
|
|
)
|
|
|
|
assert stats["controls_processed"] == 2
|
|
assert stats["obligations_extracted"] == 2
|
|
assert stats["llm_calls"] == 1 # Only 1 API call for 2 controls
|
|
assert stats["provider"] == "anthropic"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_anthropic_single(self):
|
|
"""Test Pass 0a with Anthropic API, batch_size=1 (no batching)."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
|
|
"", "", "", "security"),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
response = json.dumps([
|
|
{"obligation_text": "MFA muss implementiert werden",
|
|
"action": "implementieren", "object": "MFA",
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False, "is_reporting_obligation": False},
|
|
])
|
|
|
|
with patch(
|
|
"compliance.services.decomposition_pass._llm_anthropic",
|
|
new_callable=AsyncMock,
|
|
) as mock_llm:
|
|
mock_llm.return_value = response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(
|
|
limit=10, batch_size=1, use_anthropic=True,
|
|
)
|
|
|
|
assert stats["controls_processed"] == 1
|
|
assert stats["llm_calls"] == 1
|
|
assert stats["provider"] == "anthropic"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0b_anthropic_batched(self):
|
|
"""Test Pass 0b with Anthropic API and batch_size=2."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
("oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
|
|
"MFA implementieren", "implementieren", "MFA",
|
|
False, False, "Auth", "security",
|
|
'{"source": "DSGVO", "article": "Art. 32"}',
|
|
"high", "CTRL-001",
|
|
"continuous", False),
|
|
("oc-uuid-2", "OC-CTRL-001-02", "parent-uuid-1",
|
|
"MFA testen", "testen", "MFA",
|
|
True, False, "Auth", "security",
|
|
'{"source": "DSGVO", "article": "Art. 32"}',
|
|
"high", "CTRL-001",
|
|
"periodic", False),
|
|
]
|
|
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_rows # SELECT candidates
|
|
# _next_atomic_seq calls: call 2 (control 1), call 6 (control 2)
|
|
if call_count[0] in (2, 6):
|
|
return mock_seq
|
|
# INSERT RETURNING calls: call 3 (control 1), call 7 (control 2)
|
|
if call_count[0] in (3, 7):
|
|
mock_insert = MagicMock()
|
|
mock_insert.fetchone.return_value = (f"new-uuid-{call_count[0]}",)
|
|
return mock_insert
|
|
return MagicMock() # parent_links INSERT / UPDATE
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
batched_response = json.dumps({
|
|
"OC-CTRL-001-01": {
|
|
"title": "MFA implementieren",
|
|
"objective": "MFA fuer alle Konten.",
|
|
"requirements": ["TOTP einrichten"],
|
|
"test_procedure": ["Login testen"],
|
|
"evidence": ["Konfigurationsnachweis"],
|
|
"severity": "high",
|
|
"category": "security",
|
|
},
|
|
"OC-CTRL-001-02": {
|
|
"title": "MFA-Wirksamkeit testen",
|
|
"objective": "Regelmaessige MFA-Tests.",
|
|
"requirements": ["Testplan erstellen"],
|
|
"test_procedure": ["Testdurchfuehrung"],
|
|
"evidence": ["Testprotokoll"],
|
|
"severity": "high",
|
|
"category": "security",
|
|
},
|
|
})
|
|
|
|
with patch(
|
|
"compliance.services.decomposition_pass._llm_anthropic",
|
|
new_callable=AsyncMock,
|
|
) as mock_llm:
|
|
mock_llm.return_value = batched_response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0b(
|
|
limit=10, batch_size=2, use_anthropic=True,
|
|
)
|
|
|
|
assert stats["controls_created"] == 2
|
|
assert stats["llm_calls"] == 1
|
|
assert stats["provider"] == "anthropic"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOURCE FILTER TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSourceFilter:
|
|
"""Tests for source_filter parameter in Pass 0a."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_source_filter_builds_ilike_query(self):
|
|
"""Verify source_filter adds ILIKE clauses to query."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = [
|
|
("uuid-1", "CTRL-001", "Machine Safety", "Ensure safety",
|
|
"", "", '{"source": "Maschinenverordnung (EU) 2023/1230"}', "security"),
|
|
]
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
response = json.dumps([
|
|
{"obligation_text": "Sicherheit gewaehrleisten",
|
|
"action": "gewaehrleisten", "object": "Sicherheit",
|
|
"normative_strength": "must",
|
|
"is_test_obligation": False, "is_reporting_obligation": False},
|
|
])
|
|
|
|
with patch(
|
|
"compliance.services.decomposition_pass._llm_anthropic",
|
|
new_callable=AsyncMock,
|
|
) as mock_llm:
|
|
mock_llm.return_value = response
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(
|
|
limit=10, batch_size=1, use_anthropic=True,
|
|
source_filter="Maschinenverordnung,Cyber Resilience Act",
|
|
)
|
|
|
|
assert stats["controls_processed"] == 1
|
|
|
|
# Verify the SQL query contained ILIKE clauses
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
query_str = str(call_args[0][0])
|
|
assert "ILIKE" in query_str
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_source_filter_none_no_clause(self):
|
|
"""Verify no ILIKE clause when source_filter is None."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = []
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = await decomp.run_pass0a(
|
|
limit=10, use_anthropic=True, source_filter=None,
|
|
)
|
|
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
query_str = str(call_args[0][0])
|
|
assert "ILIKE" not in query_str
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass0a_combined_category_and_source_filter(self):
|
|
"""Verify both category_filter and source_filter can be used together."""
|
|
mock_db = MagicMock()
|
|
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = []
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
await decomp.run_pass0a(
|
|
limit=10, use_anthropic=True,
|
|
category_filter="security,operations",
|
|
source_filter="Maschinenverordnung",
|
|
)
|
|
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
query_str = str(call_args[0][0])
|
|
assert "IN :cats" in query_str
|
|
assert "ILIKE" in query_str
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TRIGGER TYPE CLASSIFICATION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestClassifyTriggerType:
|
|
"""Tests for _classify_trigger_type helper."""
|
|
|
|
def test_event_trigger_vorfall(self):
|
|
assert _classify_trigger_type(
|
|
"Bei einem Sicherheitsvorfall muss gemeldet werden", ""
|
|
) == "event"
|
|
|
|
def test_event_trigger_condition_field(self):
|
|
assert _classify_trigger_type(
|
|
"Melden", "wenn ein Datenverlust festgestellt wird"
|
|
) == "event"
|
|
|
|
def test_event_trigger_breach(self):
|
|
assert _classify_trigger_type(
|
|
"In case of a data breach, notify authorities", ""
|
|
) == "event"
|
|
|
|
def test_periodic_trigger_jaehrlich(self):
|
|
assert _classify_trigger_type(
|
|
"Jährlich ist eine Überprüfung durchzuführen", ""
|
|
) == "periodic"
|
|
|
|
def test_periodic_trigger_regelmaessig(self):
|
|
assert _classify_trigger_type(
|
|
"Regelmäßig muss ein Audit stattfinden", ""
|
|
) == "periodic"
|
|
|
|
def test_periodic_trigger_quarterly(self):
|
|
assert _classify_trigger_type(
|
|
"Quarterly review of access controls", ""
|
|
) == "periodic"
|
|
|
|
def test_continuous_default(self):
|
|
assert _classify_trigger_type(
|
|
"Betreiber müssen Zugangskontrollen implementieren", ""
|
|
) == "continuous"
|
|
|
|
def test_continuous_empty_text(self):
|
|
assert _classify_trigger_type("", "") == "continuous"
|
|
|
|
def test_event_takes_precedence_over_periodic(self):
|
|
# "Vorfall" + "regelmäßig" → event wins
|
|
assert _classify_trigger_type(
|
|
"Bei einem Vorfall ist regelmäßig zu prüfen", ""
|
|
) == "event"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# IMPLEMENTATION-SPECIFIC DETECTION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIsImplementationSpecific:
|
|
"""Tests for _is_implementation_specific_text helper."""
|
|
|
|
def test_tls_is_implementation_specific(self):
|
|
assert _is_implementation_specific_text(
|
|
"Verschlüsselung mittels TLS 1.3 sicherstellen",
|
|
"sicherstellen", "Verschlüsselung"
|
|
)
|
|
|
|
def test_mfa_is_implementation_specific(self):
|
|
assert _is_implementation_specific_text(
|
|
"MFA muss für alle Konten aktiviert werden",
|
|
"aktivieren", "MFA"
|
|
)
|
|
|
|
def test_siem_is_implementation_specific(self):
|
|
assert _is_implementation_specific_text(
|
|
"Ein SIEM-System muss betrieben werden",
|
|
"betreiben", "SIEM-System"
|
|
)
|
|
|
|
def test_abstract_obligation_not_specific(self):
|
|
assert not _is_implementation_specific_text(
|
|
"Zugriffskontrollen müssen implementiert werden",
|
|
"implementieren", "Zugriffskontrollen"
|
|
)
|
|
|
|
def test_generic_encryption_not_specific(self):
|
|
assert not _is_implementation_specific_text(
|
|
"Daten müssen verschlüsselt gespeichert werden",
|
|
"verschlüsseln", "Daten"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TEXT SIMILARITY TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTextSimilar:
|
|
"""Tests for _text_similar Jaccard helper."""
|
|
|
|
def test_identical_strings(self):
|
|
assert _text_similar("implementieren mfa", "implementieren mfa")
|
|
|
|
def test_similar_strings(self):
|
|
assert _text_similar(
|
|
"implementieren zugangskontrolle",
|
|
"implementieren zugangskontrolle system",
|
|
threshold=0.60,
|
|
)
|
|
|
|
def test_different_strings(self):
|
|
assert not _text_similar(
|
|
"implementieren mfa",
|
|
"dokumentieren audit",
|
|
threshold=0.75,
|
|
)
|
|
|
|
def test_empty_string(self):
|
|
assert not _text_similar("", "something")
|
|
|
|
def test_both_empty(self):
|
|
assert not _text_similar("", "")
|
|
|
|
|
|
class TestIsMoreImplementationSpecific:
|
|
"""Tests for _is_more_implementation_specific."""
|
|
|
|
def test_concrete_vs_abstract(self):
|
|
concrete = "SMS-Versand muss über TLS verschlüsselt werden"
|
|
abstract = "Kommunikation muss verschlüsselt werden"
|
|
assert _is_more_implementation_specific(concrete, abstract)
|
|
|
|
def test_abstract_vs_concrete(self):
|
|
concrete = "Firewall-Regeln müssen konfiguriert werden"
|
|
abstract = "Netzwerksicherheit muss gewährleistet werden"
|
|
assert not _is_more_implementation_specific(abstract, concrete)
|
|
|
|
def test_equal_specificity_longer_wins(self):
|
|
a = "Zugriffskontrollen müssen implementiert werden und dokumentiert werden"
|
|
b = "Zugriffskontrollen implementieren"
|
|
assert _is_more_implementation_specific(a, b)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MERGE PASS TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMergePass:
|
|
"""Tests for DecompositionPass.run_merge_pass."""
|
|
|
|
def test_merge_pass_merges_similar_obligations(self):
|
|
mock_db = MagicMock()
|
|
|
|
# Step 1: Parents with >1 validated obligation
|
|
mock_parents = MagicMock()
|
|
mock_parents.fetchall.return_value = [
|
|
("parent-uuid-1", 3),
|
|
]
|
|
|
|
# Step 2: Obligations for that parent
|
|
mock_obligs = MagicMock()
|
|
mock_obligs.fetchall.return_value = [
|
|
("obl-1", "OC-001-01",
|
|
"Betreiber müssen Verschlüsselung implementieren",
|
|
"implementieren", "verschlüsselung"),
|
|
("obl-2", "OC-001-02",
|
|
"Betreiber müssen Verschlüsselung mittels TLS implementieren",
|
|
"implementieren", "verschlüsselung"),
|
|
("obl-3", "OC-001-03",
|
|
"Betreiber müssen Zugriffsprotokolle führen",
|
|
"führen", "zugriffsprotokolle"),
|
|
]
|
|
|
|
# Step 3: Final count
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (2,)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_parents
|
|
if call_count[0] == 2:
|
|
return mock_obligs
|
|
if call_count[0] == 3:
|
|
return MagicMock() # UPDATE
|
|
if call_count[0] == 4:
|
|
return mock_count # Final count
|
|
return MagicMock()
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = decomp.run_merge_pass()
|
|
|
|
assert stats["parents_checked"] == 1
|
|
assert stats["obligations_merged"] == 1 # obl-2 merged into obl-1
|
|
assert stats["obligations_kept"] == 2
|
|
|
|
def test_merge_pass_no_merge_when_different_actions(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_parents = MagicMock()
|
|
mock_parents.fetchall.return_value = [
|
|
("parent-uuid-1", 2),
|
|
]
|
|
|
|
mock_obligs = MagicMock()
|
|
mock_obligs.fetchall.return_value = [
|
|
("obl-1", "OC-001-01",
|
|
"Verschlüsselung implementieren",
|
|
"implementieren", "verschlüsselung"),
|
|
("obl-2", "OC-001-02",
|
|
"Zugriffsprotokolle dokumentieren",
|
|
"dokumentieren", "zugriffsprotokolle"),
|
|
]
|
|
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (2,)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_parents
|
|
if call_count[0] == 2:
|
|
return mock_obligs
|
|
if call_count[0] == 3:
|
|
return mock_count
|
|
return MagicMock()
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = decomp.run_merge_pass()
|
|
|
|
assert stats["obligations_merged"] == 0
|
|
assert stats["obligations_kept"] == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ENRICH PASS TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEnrichPass:
|
|
"""Tests for DecompositionPass.enrich_obligations."""
|
|
|
|
def test_enrich_classifies_trigger_types(self):
|
|
mock_db = MagicMock()
|
|
|
|
mock_obligs = MagicMock()
|
|
mock_obligs.fetchall.return_value = [
|
|
("obl-1", "Bei Vorfall melden", "Sicherheitsvorfall",
|
|
"melden", "Vorfall"),
|
|
("obl-2", "Jährlich Audit durchführen", "",
|
|
"durchführen", "Audit"),
|
|
("obl-3", "Verschlüsselung mittels TLS implementieren", "",
|
|
"implementieren", "Verschlüsselung"),
|
|
]
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_obligs
|
|
return MagicMock() # UPDATE statements
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = decomp.enrich_obligations()
|
|
|
|
assert stats["enriched"] == 3
|
|
assert stats["trigger_event"] == 1
|
|
assert stats["trigger_periodic"] == 1
|
|
assert stats["trigger_continuous"] == 1
|
|
assert stats["implementation_specific"] == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MIGRATION 075 TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMigration075:
|
|
"""Tests for migration 075 SQL file."""
|
|
|
|
def test_migration_file_exists(self):
|
|
from pathlib import Path
|
|
migration = Path(__file__).parent.parent / "migrations" / "075_obligation_refinement.sql"
|
|
assert migration.exists(), "Migration 075 file missing"
|
|
|
|
def test_migration_contains_required_fields(self):
|
|
from pathlib import Path
|
|
migration = Path(__file__).parent.parent / "migrations" / "075_obligation_refinement.sql"
|
|
content = migration.read_text()
|
|
assert "merged_into_id" in content
|
|
assert "trigger_type" in content
|
|
assert "is_implementation_specific" in content
|
|
assert "'merged'" in content
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PASS 0B ENRICHMENT INTEGRATION TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPass0bWithEnrichment:
|
|
"""Tests that Pass 0b uses enrichment metadata correctly."""
|
|
|
|
def test_pass0b_query_skips_merged(self):
|
|
"""Verify Pass 0b query includes merged_into_id IS NULL filter."""
|
|
mock_db = MagicMock()
|
|
mock_rows = MagicMock()
|
|
mock_rows.fetchall.return_value = []
|
|
mock_db.execute.return_value = mock_rows
|
|
|
|
import asyncio
|
|
decomp = DecompositionPass(db=mock_db)
|
|
stats = asyncio.get_event_loop().run_until_complete(
|
|
decomp.run_pass0b(limit=10, use_anthropic=True)
|
|
)
|
|
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
query_str = str(call_args[0][0])
|
|
assert "merged_into_id IS NULL" in query_str
|
|
|
|
def test_severity_capped_for_implementation_specific(self):
|
|
"""Implementation-specific obligations get max severity=medium."""
|
|
obl = {
|
|
"oc_id": "oc-1",
|
|
"candidate_id": "OC-001-01",
|
|
"parent_uuid": "p-uuid",
|
|
"obligation_text": "TLS implementieren",
|
|
"action": "implementieren",
|
|
"object": "TLS",
|
|
"is_test": False,
|
|
"is_reporting": False,
|
|
"parent_title": "Encryption",
|
|
"parent_category": "security",
|
|
"parent_citation": "",
|
|
"parent_severity": "high",
|
|
"parent_control_id": "SEC-001",
|
|
"source_ref": "",
|
|
"trigger_type": "continuous",
|
|
"is_implementation_specific": True,
|
|
}
|
|
parsed = {
|
|
"title": "TLS implementieren",
|
|
"objective": "TLS für alle Verbindungen",
|
|
"requirements": ["TLS 1.3"],
|
|
"test_procedure": ["Scan"],
|
|
"evidence": ["Zertifikat"],
|
|
"severity": "critical",
|
|
"category": "security",
|
|
}
|
|
stats = {"controls_created": 0, "candidates_processed": 0,
|
|
"llm_failures": 0, "dedup_linked": 0, "dedup_review": 0}
|
|
|
|
mock_db = MagicMock()
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
mock_insert = MagicMock()
|
|
mock_insert.fetchone.return_value = ("new-uuid-1",)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_seq # _next_atomic_seq
|
|
if call_count[0] == 2:
|
|
return mock_insert # INSERT RETURNING id
|
|
return MagicMock()
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
import asyncio
|
|
decomp = DecompositionPass(db=mock_db)
|
|
asyncio.get_event_loop().run_until_complete(
|
|
decomp._process_pass0b_control(obl, parsed, stats)
|
|
)
|
|
|
|
# _write_atomic_control INSERT is call #2: db.execute(text(...), {params})
|
|
insert_call = mock_db.execute.call_args_list[1]
|
|
# positional args: (text_obj, params_dict)
|
|
insert_params = insert_call[0][1]
|
|
assert insert_params["severity"] == "medium"
|
|
|
|
# parent_link INSERT is call #3
|
|
link_call = mock_db.execute.call_args_list[2]
|
|
link_query = str(link_call[0][0])
|
|
assert "control_parent_links" in link_query
|
|
link_params = link_call[0][1]
|
|
assert link_params["cu"] == "new-uuid-1"
|
|
assert link_params["pu"] == "p-uuid"
|
|
|
|
def test_test_obligation_gets_testing_category(self):
|
|
"""Test obligations should get category='testing'."""
|
|
obl = {
|
|
"oc_id": "oc-1",
|
|
"candidate_id": "OC-001-01",
|
|
"parent_uuid": "p-uuid",
|
|
"obligation_text": "MFA testen",
|
|
"action": "testen",
|
|
"object": "MFA",
|
|
"is_test": True,
|
|
"is_reporting": False,
|
|
"parent_title": "Auth",
|
|
"parent_category": "security",
|
|
"parent_citation": "",
|
|
"parent_severity": "high",
|
|
"parent_control_id": "AUTH-001",
|
|
"source_ref": "",
|
|
"trigger_type": "periodic",
|
|
"is_implementation_specific": False,
|
|
}
|
|
parsed = {
|
|
"title": "MFA-Wirksamkeit testen",
|
|
"objective": "Regelmäßig MFA testen",
|
|
"requirements": ["Testplan"],
|
|
"test_procedure": ["Durchführung"],
|
|
"evidence": ["Protokoll"],
|
|
"severity": "high",
|
|
"category": "security", # LLM says security
|
|
}
|
|
stats = {"controls_created": 0, "candidates_processed": 0,
|
|
"llm_failures": 0, "dedup_linked": 0, "dedup_review": 0}
|
|
|
|
mock_db = MagicMock()
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
mock_insert = MagicMock()
|
|
mock_insert.fetchone.return_value = ("new-uuid-2",)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_seq
|
|
if call_count[0] == 2:
|
|
return mock_insert # INSERT RETURNING id
|
|
return MagicMock()
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
import asyncio
|
|
decomp = DecompositionPass(db=mock_db)
|
|
asyncio.get_event_loop().run_until_complete(
|
|
decomp._process_pass0b_control(obl, parsed, stats)
|
|
)
|
|
|
|
# _write_atomic_control INSERT is call #2: db.execute(text(...), {params})
|
|
insert_call = mock_db.execute.call_args_list[1]
|
|
insert_params = insert_call[0][1]
|
|
assert insert_params["category"] == "testing"
|
|
|
|
def test_parent_link_created_with_source_citation(self):
|
|
"""_write_atomic_control inserts a row into control_parent_links
|
|
with source_regulation and source_article parsed from parent_citation."""
|
|
import json as _json
|
|
obl = {
|
|
"oc_id": "oc-link-1",
|
|
"candidate_id": "OC-DSGVO-01",
|
|
"parent_uuid": "p-uuid-dsgvo",
|
|
"obligation_text": "Daten minimieren",
|
|
"action": "minimieren",
|
|
"object": "personenbezogene Daten",
|
|
"is_test": False,
|
|
"is_reporting": False,
|
|
"parent_title": "Datenminimierung",
|
|
"parent_category": "privacy",
|
|
"parent_citation": _json.dumps({
|
|
"source": "DSGVO",
|
|
"article": "Art. 5 Abs. 1 lit. c",
|
|
"paragraph": "",
|
|
}),
|
|
"parent_severity": "high",
|
|
"parent_control_id": "PRIV-001",
|
|
"source_ref": "DSGVO Art. 5 Abs. 1 lit. c",
|
|
"trigger_type": "continuous",
|
|
"is_implementation_specific": False,
|
|
}
|
|
parsed = {
|
|
"title": "Personenbezogene Daten minimieren",
|
|
"objective": "Nur erforderliche Daten erheben",
|
|
"requirements": ["Datenminimierung"],
|
|
"test_procedure": ["Audit"],
|
|
"evidence": ["Protokoll"],
|
|
"severity": "high",
|
|
"category": "privacy",
|
|
}
|
|
stats = {"controls_created": 0, "candidates_processed": 0,
|
|
"llm_failures": 0, "dedup_linked": 0, "dedup_review": 0}
|
|
|
|
mock_db = MagicMock()
|
|
mock_seq = MagicMock()
|
|
mock_seq.fetchone.return_value = (0,)
|
|
mock_insert = MagicMock()
|
|
mock_insert.fetchone.return_value = ("new-uuid-dsgvo",)
|
|
|
|
call_count = [0]
|
|
def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return mock_seq
|
|
if call_count[0] == 2:
|
|
return mock_insert
|
|
return MagicMock()
|
|
mock_db.execute.side_effect = side_effect
|
|
|
|
import asyncio
|
|
decomp = DecompositionPass(db=mock_db)
|
|
asyncio.get_event_loop().run_until_complete(
|
|
decomp._process_pass0b_control(obl, parsed, stats)
|
|
)
|
|
|
|
# Call #3 is the parent_link INSERT
|
|
link_call = mock_db.execute.call_args_list[2]
|
|
link_query = str(link_call[0][0])
|
|
assert "control_parent_links" in link_query
|
|
link_params = link_call[0][1]
|
|
assert link_params["cu"] == "new-uuid-dsgvo"
|
|
assert link_params["pu"] == "p-uuid-dsgvo"
|
|
assert link_params["sr"] == "DSGVO"
|
|
assert link_params["sa"] == "Art. 5 Abs. 1 lit. c"
|
|
assert link_params["oci"] == "oc-link-1"
|
|
|
|
def test_parse_citation_handles_formats(self):
|
|
"""_parse_citation handles JSON string, dict, empty, and invalid."""
|
|
import json as _json
|
|
from compliance.services.decomposition_pass import _parse_citation
|
|
|
|
# JSON string
|
|
result = _parse_citation(_json.dumps({"source": "NIS2", "article": "Art. 21"}))
|
|
assert result["source"] == "NIS2"
|
|
assert result["article"] == "Art. 21"
|
|
|
|
# Already a dict
|
|
result = _parse_citation({"source": "DSGVO", "article": "Art. 5"})
|
|
assert result["source"] == "DSGVO"
|
|
|
|
# Empty / None
|
|
assert _parse_citation("") == {}
|
|
assert _parse_citation(None) == {}
|
|
|
|
# Invalid JSON
|
|
assert _parse_citation("not json") == {}
|