Files
breakpilot-compliance/backend-compliance/tests/test_decomposition_pass.py
Benjamin Admin d22c47c9eb
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 35s
CI/CD / test-python-backend-compliance (push) Successful in 34s
CI/CD / test-python-document-crawler (push) Successful in 22s
CI/CD / test-python-dsms-gateway (push) Successful in 19s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
feat(pipeline): Anthropic Batch API, source/regulation filter, cost optimization
- Add Anthropic API support to decomposition Pass 0a/0b (prompt caching, content batching)
- Add Anthropic Batch API (50% cost reduction, async 24h processing)
- Add source_filter (ILIKE on source_citation) for regulation-based filtering
- Add category_filter to Pass 0a for selective decomposition
- Add regulation_filter to control_generator for RAG scan phase filtering
  (prefix match on regulation_code — enables CE + Code Review focus)
- New API endpoints: batch-submit-0a, batch-submit-0b, batch-status, batch-process
- 83 new tests (all passing)

Cost reduction: $2,525 → ~$600-700 with all optimizations combined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 13:22:01 +01:00

1159 lines
39 KiB
Python

"""Tests for Decomposition Pass (Pass 0a + 0b).
Covers:
- ObligationCandidate / AtomicControlCandidate dataclasses
- Normative signal detection (regex patterns)
- Quality Gate (all 6 checks)
- passes_quality_gate logic
- _compute_extraction_confidence
- _parse_json_array / _parse_json_object
- _format_field / _format_citation
- _normalize_severity
- _template_fallback
- _build_pass0a_prompt / _build_pass0b_prompt
- DecompositionPass.run_pass0a (mocked LLM + DB)
- DecompositionPass.run_pass0b (mocked LLM + DB)
- DecompositionPass.decomposition_status (mocked DB)
"""
import json
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from compliance.services.decomposition_pass import (
ObligationCandidate,
AtomicControlCandidate,
quality_gate,
passes_quality_gate,
_NORMATIVE_RE,
_RATIONALE_RE,
_TEST_RE,
_REPORTING_RE,
_parse_json_array,
_parse_json_object,
_ensure_list,
_format_field,
_format_citation,
_compute_extraction_confidence,
_normalize_severity,
_template_fallback,
_fallback_obligation,
_build_pass0a_prompt,
_build_pass0b_prompt,
_build_pass0a_batch_prompt,
_build_pass0b_batch_prompt,
_PASS0A_SYSTEM_PROMPT,
_PASS0B_SYSTEM_PROMPT,
DecompositionPass,
)
# ---------------------------------------------------------------------------
# DATACLASS TESTS
# ---------------------------------------------------------------------------
class TestObligationCandidate:
"""Tests for ObligationCandidate dataclass."""
def test_defaults(self):
oc = ObligationCandidate()
assert oc.candidate_id == ""
assert oc.normative_strength == "must"
assert oc.is_test_obligation is False
assert oc.release_state == "extracted"
assert oc.quality_flags == {}
def test_to_dict(self):
oc = ObligationCandidate(
candidate_id="OC-001-01",
parent_control_uuid="uuid-1",
obligation_text="Betreiber müssen MFA implementieren",
action="implementieren",
object_="MFA",
)
d = oc.to_dict()
assert d["candidate_id"] == "OC-001-01"
assert d["object"] == "MFA"
assert "object_" not in d # should be "object" in dict
def test_full_creation(self):
oc = ObligationCandidate(
candidate_id="OC-MICA-0001-01",
parent_control_uuid="uuid-abc",
obligation_text="Betreiber müssen Kontinuität sicherstellen",
action="sicherstellen",
object_="Dienstleistungskontinuität",
condition="bei Ausfall des Handelssystems",
normative_strength="must",
is_test_obligation=False,
is_reporting_obligation=False,
extraction_confidence=0.90,
)
assert oc.condition == "bei Ausfall des Handelssystems"
assert oc.extraction_confidence == 0.90
class TestAtomicControlCandidate:
"""Tests for AtomicControlCandidate dataclass."""
def test_defaults(self):
ac = AtomicControlCandidate()
assert ac.severity == "medium"
assert ac.requirements == []
assert ac.test_procedure == []
def test_to_dict(self):
ac = AtomicControlCandidate(
candidate_id="AC-FIN-001",
title="Service Continuity Mechanism",
objective="Ensure continuity upon failure.",
requirements=["Failover mechanism"],
)
d = ac.to_dict()
assert d["title"] == "Service Continuity Mechanism"
assert len(d["requirements"]) == 1
# ---------------------------------------------------------------------------
# NORMATIVE SIGNAL DETECTION TESTS
# ---------------------------------------------------------------------------
class TestNormativeSignals:
"""Tests for normative regex patterns."""
def test_muessen_detected(self):
assert _NORMATIVE_RE.search("Betreiber müssen sicherstellen")
def test_muss_detected(self):
assert _NORMATIVE_RE.search("Das System muss implementiert sein")
def test_hat_sicherzustellen(self):
assert _NORMATIVE_RE.search("Der Verantwortliche hat sicherzustellen")
def test_sind_verpflichtet(self):
assert _NORMATIVE_RE.search("Anbieter sind verpflichtet zu melden")
def test_ist_zu_dokumentieren(self):
assert _NORMATIVE_RE.search("Der Vorfall ist zu dokumentieren")
def test_shall(self):
assert _NORMATIVE_RE.search("The operator shall implement MFA")
def test_no_signal(self):
assert not _NORMATIVE_RE.search("Die Sonne scheint heute")
def test_rationale_detected(self):
assert _RATIONALE_RE.search("da schwache Passwörter Risiken bergen")
def test_test_signal_detected(self):
assert _TEST_RE.search("regelmäßige Tests der Wirksamkeit")
def test_reporting_signal_detected(self):
assert _REPORTING_RE.search("Behörden sind zu unterrichten")
# ---------------------------------------------------------------------------
# QUALITY GATE TESTS
# ---------------------------------------------------------------------------
class TestQualityGate:
"""Tests for quality_gate function."""
def test_valid_normative_obligation(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Betreiber müssen Verschlüsselung implementieren",
)
flags = quality_gate(oc)
assert flags["has_normative_signal"] is True
assert flags["not_evidence_only"] is True
assert flags["min_length"] is True
assert flags["has_parent_link"] is True
def test_rationale_detected(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Schwache Passwörter können zu Risiken führen, weil sie leicht zu erraten sind",
)
flags = quality_gate(oc)
assert flags["not_rationale"] is False
def test_evidence_only_rejected(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Screenshot der Konfiguration",
)
flags = quality_gate(oc)
assert flags["not_evidence_only"] is False
def test_too_short_rejected(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="MFA",
)
flags = quality_gate(oc)
assert flags["min_length"] is False
def test_no_parent_link(self):
oc = ObligationCandidate(
parent_control_uuid="",
obligation_text="Betreiber müssen MFA implementieren",
)
flags = quality_gate(oc)
assert flags["has_parent_link"] is False
def test_multi_verb_detected(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Betreiber müssen implementieren und dokumentieren sowie regelmäßig testen",
)
flags = quality_gate(oc)
assert flags["single_action"] is False
def test_single_verb_passes(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Betreiber müssen MFA für alle privilegierten Konten implementieren",
)
flags = quality_gate(oc)
assert flags["single_action"] is True
def test_no_normative_signal(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Ein DR-Plan beschreibt die Wiederherstellungsprozeduren im Detail",
)
flags = quality_gate(oc)
assert flags["has_normative_signal"] is False
class TestPassesQualityGate:
"""Tests for passes_quality_gate function."""
def test_all_critical_pass(self):
flags = {
"has_normative_signal": True,
"single_action": True,
"not_rationale": True,
"not_evidence_only": True,
"min_length": True,
"has_parent_link": True,
}
assert passes_quality_gate(flags) is True
def test_no_normative_signal_fails(self):
flags = {
"has_normative_signal": False,
"single_action": True,
"not_rationale": True,
"not_evidence_only": True,
"min_length": True,
"has_parent_link": True,
}
assert passes_quality_gate(flags) is False
def test_evidence_only_fails(self):
flags = {
"has_normative_signal": True,
"single_action": True,
"not_rationale": True,
"not_evidence_only": False,
"min_length": True,
"has_parent_link": True,
}
assert passes_quality_gate(flags) is False
def test_non_critical_dont_block(self):
"""single_action and not_rationale are NOT critical — should still pass."""
flags = {
"has_normative_signal": True,
"single_action": False, # Not critical
"not_rationale": False, # Not critical
"not_evidence_only": True,
"min_length": True,
"has_parent_link": True,
}
assert passes_quality_gate(flags) is True
# ---------------------------------------------------------------------------
# HELPER TESTS
# ---------------------------------------------------------------------------
class TestComputeExtractionConfidence:
"""Tests for _compute_extraction_confidence."""
def test_all_flags_pass(self):
flags = {
"has_normative_signal": True,
"single_action": True,
"not_rationale": True,
"not_evidence_only": True,
"min_length": True,
"has_parent_link": True,
}
assert _compute_extraction_confidence(flags) == 1.0
def test_no_flags_pass(self):
flags = {
"has_normative_signal": False,
"single_action": False,
"not_rationale": False,
"not_evidence_only": False,
"min_length": False,
"has_parent_link": False,
}
assert _compute_extraction_confidence(flags) == 0.0
def test_partial_flags(self):
flags = {
"has_normative_signal": True, # 0.30
"single_action": False,
"not_rationale": True, # 0.20
"not_evidence_only": True, # 0.15
"min_length": True, # 0.10
"has_parent_link": True, # 0.05
}
assert _compute_extraction_confidence(flags) == 0.80
class TestParseJsonArray:
"""Tests for _parse_json_array."""
def test_valid_array(self):
result = _parse_json_array('[{"a": 1}, {"a": 2}]')
assert len(result) == 2
assert result[0]["a"] == 1
def test_single_object_wrapped(self):
result = _parse_json_array('{"a": 1}')
assert len(result) == 1
def test_embedded_in_text(self):
result = _parse_json_array('Here is the result:\n[{"a": 1}]\nDone.')
assert len(result) == 1
def test_invalid_returns_empty(self):
result = _parse_json_array("not json at all")
assert result == []
def test_empty_array(self):
result = _parse_json_array("[]")
assert result == []
class TestParseJsonObject:
"""Tests for _parse_json_object."""
def test_valid_object(self):
result = _parse_json_object('{"title": "MFA"}')
assert result["title"] == "MFA"
def test_embedded_in_text(self):
result = _parse_json_object('```json\n{"title": "MFA"}\n```')
assert result["title"] == "MFA"
def test_invalid_returns_empty(self):
result = _parse_json_object("not json")
assert result == {}
class TestEnsureList:
"""Tests for _ensure_list."""
def test_list_passthrough(self):
assert _ensure_list(["a", "b"]) == ["a", "b"]
def test_string_wrapped(self):
assert _ensure_list("hello") == ["hello"]
def test_empty_string(self):
assert _ensure_list("") == []
def test_none(self):
assert _ensure_list(None) == []
def test_int(self):
assert _ensure_list(42) == []
class TestFormatField:
"""Tests for _format_field."""
def test_string_passthrough(self):
assert _format_field("hello") == "hello"
def test_json_list_string(self):
result = _format_field('["Req 1", "Req 2"]')
assert "- Req 1" in result
assert "- Req 2" in result
def test_list_input(self):
result = _format_field(["A", "B"])
assert "- A" in result
assert "- B" in result
def test_empty(self):
assert _format_field("") == ""
assert _format_field(None) == ""
class TestFormatCitation:
"""Tests for _format_citation."""
def test_json_dict(self):
result = _format_citation('{"source": "MiCA", "article": "Art. 8"}')
assert "MiCA" in result
assert "Art. 8" in result
def test_plain_string(self):
assert _format_citation("MiCA Art. 8") == "MiCA Art. 8"
def test_empty(self):
assert _format_citation("") == ""
assert _format_citation(None) == ""
class TestNormalizeSeverity:
"""Tests for _normalize_severity."""
def test_valid_values(self):
assert _normalize_severity("critical") == "critical"
assert _normalize_severity("HIGH") == "high"
assert _normalize_severity(" Medium ") == "medium"
assert _normalize_severity("low") == "low"
def test_invalid_defaults_to_medium(self):
assert _normalize_severity("unknown") == "medium"
assert _normalize_severity("") == "medium"
assert _normalize_severity(None) == "medium"
class TestTemplateFallback:
"""Tests for _template_fallback."""
def test_normal_obligation(self):
ac = _template_fallback(
obligation_text="Betreiber müssen MFA implementieren",
action="implementieren",
object_="MFA",
parent_title="Authentication Controls",
parent_severity="high",
parent_category="authentication",
is_test=False,
is_reporting=False,
)
assert "Implementieren" in ac.title
assert ac.severity == "high"
assert len(ac.requirements) == 1
def test_test_obligation(self):
ac = _template_fallback(
obligation_text="MFA muss regelmäßig getestet werden",
action="testen",
object_="MFA-Wirksamkeit",
parent_title="MFA Control",
parent_severity="medium",
parent_category="auth",
is_test=True,
is_reporting=False,
)
assert "Test:" in ac.title
assert "Testprotokoll" in ac.evidence
def test_reporting_obligation(self):
ac = _template_fallback(
obligation_text="Behörden sind über Vorfälle zu informieren",
action="informieren",
object_="zuständige Behörden",
parent_title="Incident Reporting",
parent_severity="high",
parent_category="governance",
is_test=False,
is_reporting=True,
)
assert "Meldepflicht:" in ac.title
assert "Meldeprozess-Dokumentation" in ac.evidence
# ---------------------------------------------------------------------------
# PROMPT BUILDER TESTS
# ---------------------------------------------------------------------------
class TestPromptBuilders:
"""Tests for LLM prompt builders."""
def test_pass0a_prompt_contains_all_fields(self):
prompt = _build_pass0a_prompt(
title="MFA Control",
objective="Implement MFA",
requirements="- Require TOTP\n- Hardware key",
test_procedure="- Test login",
source_ref="DSGVO Art. 32",
)
assert "MFA Control" in prompt
assert "Implement MFA" in prompt
assert "Require TOTP" in prompt
assert "DSGVO Art. 32" in prompt
assert "JSON-Array" in prompt
def test_pass0b_prompt_contains_all_fields(self):
prompt = _build_pass0b_prompt(
obligation_text="MFA implementieren",
action="implementieren",
object_="MFA",
parent_title="Auth Controls",
parent_category="authentication",
source_ref="DSGVO Art. 32",
)
assert "MFA implementieren" in prompt
assert "implementieren" in prompt
assert "Auth Controls" in prompt
assert "JSON" in prompt
def test_system_prompts_exist(self):
assert "REGELN" in _PASS0A_SYSTEM_PROMPT
assert "atomares" in _PASS0B_SYSTEM_PROMPT
# ---------------------------------------------------------------------------
# DECOMPOSITION PASS INTEGRATION TESTS
# ---------------------------------------------------------------------------
class TestDecompositionPassRun0a:
"""Tests for DecompositionPass.run_pass0a."""
@pytest.mark.asyncio
async def test_pass0a_extracts_obligations(self):
mock_db = MagicMock()
# Rich controls to decompose
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
(
"uuid-1", "CTRL-001",
"Service Continuity",
"Sicherstellen der Dienstleistungskontinuität",
'["Mechanismen implementieren", "Systeme testen"]',
'["Prüfung der Mechanismen"]',
'{"source": "MiCA", "article": "Art. 8"}',
"finance",
),
]
mock_db.execute.return_value = mock_rows
llm_response = json.dumps([
{
"obligation_text": "Betreiber müssen Mechanismen zur Dienstleistungskontinuität implementieren",
"action": "implementieren",
"object": "Kontinuitätsmechanismen",
"condition": "bei Ausfall des Handelssystems",
"normative_strength": "must",
"is_test_obligation": False,
"is_reporting_obligation": False,
},
{
"obligation_text": "Kontinuitätsmechanismen müssen regelmäßig getestet werden",
"action": "testen",
"object": "Kontinuitätsmechanismen",
"condition": None,
"normative_strength": "must",
"is_test_obligation": True,
"is_reporting_obligation": False,
},
])
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = llm_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(limit=10)
assert stats["controls_processed"] == 1
assert stats["obligations_extracted"] == 2
assert stats["obligations_validated"] == 2
assert stats["errors"] == 0
# Verify DB writes: 1 SELECT + 2 INSERTs + 1 COMMIT
assert mock_db.execute.call_count >= 3
mock_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_pass0a_fallback_on_empty_llm(self):
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
(
"uuid-1", "CTRL-001",
"MFA Control",
"Betreiber müssen MFA implementieren",
"", "", "", "auth",
),
]
mock_db.execute.return_value = mock_rows
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = "I cannot help with that." # Invalid JSON
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(limit=10)
assert stats["controls_processed"] == 1
# Fallback should create 1 obligation from the objective
assert stats["obligations_extracted"] == 1
@pytest.mark.asyncio
async def test_pass0a_skips_empty_controls(self):
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "", "", "", "", "", ""),
]
mock_db.execute.return_value = mock_rows
# No LLM call needed — empty controls are skipped before LLM
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(limit=10)
assert stats["controls_skipped_empty"] == 1
assert stats["controls_processed"] == 0
@pytest.mark.asyncio
async def test_pass0a_rejects_evidence_only(self):
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
(
"uuid-1", "CTRL-001",
"Evidence List",
"Betreiber müssen Nachweise erbringen",
"", "", "", "governance",
),
]
mock_db.execute.return_value = mock_rows
llm_response = json.dumps([
{
"obligation_text": "Dokumentation der Konfiguration",
"action": "dokumentieren",
"object": "Konfiguration",
"condition": None,
"normative_strength": "must",
"is_test_obligation": False,
"is_reporting_obligation": False,
},
])
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = llm_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(limit=10)
assert stats["obligations_extracted"] == 1
assert stats["obligations_rejected"] == 1
class TestDecompositionPassRun0b:
"""Tests for DecompositionPass.run_pass0b."""
@pytest.mark.asyncio
async def test_pass0b_creates_atomic_controls(self):
mock_db = MagicMock()
# Validated obligation candidates
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
(
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
"Betreiber müssen Kontinuität sicherstellen",
"sicherstellen", "Dienstleistungskontinuität",
False, False, # is_test, is_reporting
"Service Continuity", "finance",
'{"source": "MiCA", "article": "Art. 8"}',
"high", "FIN-001",
),
]
# Mock _next_atomic_seq result
mock_seq = MagicMock()
mock_seq.fetchone.return_value = (0,)
# Call sequence: 1=SELECT, 2=_next_atomic_seq, 3=INSERT control, 4=UPDATE oc
call_count = [0]
def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_rows # SELECT candidates
if call_count[0] == 2:
return mock_seq # _next_atomic_seq
return MagicMock() # INSERT/UPDATE
mock_db.execute.side_effect = side_effect
llm_response = json.dumps({
"title": "Dienstleistungskontinuität bei Systemausfall",
"objective": "Sicherstellen, dass Dienstleistungen fortgeführt werden.",
"requirements": ["Failover-Mechanismus implementieren"],
"test_procedure": ["Failover-Test durchführen"],
"evidence": ["Systemarchitektur", "DR-Plan"],
"severity": "high",
"category": "operations",
})
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = llm_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0b(limit=10)
assert stats["candidates_processed"] == 1
assert stats["controls_created"] == 1
assert stats["llm_failures"] == 0
@pytest.mark.asyncio
async def test_pass0b_template_fallback(self):
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
(
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
"Betreiber müssen MFA implementieren",
"implementieren", "MFA",
False, False,
"Auth Controls", "authentication",
"", "high", "AUTH-001",
),
]
mock_seq = MagicMock()
mock_seq.fetchone.return_value = (0,)
call_count = [0]
def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_rows
if call_count[0] == 2:
return mock_seq
return MagicMock()
mock_db.execute.side_effect = side_effect
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
mock_llm.return_value = "Sorry, invalid response" # LLM fails
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0b(limit=10)
assert stats["controls_created"] == 1
assert stats["llm_failures"] == 1
class TestDecompositionStatus:
"""Tests for DecompositionPass.decomposition_status."""
def test_returns_status(self):
mock_db = MagicMock()
mock_result = MagicMock()
mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800)
mock_db.execute.return_value = mock_result
decomp = DecompositionPass(db=mock_db)
status = decomp.decomposition_status()
assert status["rich_controls"] == 5000
assert status["decomposed_controls"] == 1000
assert status["total_candidates"] == 3000
assert status["validated"] == 2500
assert status["rejected"] == 200
assert status["composed"] == 2000
assert status["atomic_controls"] == 1800
assert status["decomposition_pct"] == 20.0
assert status["composition_pct"] == 80.0
def test_handles_zero_division(self):
mock_db = MagicMock()
mock_result = MagicMock()
mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0)
mock_db.execute.return_value = mock_result
decomp = DecompositionPass(db=mock_db)
status = decomp.decomposition_status()
assert status["decomposition_pct"] == 0.0
assert status["composition_pct"] == 0.0
# ---------------------------------------------------------------------------
# MIGRATION 061 SCHEMA TESTS
# ---------------------------------------------------------------------------
class TestMigration061:
"""Tests for migration 061 SQL file."""
def test_migration_file_exists(self):
from pathlib import Path
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
assert migration.exists(), "Migration 061 file missing"
def test_migration_contains_required_tables(self):
from pathlib import Path
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
content = migration.read_text()
assert "obligation_candidates" in content
assert "parent_control_uuid" in content
assert "decomposition_method" in content
assert "candidate_id" in content
assert "quality_flags" in content
# ---------------------------------------------------------------------------
# BATCH PROMPT TESTS
# ---------------------------------------------------------------------------
class TestBatchPromptBuilders:
"""Tests for batch prompt builders."""
def test_pass0a_batch_prompt_contains_all_controls(self):
controls = [
{
"control_id": "AUTH-001",
"title": "MFA Control",
"objective": "Implement MFA",
"requirements": "- TOTP required",
"test_procedure": "- Test login",
"source_ref": "DSGVO Art. 32",
},
{
"control_id": "AUTH-002",
"title": "Password Policy",
"objective": "Enforce strong passwords",
"requirements": "- Min 12 chars",
"test_procedure": "- Test weak password",
"source_ref": "BSI IT-Grundschutz",
},
]
prompt = _build_pass0a_batch_prompt(controls)
assert "AUTH-001" in prompt
assert "AUTH-002" in prompt
assert "MFA Control" in prompt
assert "Password Policy" in prompt
assert "CONTROL 1" in prompt
assert "CONTROL 2" in prompt
assert "2 Controls" in prompt
def test_pass0a_batch_prompt_single_control(self):
controls = [
{
"control_id": "AUTH-001",
"title": "MFA",
"objective": "MFA",
"requirements": "",
"test_procedure": "",
"source_ref": "",
},
]
prompt = _build_pass0a_batch_prompt(controls)
assert "AUTH-001" in prompt
assert "1 Controls" in prompt
def test_pass0b_batch_prompt_contains_all_obligations(self):
obligations = [
{
"candidate_id": "OC-AUTH-001-01",
"obligation_text": "MFA implementieren",
"action": "implementieren",
"object": "MFA",
"parent_title": "Auth Controls",
"parent_category": "authentication",
"source_ref": "DSGVO Art. 32",
},
{
"candidate_id": "OC-AUTH-001-02",
"obligation_text": "MFA testen",
"action": "testen",
"object": "MFA",
"parent_title": "Auth Controls",
"parent_category": "authentication",
"source_ref": "DSGVO Art. 32",
},
]
prompt = _build_pass0b_batch_prompt(obligations)
assert "OC-AUTH-001-01" in prompt
assert "OC-AUTH-001-02" in prompt
assert "PFLICHT 1" in prompt
assert "PFLICHT 2" in prompt
assert "2 Pflichten" in prompt
class TestFallbackObligation:
"""Tests for _fallback_obligation helper."""
def test_uses_objective_when_available(self):
ctrl = {"title": "MFA", "objective": "Implement MFA for all users"}
result = _fallback_obligation(ctrl)
assert result["obligation_text"] == "Implement MFA for all users"
assert result["action"] == "sicherstellen"
def test_uses_title_when_no_objective(self):
ctrl = {"title": "MFA Control", "objective": ""}
result = _fallback_obligation(ctrl)
assert result["obligation_text"] == "MFA Control"
# ---------------------------------------------------------------------------
# ANTHROPIC BATCHING INTEGRATION TESTS
# ---------------------------------------------------------------------------
class TestDecompositionPassAnthropicBatch:
"""Tests for batched Anthropic API calls in Pass 0a/0b."""
@pytest.mark.asyncio
async def test_pass0a_anthropic_batched(self):
"""Test Pass 0a with Anthropic API and batch_size=2."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
"", "", "", "security"),
("uuid-2", "CTRL-002", "Encryption", "Encrypt data at rest",
"", "", "", "security"),
]
mock_db.execute.return_value = mock_rows
# Anthropic returns JSON object keyed by control_id
batched_response = json.dumps({
"CTRL-001": [
{"obligation_text": "MFA muss implementiert werden",
"action": "implementieren", "object": "MFA",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
],
"CTRL-002": [
{"obligation_text": "Daten müssen verschlüsselt werden",
"action": "verschlüsseln", "object": "Daten",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
],
})
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = batched_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=2, use_anthropic=True,
)
assert stats["controls_processed"] == 2
assert stats["obligations_extracted"] == 2
assert stats["llm_calls"] == 1 # Only 1 API call for 2 controls
assert stats["provider"] == "anthropic"
@pytest.mark.asyncio
async def test_pass0a_anthropic_single(self):
"""Test Pass 0a with Anthropic API, batch_size=1 (no batching)."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
"", "", "", "security"),
]
mock_db.execute.return_value = mock_rows
response = json.dumps([
{"obligation_text": "MFA muss implementiert werden",
"action": "implementieren", "object": "MFA",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
])
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=1, use_anthropic=True,
)
assert stats["controls_processed"] == 1
assert stats["llm_calls"] == 1
assert stats["provider"] == "anthropic"
@pytest.mark.asyncio
async def test_pass0b_anthropic_batched(self):
"""Test Pass 0b with Anthropic API and batch_size=2."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
"MFA implementieren", "implementieren", "MFA",
False, False, "Auth", "security",
'{"source": "DSGVO", "article": "Art. 32"}',
"high", "CTRL-001"),
("oc-uuid-2", "OC-CTRL-001-02", "parent-uuid-1",
"MFA testen", "testen", "MFA",
True, False, "Auth", "security",
'{"source": "DSGVO", "article": "Art. 32"}',
"high", "CTRL-001"),
]
mock_seq = MagicMock()
mock_seq.fetchone.return_value = (0,)
call_count = [0]
def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_rows # SELECT candidates
# _next_atomic_seq calls (every 3rd after first: 2, 5, 8, ...)
if call_count[0] in (2, 5):
return mock_seq
return MagicMock() # INSERT/UPDATE
mock_db.execute.side_effect = side_effect
batched_response = json.dumps({
"OC-CTRL-001-01": {
"title": "MFA implementieren",
"objective": "MFA fuer alle Konten.",
"requirements": ["TOTP einrichten"],
"test_procedure": ["Login testen"],
"evidence": ["Konfigurationsnachweis"],
"severity": "high",
"category": "security",
},
"OC-CTRL-001-02": {
"title": "MFA-Wirksamkeit testen",
"objective": "Regelmaessige MFA-Tests.",
"requirements": ["Testplan erstellen"],
"test_procedure": ["Testdurchfuehrung"],
"evidence": ["Testprotokoll"],
"severity": "high",
"category": "security",
},
})
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = batched_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0b(
limit=10, batch_size=2, use_anthropic=True,
)
assert stats["controls_created"] == 2
assert stats["llm_calls"] == 1
assert stats["provider"] == "anthropic"
# ---------------------------------------------------------------------------
# SOURCE FILTER TESTS
# ---------------------------------------------------------------------------
class TestSourceFilter:
"""Tests for source_filter parameter in Pass 0a."""
@pytest.mark.asyncio
async def test_pass0a_source_filter_builds_ilike_query(self):
"""Verify source_filter adds ILIKE clauses to query."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "Machine Safety", "Ensure safety",
"", "", '{"source": "Maschinenverordnung (EU) 2023/1230"}', "security"),
]
mock_db.execute.return_value = mock_rows
response = json.dumps([
{"obligation_text": "Sicherheit gewaehrleisten",
"action": "gewaehrleisten", "object": "Sicherheit",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
])
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=1, use_anthropic=True,
source_filter="Maschinenverordnung,Cyber Resilience Act",
)
assert stats["controls_processed"] == 1
# Verify the SQL query contained ILIKE clauses
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "ILIKE" in query_str
@pytest.mark.asyncio
async def test_pass0a_source_filter_none_no_clause(self):
"""Verify no ILIKE clause when source_filter is None."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = []
mock_db.execute.return_value = mock_rows
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, use_anthropic=True, source_filter=None,
)
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "ILIKE" not in query_str
@pytest.mark.asyncio
async def test_pass0a_combined_category_and_source_filter(self):
"""Verify both category_filter and source_filter can be used together."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = []
mock_db.execute.return_value = mock_rows
decomp = DecompositionPass(db=mock_db)
await decomp.run_pass0a(
limit=10, use_anthropic=True,
category_filter="security,operations",
source_filter="Maschinenverordnung",
)
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "IN :cats" in query_str
assert "ILIKE" in query_str