"""Tests for Decomposition Pass (Pass 0a + 0b). Covers: - ObligationCandidate / AtomicControlCandidate dataclasses - Normative signal detection (regex patterns) - Quality Gate (all 6 checks) - passes_quality_gate logic - _compute_extraction_confidence - _parse_json_array / _parse_json_object - _format_field / _format_citation - _normalize_severity - _template_fallback - _build_pass0a_prompt / _build_pass0b_prompt - DecompositionPass.run_pass0a (mocked LLM + DB) - DecompositionPass.run_pass0b (mocked LLM + DB) - DecompositionPass.decomposition_status (mocked DB) """ import json import pytest from unittest.mock import MagicMock, patch, AsyncMock from compliance.services.decomposition_pass import ( ObligationCandidate, AtomicControlCandidate, quality_gate, passes_quality_gate, _NORMATIVE_RE, _RATIONALE_RE, _TEST_RE, _REPORTING_RE, _parse_json_array, _parse_json_object, _ensure_list, _format_field, _format_citation, _compute_extraction_confidence, _normalize_severity, _template_fallback, _fallback_obligation, _build_pass0a_prompt, _build_pass0b_prompt, _build_pass0a_batch_prompt, _build_pass0b_batch_prompt, _PASS0A_SYSTEM_PROMPT, _PASS0B_SYSTEM_PROMPT, DecompositionPass, ) # --------------------------------------------------------------------------- # DATACLASS TESTS # --------------------------------------------------------------------------- class TestObligationCandidate: """Tests for ObligationCandidate dataclass.""" def test_defaults(self): oc = ObligationCandidate() assert oc.candidate_id == "" assert oc.normative_strength == "must" assert oc.is_test_obligation is False assert oc.release_state == "extracted" assert oc.quality_flags == {} def test_to_dict(self): oc = ObligationCandidate( candidate_id="OC-001-01", parent_control_uuid="uuid-1", obligation_text="Betreiber müssen MFA implementieren", action="implementieren", object_="MFA", ) d = oc.to_dict() assert d["candidate_id"] == "OC-001-01" assert d["object"] == "MFA" assert "object_" not in d # should be "object" in dict def test_full_creation(self): oc = ObligationCandidate( candidate_id="OC-MICA-0001-01", parent_control_uuid="uuid-abc", obligation_text="Betreiber müssen Kontinuität sicherstellen", action="sicherstellen", object_="Dienstleistungskontinuität", condition="bei Ausfall des Handelssystems", normative_strength="must", is_test_obligation=False, is_reporting_obligation=False, extraction_confidence=0.90, ) assert oc.condition == "bei Ausfall des Handelssystems" assert oc.extraction_confidence == 0.90 class TestAtomicControlCandidate: """Tests for AtomicControlCandidate dataclass.""" def test_defaults(self): ac = AtomicControlCandidate() assert ac.severity == "medium" assert ac.requirements == [] assert ac.test_procedure == [] def test_to_dict(self): ac = AtomicControlCandidate( candidate_id="AC-FIN-001", title="Service Continuity Mechanism", objective="Ensure continuity upon failure.", requirements=["Failover mechanism"], ) d = ac.to_dict() assert d["title"] == "Service Continuity Mechanism" assert len(d["requirements"]) == 1 # --------------------------------------------------------------------------- # NORMATIVE SIGNAL DETECTION TESTS # --------------------------------------------------------------------------- class TestNormativeSignals: """Tests for normative regex patterns.""" def test_muessen_detected(self): assert _NORMATIVE_RE.search("Betreiber müssen sicherstellen") def test_muss_detected(self): assert _NORMATIVE_RE.search("Das System muss implementiert sein") def test_hat_sicherzustellen(self): assert _NORMATIVE_RE.search("Der Verantwortliche hat sicherzustellen") def test_sind_verpflichtet(self): assert _NORMATIVE_RE.search("Anbieter sind verpflichtet zu melden") def test_ist_zu_dokumentieren(self): assert _NORMATIVE_RE.search("Der Vorfall ist zu dokumentieren") def test_shall(self): assert _NORMATIVE_RE.search("The operator shall implement MFA") def test_no_signal(self): assert not _NORMATIVE_RE.search("Die Sonne scheint heute") def test_rationale_detected(self): assert _RATIONALE_RE.search("da schwache Passwörter Risiken bergen") def test_test_signal_detected(self): assert _TEST_RE.search("regelmäßige Tests der Wirksamkeit") def test_reporting_signal_detected(self): assert _REPORTING_RE.search("Behörden sind zu unterrichten") # --------------------------------------------------------------------------- # QUALITY GATE TESTS # --------------------------------------------------------------------------- class TestQualityGate: """Tests for quality_gate function.""" def test_valid_normative_obligation(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Betreiber müssen Verschlüsselung implementieren", ) flags = quality_gate(oc) assert flags["has_normative_signal"] is True assert flags["not_evidence_only"] is True assert flags["min_length"] is True assert flags["has_parent_link"] is True def test_rationale_detected(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Schwache Passwörter können zu Risiken führen, weil sie leicht zu erraten sind", ) flags = quality_gate(oc) assert flags["not_rationale"] is False def test_evidence_only_rejected(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Screenshot der Konfiguration", ) flags = quality_gate(oc) assert flags["not_evidence_only"] is False def test_too_short_rejected(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="MFA", ) flags = quality_gate(oc) assert flags["min_length"] is False def test_no_parent_link(self): oc = ObligationCandidate( parent_control_uuid="", obligation_text="Betreiber müssen MFA implementieren", ) flags = quality_gate(oc) assert flags["has_parent_link"] is False def test_multi_verb_detected(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Betreiber müssen implementieren und dokumentieren sowie regelmäßig testen", ) flags = quality_gate(oc) assert flags["single_action"] is False def test_single_verb_passes(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Betreiber müssen MFA für alle privilegierten Konten implementieren", ) flags = quality_gate(oc) assert flags["single_action"] is True def test_no_normative_signal(self): oc = ObligationCandidate( parent_control_uuid="uuid-1", obligation_text="Ein DR-Plan beschreibt die Wiederherstellungsprozeduren im Detail", ) flags = quality_gate(oc) assert flags["has_normative_signal"] is False class TestPassesQualityGate: """Tests for passes_quality_gate function.""" def test_all_critical_pass(self): flags = { "has_normative_signal": True, "single_action": True, "not_rationale": True, "not_evidence_only": True, "min_length": True, "has_parent_link": True, } assert passes_quality_gate(flags) is True def test_no_normative_signal_fails(self): flags = { "has_normative_signal": False, "single_action": True, "not_rationale": True, "not_evidence_only": True, "min_length": True, "has_parent_link": True, } assert passes_quality_gate(flags) is False def test_evidence_only_fails(self): flags = { "has_normative_signal": True, "single_action": True, "not_rationale": True, "not_evidence_only": False, "min_length": True, "has_parent_link": True, } assert passes_quality_gate(flags) is False def test_non_critical_dont_block(self): """single_action and not_rationale are NOT critical — should still pass.""" flags = { "has_normative_signal": True, "single_action": False, # Not critical "not_rationale": False, # Not critical "not_evidence_only": True, "min_length": True, "has_parent_link": True, } assert passes_quality_gate(flags) is True # --------------------------------------------------------------------------- # HELPER TESTS # --------------------------------------------------------------------------- class TestComputeExtractionConfidence: """Tests for _compute_extraction_confidence.""" def test_all_flags_pass(self): flags = { "has_normative_signal": True, "single_action": True, "not_rationale": True, "not_evidence_only": True, "min_length": True, "has_parent_link": True, } assert _compute_extraction_confidence(flags) == 1.0 def test_no_flags_pass(self): flags = { "has_normative_signal": False, "single_action": False, "not_rationale": False, "not_evidence_only": False, "min_length": False, "has_parent_link": False, } assert _compute_extraction_confidence(flags) == 0.0 def test_partial_flags(self): flags = { "has_normative_signal": True, # 0.30 "single_action": False, "not_rationale": True, # 0.20 "not_evidence_only": True, # 0.15 "min_length": True, # 0.10 "has_parent_link": True, # 0.05 } assert _compute_extraction_confidence(flags) == 0.80 class TestParseJsonArray: """Tests for _parse_json_array.""" def test_valid_array(self): result = _parse_json_array('[{"a": 1}, {"a": 2}]') assert len(result) == 2 assert result[0]["a"] == 1 def test_single_object_wrapped(self): result = _parse_json_array('{"a": 1}') assert len(result) == 1 def test_embedded_in_text(self): result = _parse_json_array('Here is the result:\n[{"a": 1}]\nDone.') assert len(result) == 1 def test_invalid_returns_empty(self): result = _parse_json_array("not json at all") assert result == [] def test_empty_array(self): result = _parse_json_array("[]") assert result == [] class TestParseJsonObject: """Tests for _parse_json_object.""" def test_valid_object(self): result = _parse_json_object('{"title": "MFA"}') assert result["title"] == "MFA" def test_embedded_in_text(self): result = _parse_json_object('```json\n{"title": "MFA"}\n```') assert result["title"] == "MFA" def test_invalid_returns_empty(self): result = _parse_json_object("not json") assert result == {} class TestEnsureList: """Tests for _ensure_list.""" def test_list_passthrough(self): assert _ensure_list(["a", "b"]) == ["a", "b"] def test_string_wrapped(self): assert _ensure_list("hello") == ["hello"] def test_empty_string(self): assert _ensure_list("") == [] def test_none(self): assert _ensure_list(None) == [] def test_int(self): assert _ensure_list(42) == [] class TestFormatField: """Tests for _format_field.""" def test_string_passthrough(self): assert _format_field("hello") == "hello" def test_json_list_string(self): result = _format_field('["Req 1", "Req 2"]') assert "- Req 1" in result assert "- Req 2" in result def test_list_input(self): result = _format_field(["A", "B"]) assert "- A" in result assert "- B" in result def test_empty(self): assert _format_field("") == "" assert _format_field(None) == "" class TestFormatCitation: """Tests for _format_citation.""" def test_json_dict(self): result = _format_citation('{"source": "MiCA", "article": "Art. 8"}') assert "MiCA" in result assert "Art. 8" in result def test_plain_string(self): assert _format_citation("MiCA Art. 8") == "MiCA Art. 8" def test_empty(self): assert _format_citation("") == "" assert _format_citation(None) == "" class TestNormalizeSeverity: """Tests for _normalize_severity.""" def test_valid_values(self): assert _normalize_severity("critical") == "critical" assert _normalize_severity("HIGH") == "high" assert _normalize_severity(" Medium ") == "medium" assert _normalize_severity("low") == "low" def test_invalid_defaults_to_medium(self): assert _normalize_severity("unknown") == "medium" assert _normalize_severity("") == "medium" assert _normalize_severity(None) == "medium" class TestTemplateFallback: """Tests for _template_fallback.""" def test_normal_obligation(self): ac = _template_fallback( obligation_text="Betreiber müssen MFA implementieren", action="implementieren", object_="MFA", parent_title="Authentication Controls", parent_severity="high", parent_category="authentication", is_test=False, is_reporting=False, ) assert "Implementieren" in ac.title assert ac.severity == "high" assert len(ac.requirements) == 1 def test_test_obligation(self): ac = _template_fallback( obligation_text="MFA muss regelmäßig getestet werden", action="testen", object_="MFA-Wirksamkeit", parent_title="MFA Control", parent_severity="medium", parent_category="auth", is_test=True, is_reporting=False, ) assert "Test:" in ac.title assert "Testprotokoll" in ac.evidence def test_reporting_obligation(self): ac = _template_fallback( obligation_text="Behörden sind über Vorfälle zu informieren", action="informieren", object_="zuständige Behörden", parent_title="Incident Reporting", parent_severity="high", parent_category="governance", is_test=False, is_reporting=True, ) assert "Meldepflicht:" in ac.title assert "Meldeprozess-Dokumentation" in ac.evidence # --------------------------------------------------------------------------- # PROMPT BUILDER TESTS # --------------------------------------------------------------------------- class TestPromptBuilders: """Tests for LLM prompt builders.""" def test_pass0a_prompt_contains_all_fields(self): prompt = _build_pass0a_prompt( title="MFA Control", objective="Implement MFA", requirements="- Require TOTP\n- Hardware key", test_procedure="- Test login", source_ref="DSGVO Art. 32", ) assert "MFA Control" in prompt assert "Implement MFA" in prompt assert "Require TOTP" in prompt assert "DSGVO Art. 32" in prompt assert "JSON-Array" in prompt def test_pass0b_prompt_contains_all_fields(self): prompt = _build_pass0b_prompt( obligation_text="MFA implementieren", action="implementieren", object_="MFA", parent_title="Auth Controls", parent_category="authentication", source_ref="DSGVO Art. 32", ) assert "MFA implementieren" in prompt assert "implementieren" in prompt assert "Auth Controls" in prompt assert "JSON" in prompt def test_system_prompts_exist(self): assert "REGELN" in _PASS0A_SYSTEM_PROMPT assert "atomares" in _PASS0B_SYSTEM_PROMPT # --------------------------------------------------------------------------- # DECOMPOSITION PASS INTEGRATION TESTS # --------------------------------------------------------------------------- class TestDecompositionPassRun0a: """Tests for DecompositionPass.run_pass0a.""" @pytest.mark.asyncio async def test_pass0a_extracts_obligations(self): mock_db = MagicMock() # Rich controls to decompose mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ( "uuid-1", "CTRL-001", "Service Continuity", "Sicherstellen der Dienstleistungskontinuität", '["Mechanismen implementieren", "Systeme testen"]', '["Prüfung der Mechanismen"]', '{"source": "MiCA", "article": "Art. 8"}', "finance", ), ] mock_db.execute.return_value = mock_rows llm_response = json.dumps([ { "obligation_text": "Betreiber müssen Mechanismen zur Dienstleistungskontinuität implementieren", "action": "implementieren", "object": "Kontinuitätsmechanismen", "condition": "bei Ausfall des Handelssystems", "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False, }, { "obligation_text": "Kontinuitätsmechanismen müssen regelmäßig getestet werden", "action": "testen", "object": "Kontinuitätsmechanismen", "condition": None, "normative_strength": "must", "is_test_obligation": True, "is_reporting_obligation": False, }, ]) with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm: mock_llm.return_value = llm_response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a(limit=10) assert stats["controls_processed"] == 1 assert stats["obligations_extracted"] == 2 assert stats["obligations_validated"] == 2 assert stats["errors"] == 0 # Verify DB writes: 1 SELECT + 2 INSERTs + 1 COMMIT assert mock_db.execute.call_count >= 3 mock_db.commit.assert_called_once() @pytest.mark.asyncio async def test_pass0a_fallback_on_empty_llm(self): mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ( "uuid-1", "CTRL-001", "MFA Control", "Betreiber müssen MFA implementieren", "", "", "", "auth", ), ] mock_db.execute.return_value = mock_rows with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm: mock_llm.return_value = "I cannot help with that." # Invalid JSON decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a(limit=10) assert stats["controls_processed"] == 1 # Fallback should create 1 obligation from the objective assert stats["obligations_extracted"] == 1 @pytest.mark.asyncio async def test_pass0a_skips_empty_controls(self): mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ("uuid-1", "CTRL-001", "", "", "", "", "", ""), ] mock_db.execute.return_value = mock_rows # No LLM call needed — empty controls are skipped before LLM decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a(limit=10) assert stats["controls_skipped_empty"] == 1 assert stats["controls_processed"] == 0 @pytest.mark.asyncio async def test_pass0a_rejects_evidence_only(self): mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ( "uuid-1", "CTRL-001", "Evidence List", "Betreiber müssen Nachweise erbringen", "", "", "", "governance", ), ] mock_db.execute.return_value = mock_rows llm_response = json.dumps([ { "obligation_text": "Dokumentation der Konfiguration", "action": "dokumentieren", "object": "Konfiguration", "condition": None, "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False, }, ]) with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm: mock_llm.return_value = llm_response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a(limit=10) assert stats["obligations_extracted"] == 1 assert stats["obligations_rejected"] == 1 class TestDecompositionPassRun0b: """Tests for DecompositionPass.run_pass0b.""" @pytest.mark.asyncio async def test_pass0b_creates_atomic_controls(self): mock_db = MagicMock() # Validated obligation candidates mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ( "oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1", "Betreiber müssen Kontinuität sicherstellen", "sicherstellen", "Dienstleistungskontinuität", False, False, # is_test, is_reporting "Service Continuity", "finance", '{"source": "MiCA", "article": "Art. 8"}', "high", "FIN-001", ), ] # Mock _next_atomic_seq result mock_seq = MagicMock() mock_seq.fetchone.return_value = (0,) # Call sequence: 1=SELECT, 2=_next_atomic_seq, 3=INSERT control, 4=UPDATE oc call_count = [0] def side_effect(*args, **kwargs): call_count[0] += 1 if call_count[0] == 1: return mock_rows # SELECT candidates if call_count[0] == 2: return mock_seq # _next_atomic_seq return MagicMock() # INSERT/UPDATE mock_db.execute.side_effect = side_effect llm_response = json.dumps({ "title": "Dienstleistungskontinuität bei Systemausfall", "objective": "Sicherstellen, dass Dienstleistungen fortgeführt werden.", "requirements": ["Failover-Mechanismus implementieren"], "test_procedure": ["Failover-Test durchführen"], "evidence": ["Systemarchitektur", "DR-Plan"], "severity": "high", "category": "operations", }) with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm: mock_llm.return_value = llm_response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0b(limit=10) assert stats["candidates_processed"] == 1 assert stats["controls_created"] == 1 assert stats["llm_failures"] == 0 @pytest.mark.asyncio async def test_pass0b_template_fallback(self): mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ( "oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1", "Betreiber müssen MFA implementieren", "implementieren", "MFA", False, False, "Auth Controls", "authentication", "", "high", "AUTH-001", ), ] mock_seq = MagicMock() mock_seq.fetchone.return_value = (0,) call_count = [0] def side_effect(*args, **kwargs): call_count[0] += 1 if call_count[0] == 1: return mock_rows if call_count[0] == 2: return mock_seq return MagicMock() mock_db.execute.side_effect = side_effect with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm: mock_llm.return_value = "Sorry, invalid response" # LLM fails decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0b(limit=10) assert stats["controls_created"] == 1 assert stats["llm_failures"] == 1 class TestDecompositionStatus: """Tests for DecompositionPass.decomposition_status.""" def test_returns_status(self): mock_db = MagicMock() mock_result = MagicMock() mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800) mock_db.execute.return_value = mock_result decomp = DecompositionPass(db=mock_db) status = decomp.decomposition_status() assert status["rich_controls"] == 5000 assert status["decomposed_controls"] == 1000 assert status["total_candidates"] == 3000 assert status["validated"] == 2500 assert status["rejected"] == 200 assert status["composed"] == 2000 assert status["atomic_controls"] == 1800 assert status["decomposition_pct"] == 20.0 assert status["composition_pct"] == 80.0 def test_handles_zero_division(self): mock_db = MagicMock() mock_result = MagicMock() mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0) mock_db.execute.return_value = mock_result decomp = DecompositionPass(db=mock_db) status = decomp.decomposition_status() assert status["decomposition_pct"] == 0.0 assert status["composition_pct"] == 0.0 # --------------------------------------------------------------------------- # MIGRATION 061 SCHEMA TESTS # --------------------------------------------------------------------------- class TestMigration061: """Tests for migration 061 SQL file.""" def test_migration_file_exists(self): from pathlib import Path migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql" assert migration.exists(), "Migration 061 file missing" def test_migration_contains_required_tables(self): from pathlib import Path migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql" content = migration.read_text() assert "obligation_candidates" in content assert "parent_control_uuid" in content assert "decomposition_method" in content assert "candidate_id" in content assert "quality_flags" in content # --------------------------------------------------------------------------- # BATCH PROMPT TESTS # --------------------------------------------------------------------------- class TestBatchPromptBuilders: """Tests for batch prompt builders.""" def test_pass0a_batch_prompt_contains_all_controls(self): controls = [ { "control_id": "AUTH-001", "title": "MFA Control", "objective": "Implement MFA", "requirements": "- TOTP required", "test_procedure": "- Test login", "source_ref": "DSGVO Art. 32", }, { "control_id": "AUTH-002", "title": "Password Policy", "objective": "Enforce strong passwords", "requirements": "- Min 12 chars", "test_procedure": "- Test weak password", "source_ref": "BSI IT-Grundschutz", }, ] prompt = _build_pass0a_batch_prompt(controls) assert "AUTH-001" in prompt assert "AUTH-002" in prompt assert "MFA Control" in prompt assert "Password Policy" in prompt assert "CONTROL 1" in prompt assert "CONTROL 2" in prompt assert "2 Controls" in prompt def test_pass0a_batch_prompt_single_control(self): controls = [ { "control_id": "AUTH-001", "title": "MFA", "objective": "MFA", "requirements": "", "test_procedure": "", "source_ref": "", }, ] prompt = _build_pass0a_batch_prompt(controls) assert "AUTH-001" in prompt assert "1 Controls" in prompt def test_pass0b_batch_prompt_contains_all_obligations(self): obligations = [ { "candidate_id": "OC-AUTH-001-01", "obligation_text": "MFA implementieren", "action": "implementieren", "object": "MFA", "parent_title": "Auth Controls", "parent_category": "authentication", "source_ref": "DSGVO Art. 32", }, { "candidate_id": "OC-AUTH-001-02", "obligation_text": "MFA testen", "action": "testen", "object": "MFA", "parent_title": "Auth Controls", "parent_category": "authentication", "source_ref": "DSGVO Art. 32", }, ] prompt = _build_pass0b_batch_prompt(obligations) assert "OC-AUTH-001-01" in prompt assert "OC-AUTH-001-02" in prompt assert "PFLICHT 1" in prompt assert "PFLICHT 2" in prompt assert "2 Pflichten" in prompt class TestFallbackObligation: """Tests for _fallback_obligation helper.""" def test_uses_objective_when_available(self): ctrl = {"title": "MFA", "objective": "Implement MFA for all users"} result = _fallback_obligation(ctrl) assert result["obligation_text"] == "Implement MFA for all users" assert result["action"] == "sicherstellen" def test_uses_title_when_no_objective(self): ctrl = {"title": "MFA Control", "objective": ""} result = _fallback_obligation(ctrl) assert result["obligation_text"] == "MFA Control" # --------------------------------------------------------------------------- # ANTHROPIC BATCHING INTEGRATION TESTS # --------------------------------------------------------------------------- class TestDecompositionPassAnthropicBatch: """Tests for batched Anthropic API calls in Pass 0a/0b.""" @pytest.mark.asyncio async def test_pass0a_anthropic_batched(self): """Test Pass 0a with Anthropic API and batch_size=2.""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ("uuid-1", "CTRL-001", "MFA Control", "Implement MFA", "", "", "", "security"), ("uuid-2", "CTRL-002", "Encryption", "Encrypt data at rest", "", "", "", "security"), ] mock_db.execute.return_value = mock_rows # Anthropic returns JSON object keyed by control_id batched_response = json.dumps({ "CTRL-001": [ {"obligation_text": "MFA muss implementiert werden", "action": "implementieren", "object": "MFA", "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False}, ], "CTRL-002": [ {"obligation_text": "Daten müssen verschlüsselt werden", "action": "verschlüsseln", "object": "Daten", "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False}, ], }) with patch( "compliance.services.decomposition_pass._llm_anthropic", new_callable=AsyncMock, ) as mock_llm: mock_llm.return_value = batched_response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a( limit=10, batch_size=2, use_anthropic=True, ) assert stats["controls_processed"] == 2 assert stats["obligations_extracted"] == 2 assert stats["llm_calls"] == 1 # Only 1 API call for 2 controls assert stats["provider"] == "anthropic" @pytest.mark.asyncio async def test_pass0a_anthropic_single(self): """Test Pass 0a with Anthropic API, batch_size=1 (no batching).""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ("uuid-1", "CTRL-001", "MFA Control", "Implement MFA", "", "", "", "security"), ] mock_db.execute.return_value = mock_rows response = json.dumps([ {"obligation_text": "MFA muss implementiert werden", "action": "implementieren", "object": "MFA", "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False}, ]) with patch( "compliance.services.decomposition_pass._llm_anthropic", new_callable=AsyncMock, ) as mock_llm: mock_llm.return_value = response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a( limit=10, batch_size=1, use_anthropic=True, ) assert stats["controls_processed"] == 1 assert stats["llm_calls"] == 1 assert stats["provider"] == "anthropic" @pytest.mark.asyncio async def test_pass0b_anthropic_batched(self): """Test Pass 0b with Anthropic API and batch_size=2.""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ("oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1", "MFA implementieren", "implementieren", "MFA", False, False, "Auth", "security", '{"source": "DSGVO", "article": "Art. 32"}', "high", "CTRL-001"), ("oc-uuid-2", "OC-CTRL-001-02", "parent-uuid-1", "MFA testen", "testen", "MFA", True, False, "Auth", "security", '{"source": "DSGVO", "article": "Art. 32"}', "high", "CTRL-001"), ] mock_seq = MagicMock() mock_seq.fetchone.return_value = (0,) call_count = [0] def side_effect(*args, **kwargs): call_count[0] += 1 if call_count[0] == 1: return mock_rows # SELECT candidates # _next_atomic_seq calls (every 3rd after first: 2, 5, 8, ...) if call_count[0] in (2, 5): return mock_seq return MagicMock() # INSERT/UPDATE mock_db.execute.side_effect = side_effect batched_response = json.dumps({ "OC-CTRL-001-01": { "title": "MFA implementieren", "objective": "MFA fuer alle Konten.", "requirements": ["TOTP einrichten"], "test_procedure": ["Login testen"], "evidence": ["Konfigurationsnachweis"], "severity": "high", "category": "security", }, "OC-CTRL-001-02": { "title": "MFA-Wirksamkeit testen", "objective": "Regelmaessige MFA-Tests.", "requirements": ["Testplan erstellen"], "test_procedure": ["Testdurchfuehrung"], "evidence": ["Testprotokoll"], "severity": "high", "category": "security", }, }) with patch( "compliance.services.decomposition_pass._llm_anthropic", new_callable=AsyncMock, ) as mock_llm: mock_llm.return_value = batched_response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0b( limit=10, batch_size=2, use_anthropic=True, ) assert stats["controls_created"] == 2 assert stats["llm_calls"] == 1 assert stats["provider"] == "anthropic" # --------------------------------------------------------------------------- # SOURCE FILTER TESTS # --------------------------------------------------------------------------- class TestSourceFilter: """Tests for source_filter parameter in Pass 0a.""" @pytest.mark.asyncio async def test_pass0a_source_filter_builds_ilike_query(self): """Verify source_filter adds ILIKE clauses to query.""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [ ("uuid-1", "CTRL-001", "Machine Safety", "Ensure safety", "", "", '{"source": "Maschinenverordnung (EU) 2023/1230"}', "security"), ] mock_db.execute.return_value = mock_rows response = json.dumps([ {"obligation_text": "Sicherheit gewaehrleisten", "action": "gewaehrleisten", "object": "Sicherheit", "normative_strength": "must", "is_test_obligation": False, "is_reporting_obligation": False}, ]) with patch( "compliance.services.decomposition_pass._llm_anthropic", new_callable=AsyncMock, ) as mock_llm: mock_llm.return_value = response decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a( limit=10, batch_size=1, use_anthropic=True, source_filter="Maschinenverordnung,Cyber Resilience Act", ) assert stats["controls_processed"] == 1 # Verify the SQL query contained ILIKE clauses call_args = mock_db.execute.call_args_list[0] query_str = str(call_args[0][0]) assert "ILIKE" in query_str @pytest.mark.asyncio async def test_pass0a_source_filter_none_no_clause(self): """Verify no ILIKE clause when source_filter is None.""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [] mock_db.execute.return_value = mock_rows decomp = DecompositionPass(db=mock_db) stats = await decomp.run_pass0a( limit=10, use_anthropic=True, source_filter=None, ) call_args = mock_db.execute.call_args_list[0] query_str = str(call_args[0][0]) assert "ILIKE" not in query_str @pytest.mark.asyncio async def test_pass0a_combined_category_and_source_filter(self): """Verify both category_filter and source_filter can be used together.""" mock_db = MagicMock() mock_rows = MagicMock() mock_rows.fetchall.return_value = [] mock_db.execute.return_value = mock_rows decomp = DecompositionPass(db=mock_db) await decomp.run_pass0a( limit=10, use_anthropic=True, category_filter="security,operations", source_filter="Maschinenverordnung", ) call_args = mock_db.execute.call_args_list[0] query_str = str(call_args[0][0]) assert "IN :cats" in query_str assert "ILIKE" in query_str