feat(multi-layer): complete Multi-Layer Control Architecture (Phases 1-8 + Pass 0)
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Implements the full Multi-Layer Control Architecture for migrating ~25,000 Rich Controls into atomic, deduplicated Master Controls with full traceability. Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance New services: - ObligationExtractor: 3-tier extraction (exact → embedding → LLM) - PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus) - ControlComposer: Pattern + Obligation → Master Control - PipelineAdapter: Pipeline integration + Migration Passes 1-5 - DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls - CrosswalkRoutes: 15 API endpoints under /v1/canonical/ New DB schema: - Migration 060: obligation_extractions, control_patterns, crosswalk_matrix - Migration 061: obligation_candidates, parent_control_uuid tracking Pattern Library: 50 YAML patterns (30 core + 20 IT-security) Go SDK: Pattern loader with YAML validation and indexing Documentation: MkDocs updated with full architecture overview 500 Python tests passing across all components. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
890
backend-compliance/tests/test_control_composer.py
Normal file
890
backend-compliance/tests/test_control_composer.py
Normal file
@@ -0,0 +1,890 @@
|
||||
"""Tests for Control Composer — Phase 6 of Multi-Layer Control Architecture.
|
||||
|
||||
Validates:
|
||||
- ComposedControl dataclass and serialization
|
||||
- Pattern-guided composition (Tier 1)
|
||||
- Template-only fallback (when LLM fails)
|
||||
- Fallback composition (no pattern)
|
||||
- License rule handling (Rules 1, 2, 3)
|
||||
- Prompt building
|
||||
- Field validation and fixing
|
||||
- Batch composition
|
||||
- Edge cases: empty inputs, missing data, malformed LLM responses
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from compliance.services.control_composer import (
|
||||
ComposedControl,
|
||||
ControlComposer,
|
||||
_anchors_from_pattern,
|
||||
_build_compose_prompt,
|
||||
_build_fallback_prompt,
|
||||
_compose_system_prompt,
|
||||
_ensure_list,
|
||||
_obligation_section,
|
||||
_pattern_section,
|
||||
_severity_to_risk,
|
||||
_validate_control,
|
||||
)
|
||||
from compliance.services.obligation_extractor import ObligationMatch
|
||||
from compliance.services.pattern_matcher import ControlPattern, PatternMatchResult
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_obligation(
|
||||
obligation_id="DSGVO-OBL-001",
|
||||
title="Verarbeitungsverzeichnis fuehren",
|
||||
text="Fuehrung eines Verzeichnisses aller Verarbeitungstaetigkeiten.",
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
regulation_id="dsgvo",
|
||||
) -> ObligationMatch:
|
||||
return ObligationMatch(
|
||||
obligation_id=obligation_id,
|
||||
obligation_title=title,
|
||||
obligation_text=text,
|
||||
method=method,
|
||||
confidence=confidence,
|
||||
regulation_id=regulation_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_pattern(
|
||||
pattern_id="CP-COMP-001",
|
||||
name="compliance_governance",
|
||||
name_de="Compliance-Governance",
|
||||
domain="COMP",
|
||||
category="compliance",
|
||||
) -> ControlPattern:
|
||||
return ControlPattern(
|
||||
id=pattern_id,
|
||||
name=name,
|
||||
name_de=name_de,
|
||||
domain=domain,
|
||||
category=category,
|
||||
description="Compliance management and governance framework",
|
||||
objective_template="Sicherstellen, dass ein wirksames Compliance-Management existiert.",
|
||||
rationale_template="Ohne Governance fehlt die Grundlage fuer Compliance.",
|
||||
requirements_template=[
|
||||
"Compliance-Verantwortlichkeiten definieren",
|
||||
"Regelmaessige Compliance-Bewertungen durchfuehren",
|
||||
"Dokumentationspflichten einhalten",
|
||||
],
|
||||
test_procedure_template=[
|
||||
"Pruefung der Compliance-Organisation",
|
||||
"Stichproben der Dokumentation",
|
||||
],
|
||||
evidence_template=[
|
||||
"Compliance-Handbuch",
|
||||
"Pruefberichte",
|
||||
],
|
||||
severity_default="high",
|
||||
implementation_effort_default="l",
|
||||
obligation_match_keywords=["compliance", "governance", "konformitaet"],
|
||||
tags=["compliance", "governance"],
|
||||
composable_with=["CP-COMP-002"],
|
||||
open_anchor_refs=[
|
||||
{"framework": "ISO 27001", "ref": "A.18"},
|
||||
{"framework": "NIST CSF", "ref": "GV.OC"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _make_pattern_result(pattern=None, confidence=0.85, method="keyword") -> PatternMatchResult:
|
||||
if pattern is None:
|
||||
pattern = _make_pattern()
|
||||
return PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=pattern.id,
|
||||
method=method,
|
||||
confidence=confidence,
|
||||
keyword_hits=4,
|
||||
total_keywords=7,
|
||||
)
|
||||
|
||||
|
||||
def _llm_success_response() -> str:
|
||||
return json.dumps({
|
||||
"title": "Compliance-Governance fuer Verarbeitungstaetigkeiten",
|
||||
"objective": "Sicherstellen, dass alle Verarbeitungstaetigkeiten dokumentiert und ueberwacht werden.",
|
||||
"rationale": "Die DSGVO verlangt ein Verarbeitungsverzeichnis als Grundlage der Rechenschaftspflicht.",
|
||||
"requirements": [
|
||||
"Verarbeitungsverzeichnis gemaess Art. 30 DSGVO fuehren",
|
||||
"Regelmaessige Aktualisierung bei Aenderungen",
|
||||
"Verantwortlichkeiten fuer die Pflege zuweisen",
|
||||
],
|
||||
"test_procedure": [
|
||||
"Vollstaendigkeit des Verzeichnisses pruefen",
|
||||
"Aktualitaet der Eintraege verifizieren",
|
||||
],
|
||||
"evidence": [
|
||||
"Verarbeitungsverzeichnis",
|
||||
"Aenderungsprotokoll",
|
||||
],
|
||||
"severity": "high",
|
||||
"implementation_effort": "m",
|
||||
"category": "compliance",
|
||||
"tags": ["dsgvo", "verarbeitungsverzeichnis", "governance"],
|
||||
"target_audience": ["unternehmen", "behoerden"],
|
||||
"verification_method": "document",
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ComposedControl
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestComposedControl:
|
||||
"""Tests for the ComposedControl dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
c = ComposedControl()
|
||||
assert c.control_id == ""
|
||||
assert c.title == ""
|
||||
assert c.severity == "medium"
|
||||
assert c.risk_score == 5.0
|
||||
assert c.implementation_effort == "m"
|
||||
assert c.release_state == "draft"
|
||||
assert c.license_rule is None
|
||||
assert c.customer_visible is True
|
||||
assert c.pattern_id is None
|
||||
assert c.obligation_ids == []
|
||||
assert c.composition_method == "pattern_guided"
|
||||
|
||||
def test_to_dict_keys(self):
|
||||
c = ComposedControl()
|
||||
d = c.to_dict()
|
||||
expected_keys = {
|
||||
"control_id", "title", "objective", "rationale", "scope",
|
||||
"requirements", "test_procedure", "evidence", "severity",
|
||||
"risk_score", "implementation_effort", "open_anchors",
|
||||
"release_state", "tags", "license_rule", "source_original_text",
|
||||
"source_citation", "customer_visible", "verification_method",
|
||||
"category", "target_audience", "pattern_id", "obligation_ids",
|
||||
"generation_metadata", "composition_method",
|
||||
}
|
||||
assert set(d.keys()) == expected_keys
|
||||
|
||||
def test_to_dict_values(self):
|
||||
c = ComposedControl(
|
||||
title="Test Control",
|
||||
pattern_id="CP-AUTH-001",
|
||||
obligation_ids=["DSGVO-OBL-001"],
|
||||
severity="high",
|
||||
license_rule=1,
|
||||
)
|
||||
d = c.to_dict()
|
||||
assert d["title"] == "Test Control"
|
||||
assert d["pattern_id"] == "CP-AUTH-001"
|
||||
assert d["obligation_ids"] == ["DSGVO-OBL-001"]
|
||||
assert d["severity"] == "high"
|
||||
assert d["license_rule"] == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _ensure_list
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestEnsureList:
|
||||
def test_list_passthrough(self):
|
||||
assert _ensure_list(["a", "b"]) == ["a", "b"]
|
||||
|
||||
def test_string_to_list(self):
|
||||
assert _ensure_list("hello") == ["hello"]
|
||||
|
||||
def test_none_to_empty(self):
|
||||
assert _ensure_list(None) == []
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _ensure_list([]) == []
|
||||
|
||||
def test_filters_empty_values(self):
|
||||
assert _ensure_list(["a", "", "b"]) == ["a", "b"]
|
||||
|
||||
def test_converts_to_strings(self):
|
||||
assert _ensure_list([1, 2, 3]) == ["1", "2", "3"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _anchors_from_pattern
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAnchorsFromPattern:
|
||||
def test_converts_anchors(self):
|
||||
pattern = _make_pattern()
|
||||
anchors = _anchors_from_pattern(pattern)
|
||||
assert len(anchors) == 2
|
||||
assert anchors[0]["framework"] == "ISO 27001"
|
||||
assert anchors[0]["control_id"] == "A.18"
|
||||
assert anchors[0]["alignment_score"] == 0.8
|
||||
|
||||
def test_empty_anchors(self):
|
||||
pattern = _make_pattern()
|
||||
pattern.open_anchor_refs = []
|
||||
anchors = _anchors_from_pattern(pattern)
|
||||
assert anchors == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _severity_to_risk
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSeverityToRisk:
|
||||
def test_critical(self):
|
||||
assert _severity_to_risk("critical") == 9.0
|
||||
|
||||
def test_high(self):
|
||||
assert _severity_to_risk("high") == 7.0
|
||||
|
||||
def test_medium(self):
|
||||
assert _severity_to_risk("medium") == 5.0
|
||||
|
||||
def test_low(self):
|
||||
assert _severity_to_risk("low") == 3.0
|
||||
|
||||
def test_unknown(self):
|
||||
assert _severity_to_risk("xyz") == 5.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _validate_control
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestValidateControl:
|
||||
def test_fixes_invalid_severity(self):
|
||||
c = ComposedControl(severity="extreme")
|
||||
_validate_control(c)
|
||||
assert c.severity == "medium"
|
||||
|
||||
def test_keeps_valid_severity(self):
|
||||
c = ComposedControl(severity="critical")
|
||||
_validate_control(c)
|
||||
assert c.severity == "critical"
|
||||
|
||||
def test_fixes_invalid_effort(self):
|
||||
c = ComposedControl(implementation_effort="xxl")
|
||||
_validate_control(c)
|
||||
assert c.implementation_effort == "m"
|
||||
|
||||
def test_fixes_invalid_verification(self):
|
||||
c = ComposedControl(verification_method="magic")
|
||||
_validate_control(c)
|
||||
assert c.verification_method is None
|
||||
|
||||
def test_keeps_valid_verification(self):
|
||||
c = ComposedControl(verification_method="code_review")
|
||||
_validate_control(c)
|
||||
assert c.verification_method == "code_review"
|
||||
|
||||
def test_fixes_risk_score_out_of_range(self):
|
||||
c = ComposedControl(risk_score=15.0, severity="high")
|
||||
_validate_control(c)
|
||||
assert c.risk_score == 7.0 # from severity
|
||||
|
||||
def test_truncates_long_title(self):
|
||||
c = ComposedControl(title="A" * 300)
|
||||
_validate_control(c)
|
||||
assert len(c.title) <= 255
|
||||
|
||||
def test_ensures_minimum_content(self):
|
||||
c = ComposedControl(
|
||||
title="Test",
|
||||
objective="",
|
||||
rationale="",
|
||||
requirements=[],
|
||||
test_procedure=[],
|
||||
evidence=[],
|
||||
)
|
||||
_validate_control(c)
|
||||
assert c.objective == "Test" # falls back to title
|
||||
assert c.rationale != ""
|
||||
assert len(c.requirements) >= 1
|
||||
assert len(c.test_procedure) >= 1
|
||||
assert len(c.evidence) >= 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Prompt builders
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPromptBuilders:
|
||||
def test_compose_system_prompt_rule1(self):
|
||||
prompt = _compose_system_prompt(1)
|
||||
assert "praxisorientiertes" in prompt
|
||||
assert "KOPIERE KEINE" not in prompt
|
||||
|
||||
def test_compose_system_prompt_rule3(self):
|
||||
prompt = _compose_system_prompt(3)
|
||||
assert "KOPIERE KEINE" in prompt
|
||||
assert "NENNE NICHT die Quelle" in prompt
|
||||
|
||||
def test_obligation_section_full(self):
|
||||
obl = _make_obligation()
|
||||
section = _obligation_section(obl)
|
||||
assert "PFLICHT" in section
|
||||
assert "Verarbeitungsverzeichnis" in section
|
||||
assert "DSGVO-OBL-001" in section
|
||||
assert "dsgvo" in section
|
||||
|
||||
def test_obligation_section_minimal(self):
|
||||
obl = ObligationMatch()
|
||||
section = _obligation_section(obl)
|
||||
assert "Keine spezifische Pflicht" in section
|
||||
|
||||
def test_pattern_section(self):
|
||||
pattern = _make_pattern()
|
||||
section = _pattern_section(pattern)
|
||||
assert "MUSTER" in section
|
||||
assert "Compliance-Governance" in section
|
||||
assert "CP-COMP-001" in section
|
||||
assert "Compliance-Verantwortlichkeiten" in section
|
||||
|
||||
def test_build_compose_prompt_rule1(self):
|
||||
obl = _make_obligation()
|
||||
pattern = _make_pattern()
|
||||
prompt = _build_compose_prompt(obl, pattern, "Original text here", 1)
|
||||
assert "PFLICHT" in prompt
|
||||
assert "MUSTER" in prompt
|
||||
assert "KONTEXT (Originaltext)" in prompt
|
||||
assert "Original text here" in prompt
|
||||
|
||||
def test_build_compose_prompt_rule3(self):
|
||||
obl = _make_obligation()
|
||||
pattern = _make_pattern()
|
||||
prompt = _build_compose_prompt(obl, pattern, "Secret text", 3)
|
||||
assert "Intern analysiert" in prompt
|
||||
assert "Secret text" not in prompt
|
||||
|
||||
def test_build_fallback_prompt(self):
|
||||
obl = _make_obligation()
|
||||
prompt = _build_fallback_prompt(obl, "Chunk text", 1)
|
||||
assert "PFLICHT" in prompt
|
||||
assert "KONTEXT (Originaltext)" in prompt
|
||||
|
||||
def test_build_fallback_prompt_no_chunk(self):
|
||||
obl = _make_obligation()
|
||||
prompt = _build_fallback_prompt(obl, None, 1)
|
||||
assert "Kein Originaltext" in prompt
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ControlComposer — Pattern-guided composition
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestComposeWithPattern:
|
||||
"""Tests for pattern-guided control composition."""
|
||||
|
||||
def setup_method(self):
|
||||
self.composer = ControlComposer()
|
||||
self.obligation = _make_obligation()
|
||||
self.pattern_result = _make_pattern_result()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_success_rule1(self):
|
||||
"""Successful LLM composition with Rule 1."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
chunk_text="Der Verantwortliche fuehrt ein Verzeichnis...",
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.composition_method == "pattern_guided"
|
||||
assert control.title != ""
|
||||
assert "Verarbeitungstaetigkeiten" in control.objective
|
||||
assert len(control.requirements) >= 2
|
||||
assert len(control.test_procedure) >= 1
|
||||
assert len(control.evidence) >= 1
|
||||
assert control.severity == "high"
|
||||
assert control.category == "compliance"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_sets_linkage(self):
|
||||
"""Pattern and obligation IDs should be set."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.pattern_id == "CP-COMP-001"
|
||||
assert control.obligation_ids == ["DSGVO-OBL-001"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_sets_metadata(self):
|
||||
"""Generation metadata should include composition details."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
regulation_code="eu_2016_679",
|
||||
)
|
||||
|
||||
meta = control.generation_metadata
|
||||
assert meta["composition_method"] == "pattern_guided"
|
||||
assert meta["pattern_id"] == "CP-COMP-001"
|
||||
assert meta["pattern_confidence"] == 0.85
|
||||
assert meta["obligation_id"] == "DSGVO-OBL-001"
|
||||
assert meta["license_rule"] == 1
|
||||
assert meta["regulation_code"] == "eu_2016_679"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_rule1_stores_original(self):
|
||||
"""Rule 1: original text should be stored."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
chunk_text="Original DSGVO text",
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.license_rule == 1
|
||||
assert control.source_original_text == "Original DSGVO text"
|
||||
assert control.customer_visible is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_rule2_stores_citation(self):
|
||||
"""Rule 2: citation should be stored."""
|
||||
citation = {
|
||||
"source": "OWASP ASVS",
|
||||
"license": "CC-BY-SA-4.0",
|
||||
"license_notice": "OWASP Foundation",
|
||||
}
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
chunk_text="OWASP text",
|
||||
license_rule=2,
|
||||
source_citation=citation,
|
||||
)
|
||||
|
||||
assert control.license_rule == 2
|
||||
assert control.source_original_text == "OWASP text"
|
||||
assert control.source_citation == citation
|
||||
assert control.customer_visible is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_rule3_no_original(self):
|
||||
"""Rule 3: no original text, not customer visible."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
chunk_text="BSI restricted text",
|
||||
license_rule=3,
|
||||
)
|
||||
|
||||
assert control.license_rule == 3
|
||||
assert control.source_original_text is None
|
||||
assert control.source_citation is None
|
||||
assert control.customer_visible is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ControlComposer — Template-only fallback (LLM fails)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTemplateOnlyFallback:
|
||||
"""Tests for template-only composition when LLM fails."""
|
||||
|
||||
def setup_method(self):
|
||||
self.composer = ControlComposer()
|
||||
self.obligation = _make_obligation()
|
||||
self.pattern_result = _make_pattern_result()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_fallback_on_empty_llm(self):
|
||||
"""When LLM returns empty, should use template directly."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.composition_method == "template_only"
|
||||
assert "Compliance-Governance" in control.title
|
||||
assert control.severity == "high" # from pattern
|
||||
assert len(control.requirements) >= 2 # from pattern template
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_fallback_on_invalid_json(self):
|
||||
"""When LLM returns non-JSON, should use template."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="This is not JSON at all",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.composition_method == "template_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_includes_obligation_title(self):
|
||||
"""Template fallback should include obligation title in control title."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert "Verarbeitungsverzeichnis" in control.title
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_has_open_anchors(self):
|
||||
"""Template fallback should include pattern anchors."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=self.pattern_result,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert len(control.open_anchors) == 2
|
||||
frameworks = [a["framework"] for a in control.open_anchors]
|
||||
assert "ISO 27001" in frameworks
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ControlComposer — Fallback (no pattern)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFallbackNoPattern:
|
||||
"""Tests for fallback composition without a pattern."""
|
||||
|
||||
def setup_method(self):
|
||||
self.composer = ControlComposer()
|
||||
self.obligation = _make_obligation()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_llm(self):
|
||||
"""Fallback should work with LLM response."""
|
||||
response = json.dumps({
|
||||
"title": "Verarbeitungsverzeichnis",
|
||||
"objective": "Verzeichnis fuehren",
|
||||
"rationale": "DSGVO Art. 30",
|
||||
"requirements": ["VVT anlegen"],
|
||||
"test_procedure": ["VVT pruefen"],
|
||||
"evidence": ["VVT Dokument"],
|
||||
"severity": "high",
|
||||
})
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=response,
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=PatternMatchResult(), # No pattern
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.composition_method == "fallback"
|
||||
assert control.pattern_id is None
|
||||
assert control.release_state == "needs_review"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_llm_fails(self):
|
||||
"""Fallback with LLM failure should still produce a control."""
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=self.obligation,
|
||||
pattern_result=PatternMatchResult(),
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.composition_method == "fallback"
|
||||
assert control.title != ""
|
||||
# Validation ensures minimum content
|
||||
assert len(control.requirements) >= 1
|
||||
assert len(control.test_procedure) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_no_obligation_text(self):
|
||||
"""Fallback with empty obligation should still work."""
|
||||
empty_obl = ObligationMatch()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
control = await self.composer.compose(
|
||||
obligation=empty_obl,
|
||||
pattern_result=PatternMatchResult(),
|
||||
license_rule=3,
|
||||
)
|
||||
|
||||
assert control.title != ""
|
||||
assert control.customer_visible is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ControlComposer — Batch composition
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestComposeBatch:
|
||||
"""Tests for batch composition."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_returns_list(self):
|
||||
composer = ControlComposer()
|
||||
items = [
|
||||
{
|
||||
"obligation": _make_obligation(),
|
||||
"pattern_result": _make_pattern_result(),
|
||||
"license_rule": 1,
|
||||
},
|
||||
{
|
||||
"obligation": _make_obligation(
|
||||
obligation_id="NIS2-OBL-001",
|
||||
title="Incident Meldepflicht",
|
||||
regulation_id="nis2",
|
||||
),
|
||||
"pattern_result": PatternMatchResult(),
|
||||
"license_rule": 3,
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
results = await composer.compose_batch(items)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].pattern_id == "CP-COMP-001"
|
||||
assert results[1].pattern_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_empty(self):
|
||||
composer = ControlComposer()
|
||||
results = await composer.compose_batch([])
|
||||
assert results == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Validation integration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestValidationIntegration:
|
||||
"""Tests that validation runs during compose."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_validates_severity(self):
|
||||
"""Invalid severity from LLM should be fixed."""
|
||||
response = json.dumps({
|
||||
"title": "Test",
|
||||
"objective": "Test",
|
||||
"severity": "EXTREME",
|
||||
})
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=response,
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=_make_obligation(),
|
||||
pattern_result=_make_pattern_result(),
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.severity in {"low", "medium", "high", "critical"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compose_ensures_minimum_content(self):
|
||||
"""Empty requirements from LLM should be filled with defaults."""
|
||||
response = json.dumps({
|
||||
"title": "Test",
|
||||
"objective": "Test objective",
|
||||
"requirements": [],
|
||||
})
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=response,
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=_make_obligation(),
|
||||
pattern_result=_make_pattern_result(),
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert len(control.requirements) >= 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: License rule edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestLicenseRuleEdgeCases:
|
||||
"""Tests for license rule handling edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule1_no_chunk_text(self):
|
||||
"""Rule 1 without chunk text: original_text should be None."""
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=_make_obligation(),
|
||||
pattern_result=_make_pattern_result(),
|
||||
chunk_text=None,
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.license_rule == 1
|
||||
assert control.source_original_text is None
|
||||
assert control.customer_visible is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule2_no_citation(self):
|
||||
"""Rule 2 without citation: citation should be None."""
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=_make_obligation(),
|
||||
pattern_result=_make_pattern_result(),
|
||||
chunk_text="Some text",
|
||||
license_rule=2,
|
||||
source_citation=None,
|
||||
)
|
||||
|
||||
assert control.license_rule == 2
|
||||
assert control.source_citation is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule3_overrides_chunk_and_citation(self):
|
||||
"""Rule 3 should always clear original text and citation."""
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=_make_obligation(),
|
||||
pattern_result=_make_pattern_result(),
|
||||
chunk_text="This should be cleared",
|
||||
license_rule=3,
|
||||
source_citation={"source": "BSI"},
|
||||
)
|
||||
|
||||
assert control.source_original_text is None
|
||||
assert control.source_citation is None
|
||||
assert control.customer_visible is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Obligation without ID
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestObligationWithoutId:
|
||||
"""Tests for handling obligations without a known ID."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extracted_obligation(self):
|
||||
"""LLM-extracted obligation (no ID) should still compose."""
|
||||
obl = ObligationMatch(
|
||||
obligation_id=None,
|
||||
obligation_title=None,
|
||||
obligation_text="Pflicht zur Meldung von Sicherheitsvorfaellen",
|
||||
method="llm_extracted",
|
||||
confidence=0.60,
|
||||
regulation_id="nis2",
|
||||
)
|
||||
composer = ControlComposer()
|
||||
with patch(
|
||||
"compliance.services.control_composer._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_llm_success_response(),
|
||||
):
|
||||
control = await composer.compose(
|
||||
obligation=obl,
|
||||
pattern_result=_make_pattern_result(),
|
||||
license_rule=1,
|
||||
)
|
||||
|
||||
assert control.obligation_ids == [] # No ID to link
|
||||
assert control.pattern_id == "CP-COMP-001"
|
||||
assert control.generation_metadata["obligation_method"] == "llm_extracted"
|
||||
504
backend-compliance/tests/test_control_patterns.py
Normal file
504
backend-compliance/tests/test_control_patterns.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""Tests for Control Pattern Library (Phase 2).
|
||||
|
||||
Validates:
|
||||
- JSON Schema structure
|
||||
- YAML pattern files against schema
|
||||
- Pattern ID uniqueness and format
|
||||
- Domain/category consistency
|
||||
- Keyword coverage
|
||||
- Cross-references (composable_with)
|
||||
- Template quality (min lengths, no placeholders without defaults)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
PATTERNS_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "control_patterns"
|
||||
SCHEMA_FILE = PATTERNS_DIR / "_pattern_schema.json"
|
||||
CORE_FILE = PATTERNS_DIR / "core_patterns.yaml"
|
||||
IT_SEC_FILE = PATTERNS_DIR / "domain_it_security.yaml"
|
||||
|
||||
VALID_DOMAINS = [
|
||||
"AUTH", "CRYP", "NET", "DATA", "LOG", "ACC", "SEC",
|
||||
"INC", "AI", "COMP", "GOV", "LAB", "FIN", "TRD", "ENV", "HLT",
|
||||
]
|
||||
|
||||
VALID_SEVERITIES = ["low", "medium", "high", "critical"]
|
||||
VALID_EFFORTS = ["s", "m", "l", "xl"]
|
||||
|
||||
PATTERN_ID_RE = re.compile(r"^CP-[A-Z]+-[0-9]{3}$")
|
||||
NAME_RE = re.compile(r"^[a-z][a-z0-9_]*$")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema():
|
||||
"""Load the JSON schema."""
|
||||
assert SCHEMA_FILE.exists(), f"Schema file not found: {SCHEMA_FILE}"
|
||||
with open(SCHEMA_FILE) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def core_patterns():
|
||||
"""Load core patterns."""
|
||||
assert CORE_FILE.exists(), f"Core patterns file not found: {CORE_FILE}"
|
||||
with open(CORE_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data["patterns"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def it_sec_patterns():
|
||||
"""Load IT security patterns."""
|
||||
assert IT_SEC_FILE.exists(), f"IT security patterns file not found: {IT_SEC_FILE}"
|
||||
with open(IT_SEC_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data["patterns"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_patterns(core_patterns, it_sec_patterns):
|
||||
"""Combined list of all patterns."""
|
||||
return core_patterns + it_sec_patterns
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatternSchema:
|
||||
"""Validate the JSON Schema file itself."""
|
||||
|
||||
def test_schema_exists(self):
|
||||
assert SCHEMA_FILE.exists()
|
||||
|
||||
def test_schema_is_valid_json(self, schema):
|
||||
assert "$schema" in schema
|
||||
assert "properties" in schema
|
||||
|
||||
def test_schema_defines_pattern(self, schema):
|
||||
assert "ControlPattern" in schema.get("$defs", {})
|
||||
|
||||
def test_schema_requires_key_fields(self, schema):
|
||||
pattern_def = schema["$defs"]["ControlPattern"]
|
||||
required = pattern_def["required"]
|
||||
for field in [
|
||||
"id", "name", "name_de", "domain", "category",
|
||||
"description", "objective_template", "rationale_template",
|
||||
"requirements_template", "test_procedure_template",
|
||||
"evidence_template", "severity_default",
|
||||
"obligation_match_keywords", "tags",
|
||||
]:
|
||||
assert field in required, f"Missing required field in schema: {field}"
|
||||
|
||||
def test_schema_domain_enum(self, schema):
|
||||
pattern_def = schema["$defs"]["ControlPattern"]
|
||||
domain_enum = pattern_def["properties"]["domain"]["enum"]
|
||||
assert set(domain_enum) == set(VALID_DOMAINS)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Structure Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFileStructure:
|
||||
"""Validate YAML file structure."""
|
||||
|
||||
def test_core_file_exists(self):
|
||||
assert CORE_FILE.exists()
|
||||
|
||||
def test_it_sec_file_exists(self):
|
||||
assert IT_SEC_FILE.exists()
|
||||
|
||||
def test_core_has_version(self):
|
||||
with open(CORE_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert "version" in data
|
||||
assert data["version"] == "1.0"
|
||||
|
||||
def test_it_sec_has_version(self):
|
||||
with open(IT_SEC_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert "version" in data
|
||||
assert data["version"] == "1.0"
|
||||
|
||||
def test_core_has_description(self):
|
||||
with open(CORE_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert "description" in data
|
||||
assert len(data["description"]) > 20
|
||||
|
||||
def test_it_sec_has_description(self):
|
||||
with open(IT_SEC_FILE) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert "description" in data
|
||||
assert len(data["description"]) > 20
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern Count Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatternCounts:
|
||||
"""Verify expected number of patterns."""
|
||||
|
||||
def test_core_has_30_patterns(self, core_patterns):
|
||||
assert len(core_patterns) == 30, (
|
||||
f"Expected 30 core patterns, got {len(core_patterns)}"
|
||||
)
|
||||
|
||||
def test_it_sec_has_20_patterns(self, it_sec_patterns):
|
||||
assert len(it_sec_patterns) == 20, (
|
||||
f"Expected 20 IT security patterns, got {len(it_sec_patterns)}"
|
||||
)
|
||||
|
||||
def test_total_is_50(self, all_patterns):
|
||||
assert len(all_patterns) == 50, (
|
||||
f"Expected 50 total patterns, got {len(all_patterns)}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern ID Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatternIDs:
|
||||
"""Validate pattern ID format and uniqueness."""
|
||||
|
||||
def test_all_ids_match_format(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert PATTERN_ID_RE.match(p["id"]), (
|
||||
f"Invalid pattern ID format: {p['id']} (expected CP-DOMAIN-NNN)"
|
||||
)
|
||||
|
||||
def test_all_ids_unique(self, all_patterns):
|
||||
ids = [p["id"] for p in all_patterns]
|
||||
duplicates = [id for id, count in Counter(ids).items() if count > 1]
|
||||
assert not duplicates, f"Duplicate pattern IDs: {duplicates}"
|
||||
|
||||
def test_all_names_unique(self, all_patterns):
|
||||
names = [p["name"] for p in all_patterns]
|
||||
duplicates = [n for n, count in Counter(names).items() if count > 1]
|
||||
assert not duplicates, f"Duplicate pattern names: {duplicates}"
|
||||
|
||||
def test_id_domain_matches_domain_field(self, all_patterns):
|
||||
"""The domain in the ID (CP-{DOMAIN}-NNN) should match the domain field."""
|
||||
for p in all_patterns:
|
||||
id_domain = p["id"].split("-")[1]
|
||||
assert id_domain == p["domain"], (
|
||||
f"Pattern {p['id']}: ID domain '{id_domain}' != field domain '{p['domain']}'"
|
||||
)
|
||||
|
||||
def test_all_names_are_snake_case(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert NAME_RE.match(p["name"]), (
|
||||
f"Pattern {p['id']}: name '{p['name']}' is not snake_case"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Domain & Category Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDomainCategories:
|
||||
"""Validate domain and category assignments."""
|
||||
|
||||
def test_all_domains_valid(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert p["domain"] in VALID_DOMAINS, (
|
||||
f"Pattern {p['id']}: invalid domain '{p['domain']}'"
|
||||
)
|
||||
|
||||
def test_domain_coverage(self, all_patterns):
|
||||
"""At least 5 different domains should be covered."""
|
||||
domains = {p["domain"] for p in all_patterns}
|
||||
assert len(domains) >= 5, (
|
||||
f"Only {len(domains)} domains covered: {domains}"
|
||||
)
|
||||
|
||||
def test_all_have_category(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert p.get("category"), (
|
||||
f"Pattern {p['id']}: missing category"
|
||||
)
|
||||
|
||||
def test_category_not_empty(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert len(p["category"]) >= 3, (
|
||||
f"Pattern {p['id']}: category too short: '{p['category']}'"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template Quality Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTemplateQuality:
|
||||
"""Validate template content quality."""
|
||||
|
||||
def test_description_min_length(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
desc = p["description"].strip()
|
||||
assert len(desc) >= 30, (
|
||||
f"Pattern {p['id']}: description too short ({len(desc)} chars)"
|
||||
)
|
||||
|
||||
def test_objective_min_length(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
obj = p["objective_template"].strip()
|
||||
assert len(obj) >= 30, (
|
||||
f"Pattern {p['id']}: objective_template too short ({len(obj)} chars)"
|
||||
)
|
||||
|
||||
def test_rationale_min_length(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
rat = p["rationale_template"].strip()
|
||||
assert len(rat) >= 30, (
|
||||
f"Pattern {p['id']}: rationale_template too short ({len(rat)} chars)"
|
||||
)
|
||||
|
||||
def test_requirements_min_count(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
reqs = p["requirements_template"]
|
||||
assert len(reqs) >= 2, (
|
||||
f"Pattern {p['id']}: needs at least 2 requirements, got {len(reqs)}"
|
||||
)
|
||||
|
||||
def test_requirements_not_empty(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
for i, req in enumerate(p["requirements_template"]):
|
||||
assert len(req.strip()) >= 10, (
|
||||
f"Pattern {p['id']}: requirement {i} too short"
|
||||
)
|
||||
|
||||
def test_test_procedure_min_count(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
tests = p["test_procedure_template"]
|
||||
assert len(tests) >= 1, (
|
||||
f"Pattern {p['id']}: needs at least 1 test procedure"
|
||||
)
|
||||
|
||||
def test_evidence_min_count(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
evidence = p["evidence_template"]
|
||||
assert len(evidence) >= 1, (
|
||||
f"Pattern {p['id']}: needs at least 1 evidence item"
|
||||
)
|
||||
|
||||
def test_name_de_exists(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert p.get("name_de"), (
|
||||
f"Pattern {p['id']}: missing German name (name_de)"
|
||||
)
|
||||
assert len(p["name_de"]) >= 5, (
|
||||
f"Pattern {p['id']}: name_de too short: '{p['name_de']}'"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Severity & Effort Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSeverityEffort:
|
||||
"""Validate severity and effort assignments."""
|
||||
|
||||
def test_all_have_valid_severity(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert p["severity_default"] in VALID_SEVERITIES, (
|
||||
f"Pattern {p['id']}: invalid severity '{p['severity_default']}'"
|
||||
)
|
||||
|
||||
def test_all_have_effort(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
if "implementation_effort_default" in p:
|
||||
assert p["implementation_effort_default"] in VALID_EFFORTS, (
|
||||
f"Pattern {p['id']}: invalid effort '{p['implementation_effort_default']}'"
|
||||
)
|
||||
|
||||
def test_severity_distribution(self, all_patterns):
|
||||
"""At least 2 different severity levels should be used."""
|
||||
severities = {p["severity_default"] for p in all_patterns}
|
||||
assert len(severities) >= 2, (
|
||||
f"Only {len(severities)} severity levels used: {severities}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Keyword Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestKeywords:
|
||||
"""Validate obligation match keywords."""
|
||||
|
||||
def test_all_have_keywords(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
kws = p["obligation_match_keywords"]
|
||||
assert len(kws) >= 3, (
|
||||
f"Pattern {p['id']}: needs at least 3 keywords, got {len(kws)}"
|
||||
)
|
||||
|
||||
def test_keywords_not_empty(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
for kw in p["obligation_match_keywords"]:
|
||||
assert len(kw.strip()) >= 2, (
|
||||
f"Pattern {p['id']}: empty or too short keyword: '{kw}'"
|
||||
)
|
||||
|
||||
def test_keywords_lowercase(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
for kw in p["obligation_match_keywords"]:
|
||||
assert kw == kw.lower(), (
|
||||
f"Pattern {p['id']}: keyword should be lowercase: '{kw}'"
|
||||
)
|
||||
|
||||
def test_has_german_and_english_keywords(self, all_patterns):
|
||||
"""Each pattern should have keywords in both languages (spot check)."""
|
||||
# At minimum, keywords should have a mix (not all German, not all English)
|
||||
for p in all_patterns:
|
||||
kws = p["obligation_match_keywords"]
|
||||
assert len(kws) >= 3, (
|
||||
f"Pattern {p['id']}: too few keywords for bilingual coverage"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tags Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTags:
|
||||
"""Validate tags."""
|
||||
|
||||
def test_all_have_tags(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
assert len(p["tags"]) >= 1, (
|
||||
f"Pattern {p['id']}: needs at least 1 tag"
|
||||
)
|
||||
|
||||
def test_tags_are_strings(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
for tag in p["tags"]:
|
||||
assert isinstance(tag, str) and len(tag) >= 2, (
|
||||
f"Pattern {p['id']}: invalid tag: {tag}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Open Anchor Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestOpenAnchors:
|
||||
"""Validate open anchor references."""
|
||||
|
||||
def test_most_have_anchors(self, all_patterns):
|
||||
"""At least 80% of patterns should have open anchor references."""
|
||||
with_anchors = sum(
|
||||
1 for p in all_patterns
|
||||
if p.get("open_anchor_refs") and len(p["open_anchor_refs"]) >= 1
|
||||
)
|
||||
ratio = with_anchors / len(all_patterns)
|
||||
assert ratio >= 0.80, (
|
||||
f"Only {with_anchors}/{len(all_patterns)} ({ratio:.0%}) patterns have "
|
||||
f"open anchor references (need >= 80%)"
|
||||
)
|
||||
|
||||
def test_anchor_structure(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
for anchor in p.get("open_anchor_refs", []):
|
||||
assert "framework" in anchor, (
|
||||
f"Pattern {p['id']}: anchor missing 'framework'"
|
||||
)
|
||||
assert "ref" in anchor, (
|
||||
f"Pattern {p['id']}: anchor missing 'ref'"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Composability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestComposability:
|
||||
"""Validate composable_with references."""
|
||||
|
||||
def test_composable_refs_are_valid_ids(self, all_patterns):
|
||||
all_ids = {p["id"] for p in all_patterns}
|
||||
for p in all_patterns:
|
||||
for ref in p.get("composable_with", []):
|
||||
assert PATTERN_ID_RE.match(ref), (
|
||||
f"Pattern {p['id']}: composable_with ref '{ref}' is not valid ID format"
|
||||
)
|
||||
assert ref in all_ids, (
|
||||
f"Pattern {p['id']}: composable_with ref '{ref}' does not exist"
|
||||
)
|
||||
|
||||
def test_no_self_references(self, all_patterns):
|
||||
for p in all_patterns:
|
||||
composable = p.get("composable_with", [])
|
||||
assert p["id"] not in composable, (
|
||||
f"Pattern {p['id']}: composable_with contains self-reference"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-File Consistency Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCrossFileConsistency:
|
||||
"""Validate consistency between core and IT security files."""
|
||||
|
||||
def test_no_id_overlap(self, core_patterns, it_sec_patterns):
|
||||
core_ids = {p["id"] for p in core_patterns}
|
||||
it_sec_ids = {p["id"] for p in it_sec_patterns}
|
||||
overlap = core_ids & it_sec_ids
|
||||
assert not overlap, f"ID overlap between files: {overlap}"
|
||||
|
||||
def test_no_name_overlap(self, core_patterns, it_sec_patterns):
|
||||
core_names = {p["name"] for p in core_patterns}
|
||||
it_sec_names = {p["name"] for p in it_sec_patterns}
|
||||
overlap = core_names & it_sec_names
|
||||
assert not overlap, f"Name overlap between files: {overlap}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Placeholder Syntax Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPlaceholderSyntax:
|
||||
"""Validate {placeholder:default} syntax in templates."""
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\{(\w+)(?::([^}]+))?\}")
|
||||
|
||||
def test_placeholders_have_defaults(self, all_patterns):
|
||||
"""All placeholders in requirements should have defaults."""
|
||||
for p in all_patterns:
|
||||
for req in p["requirements_template"]:
|
||||
for match in self.PLACEHOLDER_RE.finditer(req):
|
||||
placeholder = match.group(1)
|
||||
default = match.group(2)
|
||||
# Placeholders should have defaults
|
||||
assert default is not None, (
|
||||
f"Pattern {p['id']}: placeholder '{{{placeholder}}}' has no default value"
|
||||
)
|
||||
1131
backend-compliance/tests/test_crosswalk_routes.py
Normal file
1131
backend-compliance/tests/test_crosswalk_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
816
backend-compliance/tests/test_decomposition_pass.py
Normal file
816
backend-compliance/tests/test_decomposition_pass.py
Normal file
@@ -0,0 +1,816 @@
|
||||
"""Tests for Decomposition Pass (Pass 0a + 0b).
|
||||
|
||||
Covers:
|
||||
- ObligationCandidate / AtomicControlCandidate dataclasses
|
||||
- Normative signal detection (regex patterns)
|
||||
- Quality Gate (all 6 checks)
|
||||
- passes_quality_gate logic
|
||||
- _compute_extraction_confidence
|
||||
- _parse_json_array / _parse_json_object
|
||||
- _format_field / _format_citation
|
||||
- _normalize_severity
|
||||
- _template_fallback
|
||||
- _build_pass0a_prompt / _build_pass0b_prompt
|
||||
- DecompositionPass.run_pass0a (mocked LLM + DB)
|
||||
- DecompositionPass.run_pass0b (mocked LLM + DB)
|
||||
- DecompositionPass.decomposition_status (mocked DB)
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from compliance.services.decomposition_pass import (
|
||||
ObligationCandidate,
|
||||
AtomicControlCandidate,
|
||||
quality_gate,
|
||||
passes_quality_gate,
|
||||
_NORMATIVE_RE,
|
||||
_RATIONALE_RE,
|
||||
_TEST_RE,
|
||||
_REPORTING_RE,
|
||||
_parse_json_array,
|
||||
_parse_json_object,
|
||||
_ensure_list,
|
||||
_format_field,
|
||||
_format_citation,
|
||||
_compute_extraction_confidence,
|
||||
_normalize_severity,
|
||||
_template_fallback,
|
||||
_build_pass0a_prompt,
|
||||
_build_pass0b_prompt,
|
||||
_PASS0A_SYSTEM_PROMPT,
|
||||
_PASS0B_SYSTEM_PROMPT,
|
||||
DecompositionPass,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DATACLASS TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestObligationCandidate:
|
||||
"""Tests for ObligationCandidate dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
oc = ObligationCandidate()
|
||||
assert oc.candidate_id == ""
|
||||
assert oc.normative_strength == "must"
|
||||
assert oc.is_test_obligation is False
|
||||
assert oc.release_state == "extracted"
|
||||
assert oc.quality_flags == {}
|
||||
|
||||
def test_to_dict(self):
|
||||
oc = ObligationCandidate(
|
||||
candidate_id="OC-001-01",
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Betreiber müssen MFA implementieren",
|
||||
action="implementieren",
|
||||
object_="MFA",
|
||||
)
|
||||
d = oc.to_dict()
|
||||
assert d["candidate_id"] == "OC-001-01"
|
||||
assert d["object"] == "MFA"
|
||||
assert "object_" not in d # should be "object" in dict
|
||||
|
||||
def test_full_creation(self):
|
||||
oc = ObligationCandidate(
|
||||
candidate_id="OC-MICA-0001-01",
|
||||
parent_control_uuid="uuid-abc",
|
||||
obligation_text="Betreiber müssen Kontinuität sicherstellen",
|
||||
action="sicherstellen",
|
||||
object_="Dienstleistungskontinuität",
|
||||
condition="bei Ausfall des Handelssystems",
|
||||
normative_strength="must",
|
||||
is_test_obligation=False,
|
||||
is_reporting_obligation=False,
|
||||
extraction_confidence=0.90,
|
||||
)
|
||||
assert oc.condition == "bei Ausfall des Handelssystems"
|
||||
assert oc.extraction_confidence == 0.90
|
||||
|
||||
|
||||
class TestAtomicControlCandidate:
|
||||
"""Tests for AtomicControlCandidate dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
ac = AtomicControlCandidate()
|
||||
assert ac.severity == "medium"
|
||||
assert ac.requirements == []
|
||||
assert ac.test_procedure == []
|
||||
|
||||
def test_to_dict(self):
|
||||
ac = AtomicControlCandidate(
|
||||
candidate_id="AC-FIN-001",
|
||||
title="Service Continuity Mechanism",
|
||||
objective="Ensure continuity upon failure.",
|
||||
requirements=["Failover mechanism"],
|
||||
)
|
||||
d = ac.to_dict()
|
||||
assert d["title"] == "Service Continuity Mechanism"
|
||||
assert len(d["requirements"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NORMATIVE SIGNAL DETECTION TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormativeSignals:
|
||||
"""Tests for normative regex patterns."""
|
||||
|
||||
def test_muessen_detected(self):
|
||||
assert _NORMATIVE_RE.search("Betreiber müssen sicherstellen")
|
||||
|
||||
def test_muss_detected(self):
|
||||
assert _NORMATIVE_RE.search("Das System muss implementiert sein")
|
||||
|
||||
def test_hat_sicherzustellen(self):
|
||||
assert _NORMATIVE_RE.search("Der Verantwortliche hat sicherzustellen")
|
||||
|
||||
def test_sind_verpflichtet(self):
|
||||
assert _NORMATIVE_RE.search("Anbieter sind verpflichtet zu melden")
|
||||
|
||||
def test_ist_zu_dokumentieren(self):
|
||||
assert _NORMATIVE_RE.search("Der Vorfall ist zu dokumentieren")
|
||||
|
||||
def test_shall(self):
|
||||
assert _NORMATIVE_RE.search("The operator shall implement MFA")
|
||||
|
||||
def test_no_signal(self):
|
||||
assert not _NORMATIVE_RE.search("Die Sonne scheint heute")
|
||||
|
||||
def test_rationale_detected(self):
|
||||
assert _RATIONALE_RE.search("da schwache Passwörter Risiken bergen")
|
||||
|
||||
def test_test_signal_detected(self):
|
||||
assert _TEST_RE.search("regelmäßige Tests der Wirksamkeit")
|
||||
|
||||
def test_reporting_signal_detected(self):
|
||||
assert _REPORTING_RE.search("Behörden sind zu unterrichten")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QUALITY GATE TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQualityGate:
|
||||
"""Tests for quality_gate function."""
|
||||
|
||||
def test_valid_normative_obligation(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Betreiber müssen Verschlüsselung implementieren",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["has_normative_signal"] is True
|
||||
assert flags["not_evidence_only"] is True
|
||||
assert flags["min_length"] is True
|
||||
assert flags["has_parent_link"] is True
|
||||
|
||||
def test_rationale_detected(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Schwache Passwörter können zu Risiken führen, weil sie leicht zu erraten sind",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["not_rationale"] is False
|
||||
|
||||
def test_evidence_only_rejected(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Screenshot der Konfiguration",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["not_evidence_only"] is False
|
||||
|
||||
def test_too_short_rejected(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="MFA",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["min_length"] is False
|
||||
|
||||
def test_no_parent_link(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="",
|
||||
obligation_text="Betreiber müssen MFA implementieren",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["has_parent_link"] is False
|
||||
|
||||
def test_multi_verb_detected(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Betreiber müssen implementieren und dokumentieren sowie regelmäßig testen",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["single_action"] is False
|
||||
|
||||
def test_single_verb_passes(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Betreiber müssen MFA für alle privilegierten Konten implementieren",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["single_action"] is True
|
||||
|
||||
def test_no_normative_signal(self):
|
||||
oc = ObligationCandidate(
|
||||
parent_control_uuid="uuid-1",
|
||||
obligation_text="Ein DR-Plan beschreibt die Wiederherstellungsprozeduren im Detail",
|
||||
)
|
||||
flags = quality_gate(oc)
|
||||
assert flags["has_normative_signal"] is False
|
||||
|
||||
|
||||
class TestPassesQualityGate:
|
||||
"""Tests for passes_quality_gate function."""
|
||||
|
||||
def test_all_critical_pass(self):
|
||||
flags = {
|
||||
"has_normative_signal": True,
|
||||
"single_action": True,
|
||||
"not_rationale": True,
|
||||
"not_evidence_only": True,
|
||||
"min_length": True,
|
||||
"has_parent_link": True,
|
||||
}
|
||||
assert passes_quality_gate(flags) is True
|
||||
|
||||
def test_no_normative_signal_fails(self):
|
||||
flags = {
|
||||
"has_normative_signal": False,
|
||||
"single_action": True,
|
||||
"not_rationale": True,
|
||||
"not_evidence_only": True,
|
||||
"min_length": True,
|
||||
"has_parent_link": True,
|
||||
}
|
||||
assert passes_quality_gate(flags) is False
|
||||
|
||||
def test_evidence_only_fails(self):
|
||||
flags = {
|
||||
"has_normative_signal": True,
|
||||
"single_action": True,
|
||||
"not_rationale": True,
|
||||
"not_evidence_only": False,
|
||||
"min_length": True,
|
||||
"has_parent_link": True,
|
||||
}
|
||||
assert passes_quality_gate(flags) is False
|
||||
|
||||
def test_non_critical_dont_block(self):
|
||||
"""single_action and not_rationale are NOT critical — should still pass."""
|
||||
flags = {
|
||||
"has_normative_signal": True,
|
||||
"single_action": False, # Not critical
|
||||
"not_rationale": False, # Not critical
|
||||
"not_evidence_only": True,
|
||||
"min_length": True,
|
||||
"has_parent_link": True,
|
||||
}
|
||||
assert passes_quality_gate(flags) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HELPER TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeExtractionConfidence:
|
||||
"""Tests for _compute_extraction_confidence."""
|
||||
|
||||
def test_all_flags_pass(self):
|
||||
flags = {
|
||||
"has_normative_signal": True,
|
||||
"single_action": True,
|
||||
"not_rationale": True,
|
||||
"not_evidence_only": True,
|
||||
"min_length": True,
|
||||
"has_parent_link": True,
|
||||
}
|
||||
assert _compute_extraction_confidence(flags) == 1.0
|
||||
|
||||
def test_no_flags_pass(self):
|
||||
flags = {
|
||||
"has_normative_signal": False,
|
||||
"single_action": False,
|
||||
"not_rationale": False,
|
||||
"not_evidence_only": False,
|
||||
"min_length": False,
|
||||
"has_parent_link": False,
|
||||
}
|
||||
assert _compute_extraction_confidence(flags) == 0.0
|
||||
|
||||
def test_partial_flags(self):
|
||||
flags = {
|
||||
"has_normative_signal": True, # 0.30
|
||||
"single_action": False,
|
||||
"not_rationale": True, # 0.20
|
||||
"not_evidence_only": True, # 0.15
|
||||
"min_length": True, # 0.10
|
||||
"has_parent_link": True, # 0.05
|
||||
}
|
||||
assert _compute_extraction_confidence(flags) == 0.80
|
||||
|
||||
|
||||
class TestParseJsonArray:
|
||||
"""Tests for _parse_json_array."""
|
||||
|
||||
def test_valid_array(self):
|
||||
result = _parse_json_array('[{"a": 1}, {"a": 2}]')
|
||||
assert len(result) == 2
|
||||
assert result[0]["a"] == 1
|
||||
|
||||
def test_single_object_wrapped(self):
|
||||
result = _parse_json_array('{"a": 1}')
|
||||
assert len(result) == 1
|
||||
|
||||
def test_embedded_in_text(self):
|
||||
result = _parse_json_array('Here is the result:\n[{"a": 1}]\nDone.')
|
||||
assert len(result) == 1
|
||||
|
||||
def test_invalid_returns_empty(self):
|
||||
result = _parse_json_array("not json at all")
|
||||
assert result == []
|
||||
|
||||
def test_empty_array(self):
|
||||
result = _parse_json_array("[]")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestParseJsonObject:
|
||||
"""Tests for _parse_json_object."""
|
||||
|
||||
def test_valid_object(self):
|
||||
result = _parse_json_object('{"title": "MFA"}')
|
||||
assert result["title"] == "MFA"
|
||||
|
||||
def test_embedded_in_text(self):
|
||||
result = _parse_json_object('```json\n{"title": "MFA"}\n```')
|
||||
assert result["title"] == "MFA"
|
||||
|
||||
def test_invalid_returns_empty(self):
|
||||
result = _parse_json_object("not json")
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestEnsureList:
|
||||
"""Tests for _ensure_list."""
|
||||
|
||||
def test_list_passthrough(self):
|
||||
assert _ensure_list(["a", "b"]) == ["a", "b"]
|
||||
|
||||
def test_string_wrapped(self):
|
||||
assert _ensure_list("hello") == ["hello"]
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _ensure_list("") == []
|
||||
|
||||
def test_none(self):
|
||||
assert _ensure_list(None) == []
|
||||
|
||||
def test_int(self):
|
||||
assert _ensure_list(42) == []
|
||||
|
||||
|
||||
class TestFormatField:
|
||||
"""Tests for _format_field."""
|
||||
|
||||
def test_string_passthrough(self):
|
||||
assert _format_field("hello") == "hello"
|
||||
|
||||
def test_json_list_string(self):
|
||||
result = _format_field('["Req 1", "Req 2"]')
|
||||
assert "- Req 1" in result
|
||||
assert "- Req 2" in result
|
||||
|
||||
def test_list_input(self):
|
||||
result = _format_field(["A", "B"])
|
||||
assert "- A" in result
|
||||
assert "- B" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _format_field("") == ""
|
||||
assert _format_field(None) == ""
|
||||
|
||||
|
||||
class TestFormatCitation:
|
||||
"""Tests for _format_citation."""
|
||||
|
||||
def test_json_dict(self):
|
||||
result = _format_citation('{"source": "MiCA", "article": "Art. 8"}')
|
||||
assert "MiCA" in result
|
||||
assert "Art. 8" in result
|
||||
|
||||
def test_plain_string(self):
|
||||
assert _format_citation("MiCA Art. 8") == "MiCA Art. 8"
|
||||
|
||||
def test_empty(self):
|
||||
assert _format_citation("") == ""
|
||||
assert _format_citation(None) == ""
|
||||
|
||||
|
||||
class TestNormalizeSeverity:
|
||||
"""Tests for _normalize_severity."""
|
||||
|
||||
def test_valid_values(self):
|
||||
assert _normalize_severity("critical") == "critical"
|
||||
assert _normalize_severity("HIGH") == "high"
|
||||
assert _normalize_severity(" Medium ") == "medium"
|
||||
assert _normalize_severity("low") == "low"
|
||||
|
||||
def test_invalid_defaults_to_medium(self):
|
||||
assert _normalize_severity("unknown") == "medium"
|
||||
assert _normalize_severity("") == "medium"
|
||||
assert _normalize_severity(None) == "medium"
|
||||
|
||||
|
||||
class TestTemplateFallback:
|
||||
"""Tests for _template_fallback."""
|
||||
|
||||
def test_normal_obligation(self):
|
||||
ac = _template_fallback(
|
||||
obligation_text="Betreiber müssen MFA implementieren",
|
||||
action="implementieren",
|
||||
object_="MFA",
|
||||
parent_title="Authentication Controls",
|
||||
parent_severity="high",
|
||||
parent_category="authentication",
|
||||
is_test=False,
|
||||
is_reporting=False,
|
||||
)
|
||||
assert "Implementieren" in ac.title
|
||||
assert ac.severity == "high"
|
||||
assert len(ac.requirements) == 1
|
||||
|
||||
def test_test_obligation(self):
|
||||
ac = _template_fallback(
|
||||
obligation_text="MFA muss regelmäßig getestet werden",
|
||||
action="testen",
|
||||
object_="MFA-Wirksamkeit",
|
||||
parent_title="MFA Control",
|
||||
parent_severity="medium",
|
||||
parent_category="auth",
|
||||
is_test=True,
|
||||
is_reporting=False,
|
||||
)
|
||||
assert "Test:" in ac.title
|
||||
assert "Testprotokoll" in ac.evidence
|
||||
|
||||
def test_reporting_obligation(self):
|
||||
ac = _template_fallback(
|
||||
obligation_text="Behörden sind über Vorfälle zu informieren",
|
||||
action="informieren",
|
||||
object_="zuständige Behörden",
|
||||
parent_title="Incident Reporting",
|
||||
parent_severity="high",
|
||||
parent_category="governance",
|
||||
is_test=False,
|
||||
is_reporting=True,
|
||||
)
|
||||
assert "Meldepflicht:" in ac.title
|
||||
assert "Meldeprozess-Dokumentation" in ac.evidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PROMPT BUILDER TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPromptBuilders:
|
||||
"""Tests for LLM prompt builders."""
|
||||
|
||||
def test_pass0a_prompt_contains_all_fields(self):
|
||||
prompt = _build_pass0a_prompt(
|
||||
title="MFA Control",
|
||||
objective="Implement MFA",
|
||||
requirements="- Require TOTP\n- Hardware key",
|
||||
test_procedure="- Test login",
|
||||
source_ref="DSGVO Art. 32",
|
||||
)
|
||||
assert "MFA Control" in prompt
|
||||
assert "Implement MFA" in prompt
|
||||
assert "Require TOTP" in prompt
|
||||
assert "DSGVO Art. 32" in prompt
|
||||
assert "JSON-Array" in prompt
|
||||
|
||||
def test_pass0b_prompt_contains_all_fields(self):
|
||||
prompt = _build_pass0b_prompt(
|
||||
obligation_text="MFA implementieren",
|
||||
action="implementieren",
|
||||
object_="MFA",
|
||||
parent_title="Auth Controls",
|
||||
parent_category="authentication",
|
||||
source_ref="DSGVO Art. 32",
|
||||
)
|
||||
assert "MFA implementieren" in prompt
|
||||
assert "implementieren" in prompt
|
||||
assert "Auth Controls" in prompt
|
||||
assert "JSON" in prompt
|
||||
|
||||
def test_system_prompts_exist(self):
|
||||
assert "REGELN" in _PASS0A_SYSTEM_PROMPT
|
||||
assert "atomares" in _PASS0B_SYSTEM_PROMPT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DECOMPOSITION PASS INTEGRATION TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDecompositionPassRun0a:
|
||||
"""Tests for DecompositionPass.run_pass0a."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0a_extracts_obligations(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Rich controls to decompose
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "CTRL-001",
|
||||
"Service Continuity",
|
||||
"Sicherstellen der Dienstleistungskontinuität",
|
||||
'["Mechanismen implementieren", "Systeme testen"]',
|
||||
'["Prüfung der Mechanismen"]',
|
||||
'{"source": "MiCA", "article": "Art. 8"}',
|
||||
"finance",
|
||||
),
|
||||
]
|
||||
mock_db.execute.return_value = mock_rows
|
||||
|
||||
llm_response = json.dumps([
|
||||
{
|
||||
"obligation_text": "Betreiber müssen Mechanismen zur Dienstleistungskontinuität implementieren",
|
||||
"action": "implementieren",
|
||||
"object": "Kontinuitätsmechanismen",
|
||||
"condition": "bei Ausfall des Handelssystems",
|
||||
"normative_strength": "must",
|
||||
"is_test_obligation": False,
|
||||
"is_reporting_obligation": False,
|
||||
},
|
||||
{
|
||||
"obligation_text": "Kontinuitätsmechanismen müssen regelmäßig getestet werden",
|
||||
"action": "testen",
|
||||
"object": "Kontinuitätsmechanismen",
|
||||
"condition": None,
|
||||
"normative_strength": "must",
|
||||
"is_test_obligation": True,
|
||||
"is_reporting_obligation": False,
|
||||
},
|
||||
])
|
||||
|
||||
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = llm_response
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0a(limit=10)
|
||||
|
||||
assert stats["controls_processed"] == 1
|
||||
assert stats["obligations_extracted"] == 2
|
||||
assert stats["obligations_validated"] == 2
|
||||
assert stats["errors"] == 0
|
||||
|
||||
# Verify DB writes: 1 SELECT + 2 INSERTs + 1 COMMIT
|
||||
assert mock_db.execute.call_count >= 3
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0a_fallback_on_empty_llm(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "CTRL-001",
|
||||
"MFA Control",
|
||||
"Betreiber müssen MFA implementieren",
|
||||
"", "", "", "auth",
|
||||
),
|
||||
]
|
||||
mock_db.execute.return_value = mock_rows
|
||||
|
||||
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = "I cannot help with that." # Invalid JSON
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0a(limit=10)
|
||||
|
||||
assert stats["controls_processed"] == 1
|
||||
# Fallback should create 1 obligation from the objective
|
||||
assert stats["obligations_extracted"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0a_skips_empty_controls(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
("uuid-1", "CTRL-001", "", "", "", "", "", ""),
|
||||
]
|
||||
mock_db.execute.return_value = mock_rows
|
||||
|
||||
# No LLM call needed — empty controls are skipped before LLM
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0a(limit=10)
|
||||
|
||||
assert stats["controls_skipped_empty"] == 1
|
||||
assert stats["controls_processed"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0a_rejects_evidence_only(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "CTRL-001",
|
||||
"Evidence List",
|
||||
"Betreiber müssen Nachweise erbringen",
|
||||
"", "", "", "governance",
|
||||
),
|
||||
]
|
||||
mock_db.execute.return_value = mock_rows
|
||||
|
||||
llm_response = json.dumps([
|
||||
{
|
||||
"obligation_text": "Dokumentation der Konfiguration",
|
||||
"action": "dokumentieren",
|
||||
"object": "Konfiguration",
|
||||
"condition": None,
|
||||
"normative_strength": "must",
|
||||
"is_test_obligation": False,
|
||||
"is_reporting_obligation": False,
|
||||
},
|
||||
])
|
||||
|
||||
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = llm_response
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0a(limit=10)
|
||||
|
||||
assert stats["obligations_extracted"] == 1
|
||||
assert stats["obligations_rejected"] == 1
|
||||
|
||||
|
||||
class TestDecompositionPassRun0b:
|
||||
"""Tests for DecompositionPass.run_pass0b."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0b_creates_atomic_controls(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Validated obligation candidates
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
(
|
||||
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
|
||||
"Betreiber müssen Kontinuität sicherstellen",
|
||||
"sicherstellen", "Dienstleistungskontinuität",
|
||||
False, False, # is_test, is_reporting
|
||||
"Service Continuity", "finance",
|
||||
'{"source": "MiCA", "article": "Art. 8"}',
|
||||
"high", "FIN-001",
|
||||
),
|
||||
]
|
||||
|
||||
# Mock _next_atomic_seq result
|
||||
mock_seq = MagicMock()
|
||||
mock_seq.fetchone.return_value = (0,)
|
||||
|
||||
# Call sequence: 1=SELECT, 2=_next_atomic_seq, 3=INSERT control, 4=UPDATE oc
|
||||
call_count = [0]
|
||||
def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return mock_rows # SELECT candidates
|
||||
if call_count[0] == 2:
|
||||
return mock_seq # _next_atomic_seq
|
||||
return MagicMock() # INSERT/UPDATE
|
||||
|
||||
mock_db.execute.side_effect = side_effect
|
||||
|
||||
llm_response = json.dumps({
|
||||
"title": "Dienstleistungskontinuität bei Systemausfall",
|
||||
"objective": "Sicherstellen, dass Dienstleistungen fortgeführt werden.",
|
||||
"requirements": ["Failover-Mechanismus implementieren"],
|
||||
"test_procedure": ["Failover-Test durchführen"],
|
||||
"evidence": ["Systemarchitektur", "DR-Plan"],
|
||||
"severity": "high",
|
||||
"category": "operations",
|
||||
})
|
||||
|
||||
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = llm_response
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0b(limit=10)
|
||||
|
||||
assert stats["candidates_processed"] == 1
|
||||
assert stats["controls_created"] == 1
|
||||
assert stats["llm_failures"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass0b_template_fallback(self):
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_rows = MagicMock()
|
||||
mock_rows.fetchall.return_value = [
|
||||
(
|
||||
"oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
|
||||
"Betreiber müssen MFA implementieren",
|
||||
"implementieren", "MFA",
|
||||
False, False,
|
||||
"Auth Controls", "authentication",
|
||||
"", "high", "AUTH-001",
|
||||
),
|
||||
]
|
||||
|
||||
mock_seq = MagicMock()
|
||||
mock_seq.fetchone.return_value = (0,)
|
||||
|
||||
call_count = [0]
|
||||
def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return mock_rows
|
||||
if call_count[0] == 2:
|
||||
return mock_seq
|
||||
return MagicMock()
|
||||
|
||||
mock_db.execute.side_effect = side_effect
|
||||
|
||||
with patch("compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = "Sorry, invalid response" # LLM fails
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
stats = await decomp.run_pass0b(limit=10)
|
||||
|
||||
assert stats["controls_created"] == 1
|
||||
assert stats["llm_failures"] == 1
|
||||
|
||||
|
||||
class TestDecompositionStatus:
|
||||
"""Tests for DecompositionPass.decomposition_status."""
|
||||
|
||||
def test_returns_status(self):
|
||||
mock_db = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.fetchone.return_value = (5000, 1000, 3000, 2500, 200, 2000, 1800)
|
||||
mock_db.execute.return_value = mock_result
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
status = decomp.decomposition_status()
|
||||
|
||||
assert status["rich_controls"] == 5000
|
||||
assert status["decomposed_controls"] == 1000
|
||||
assert status["total_candidates"] == 3000
|
||||
assert status["validated"] == 2500
|
||||
assert status["rejected"] == 200
|
||||
assert status["composed"] == 2000
|
||||
assert status["atomic_controls"] == 1800
|
||||
assert status["decomposition_pct"] == 20.0
|
||||
assert status["composition_pct"] == 80.0
|
||||
|
||||
def test_handles_zero_division(self):
|
||||
mock_db = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.fetchone.return_value = (0, 0, 0, 0, 0, 0, 0)
|
||||
mock_db.execute.return_value = mock_result
|
||||
|
||||
decomp = DecompositionPass(db=mock_db)
|
||||
status = decomp.decomposition_status()
|
||||
|
||||
assert status["decomposition_pct"] == 0.0
|
||||
assert status["composition_pct"] == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIGRATION 061 SCHEMA TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigration061:
|
||||
"""Tests for migration 061 SQL file."""
|
||||
|
||||
def test_migration_file_exists(self):
|
||||
from pathlib import Path
|
||||
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
|
||||
assert migration.exists(), "Migration 061 file missing"
|
||||
|
||||
def test_migration_contains_required_tables(self):
|
||||
from pathlib import Path
|
||||
migration = Path(__file__).parent.parent / "migrations" / "061_obligation_candidates.sql"
|
||||
content = migration.read_text()
|
||||
assert "obligation_candidates" in content
|
||||
assert "parent_control_uuid" in content
|
||||
assert "decomposition_method" in content
|
||||
assert "candidate_id" in content
|
||||
assert "quality_flags" in content
|
||||
428
backend-compliance/tests/test_migration_060.py
Normal file
428
backend-compliance/tests/test_migration_060.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for Migration 060: Multi-Layer Control Architecture DB Schema.
|
||||
|
||||
Validates SQL syntax, table definitions, constraints, and indexes
|
||||
defined in 060_crosswalk_matrix.sql.
|
||||
|
||||
Uses an in-memory SQLite-compatible approach: we parse the SQL and validate
|
||||
the structure, then run it against a real PostgreSQL test database if available.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
MIGRATION_FILE = (
|
||||
Path(__file__).resolve().parent.parent / "migrations" / "060_crosswalk_matrix.sql"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_sql():
|
||||
"""Load the migration SQL file."""
|
||||
assert MIGRATION_FILE.exists(), f"Migration file not found: {MIGRATION_FILE}"
|
||||
return MIGRATION_FILE.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SQL File Structure Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMigrationFileStructure:
|
||||
"""Validate the migration file exists and has correct structure."""
|
||||
|
||||
def test_file_exists(self):
|
||||
assert MIGRATION_FILE.exists()
|
||||
|
||||
def test_file_not_empty(self, migration_sql):
|
||||
assert len(migration_sql.strip()) > 100
|
||||
|
||||
def test_has_migration_header_comment(self, migration_sql):
|
||||
assert "Migration 060" in migration_sql
|
||||
assert "Multi-Layer Control Architecture" in migration_sql
|
||||
|
||||
def test_no_explicit_transaction_control(self, migration_sql):
|
||||
"""Migration runner strips BEGIN/COMMIT — file should not contain them."""
|
||||
lines = migration_sql.split("\n")
|
||||
for line in lines:
|
||||
stripped = line.strip().upper()
|
||||
if stripped.startswith("--"):
|
||||
continue
|
||||
assert stripped != "BEGIN;", "Migration should not contain explicit BEGIN"
|
||||
assert stripped != "COMMIT;", "Migration should not contain explicit COMMIT"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Table Definition Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestObligationExtractionsTable:
|
||||
"""Validate obligation_extractions table definition."""
|
||||
|
||||
def test_create_table_present(self, migration_sql):
|
||||
assert "CREATE TABLE IF NOT EXISTS obligation_extractions" in migration_sql
|
||||
|
||||
def test_has_primary_key(self, migration_sql):
|
||||
# Extract the CREATE TABLE block
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "id UUID PRIMARY KEY" in block
|
||||
|
||||
def test_has_chunk_hash_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "chunk_hash VARCHAR(64) NOT NULL" in block
|
||||
|
||||
def test_has_collection_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "collection VARCHAR(100) NOT NULL" in block
|
||||
|
||||
def test_has_regulation_code_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "regulation_code VARCHAR(100) NOT NULL" in block
|
||||
|
||||
def test_has_obligation_id_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "obligation_id VARCHAR(50)" in block
|
||||
|
||||
def test_has_confidence_column_with_check(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "confidence NUMERIC(3,2)" in block
|
||||
assert "confidence >= 0" in block
|
||||
assert "confidence <= 1" in block
|
||||
|
||||
def test_extraction_method_check_constraint(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "extraction_method VARCHAR(30) NOT NULL" in block
|
||||
for method in ("exact_match", "embedding_match", "llm_extracted", "inferred"):
|
||||
assert method in block, f"Missing extraction_method: {method}"
|
||||
|
||||
def test_has_pattern_id_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "pattern_id VARCHAR(50)" in block
|
||||
|
||||
def test_has_pattern_match_score_with_check(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "pattern_match_score NUMERIC(3,2)" in block
|
||||
|
||||
def test_has_control_uuid_fk(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "control_uuid UUID REFERENCES canonical_controls(id)" in block
|
||||
|
||||
def test_has_job_id_fk(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "job_id UUID REFERENCES canonical_generation_jobs(id)" in block
|
||||
|
||||
def test_has_created_at(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "created_at TIMESTAMPTZ" in block
|
||||
|
||||
def test_indexes_created(self, migration_sql):
|
||||
expected_indexes = [
|
||||
"idx_oe_obligation",
|
||||
"idx_oe_pattern",
|
||||
"idx_oe_control",
|
||||
"idx_oe_regulation",
|
||||
"idx_oe_chunk",
|
||||
"idx_oe_method",
|
||||
]
|
||||
for idx in expected_indexes:
|
||||
assert idx in migration_sql, f"Missing index: {idx}"
|
||||
|
||||
|
||||
class TestControlPatternsTable:
|
||||
"""Validate control_patterns table definition."""
|
||||
|
||||
def test_create_table_present(self, migration_sql):
|
||||
assert "CREATE TABLE IF NOT EXISTS control_patterns" in migration_sql
|
||||
|
||||
def test_has_primary_key(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "id UUID PRIMARY KEY" in block
|
||||
|
||||
def test_pattern_id_unique(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "pattern_id VARCHAR(50) UNIQUE NOT NULL" in block
|
||||
|
||||
def test_has_name_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "name VARCHAR(255) NOT NULL" in block
|
||||
|
||||
def test_has_name_de_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "name_de VARCHAR(255)" in block
|
||||
|
||||
def test_has_domain_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "domain VARCHAR(10) NOT NULL" in block
|
||||
|
||||
def test_has_category_column(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "category VARCHAR(50)" in block
|
||||
|
||||
def test_has_template_fields(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "template_objective TEXT" in block
|
||||
assert "template_rationale TEXT" in block
|
||||
assert "template_requirements JSONB" in block
|
||||
assert "template_test_procedure JSONB" in block
|
||||
assert "template_evidence JSONB" in block
|
||||
|
||||
def test_severity_check_constraint(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
for severity in ("low", "medium", "high", "critical"):
|
||||
assert severity in block, f"Missing severity: {severity}"
|
||||
|
||||
def test_effort_check_constraint(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "implementation_effort_default" in block
|
||||
|
||||
def test_has_keyword_and_tag_fields(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "obligation_match_keywords JSONB" in block
|
||||
assert "tags JSONB" in block
|
||||
|
||||
def test_has_anchor_refs(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "open_anchor_refs JSONB" in block
|
||||
|
||||
def test_has_composable_with(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "composable_with JSONB" in block
|
||||
|
||||
def test_has_version(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "control_patterns")
|
||||
assert "version VARCHAR(10)" in block
|
||||
|
||||
def test_indexes_created(self, migration_sql):
|
||||
expected_indexes = ["idx_cp_domain", "idx_cp_category", "idx_cp_pattern_id"]
|
||||
for idx in expected_indexes:
|
||||
assert idx in migration_sql, f"Missing index: {idx}"
|
||||
|
||||
|
||||
class TestCrosswalkMatrixTable:
|
||||
"""Validate crosswalk_matrix table definition."""
|
||||
|
||||
def test_create_table_present(self, migration_sql):
|
||||
assert "CREATE TABLE IF NOT EXISTS crosswalk_matrix" in migration_sql
|
||||
|
||||
def test_has_primary_key(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "id UUID PRIMARY KEY" in block
|
||||
|
||||
def test_has_regulation_code(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "regulation_code VARCHAR(100) NOT NULL" in block
|
||||
|
||||
def test_has_article_paragraph(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "article VARCHAR(100)" in block
|
||||
assert "paragraph VARCHAR(100)" in block
|
||||
|
||||
def test_has_obligation_id(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "obligation_id VARCHAR(50)" in block
|
||||
|
||||
def test_has_pattern_id(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "pattern_id VARCHAR(50)" in block
|
||||
|
||||
def test_has_master_control_fields(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "master_control_id VARCHAR(20)" in block
|
||||
assert "master_control_uuid UUID REFERENCES canonical_controls(id)" in block
|
||||
|
||||
def test_has_tom_control_id(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "tom_control_id VARCHAR(30)" in block
|
||||
|
||||
def test_confidence_check(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "confidence NUMERIC(3,2)" in block
|
||||
|
||||
def test_source_check_constraint(self, migration_sql):
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
for source_val in ("manual", "auto", "migrated"):
|
||||
assert source_val in block, f"Missing source value: {source_val}"
|
||||
|
||||
def test_indexes_created(self, migration_sql):
|
||||
expected_indexes = [
|
||||
"idx_cw_regulation",
|
||||
"idx_cw_obligation",
|
||||
"idx_cw_pattern",
|
||||
"idx_cw_control",
|
||||
"idx_cw_tom",
|
||||
]
|
||||
for idx in expected_indexes:
|
||||
assert idx in migration_sql, f"Missing index: {idx}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ALTER TABLE Tests (canonical_controls extensions)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCanonicalControlsExtension:
|
||||
"""Validate ALTER TABLE additions to canonical_controls."""
|
||||
|
||||
def test_adds_pattern_id_column(self, migration_sql):
|
||||
assert "ALTER TABLE canonical_controls" in migration_sql
|
||||
assert "pattern_id VARCHAR(50)" in migration_sql
|
||||
|
||||
def test_adds_obligation_ids_column(self, migration_sql):
|
||||
assert "obligation_ids JSONB" in migration_sql
|
||||
|
||||
def test_uses_if_not_exists(self, migration_sql):
|
||||
alter_lines = [
|
||||
line.strip()
|
||||
for line in migration_sql.split("\n")
|
||||
if "ALTER TABLE canonical_controls" in line
|
||||
and "ADD COLUMN" in line
|
||||
]
|
||||
for line in alter_lines:
|
||||
assert "IF NOT EXISTS" in line, (
|
||||
f"ALTER TABLE missing IF NOT EXISTS: {line}"
|
||||
)
|
||||
|
||||
def test_pattern_id_index(self, migration_sql):
|
||||
assert "idx_cc_pattern" in migration_sql
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-Cutting Concerns
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSQLSafety:
|
||||
"""Validate SQL safety and idempotency."""
|
||||
|
||||
def test_all_tables_use_if_not_exists(self, migration_sql):
|
||||
create_statements = re.findall(
|
||||
r"CREATE TABLE\s+(?:IF NOT EXISTS\s+)?(\w+)", migration_sql
|
||||
)
|
||||
for match in re.finditer(r"CREATE TABLE\s+(\w+)", migration_sql):
|
||||
table_name = match.group(1)
|
||||
if table_name == "IF":
|
||||
continue # This is part of "IF NOT EXISTS"
|
||||
full_match = migration_sql[match.start() : match.start() + 60]
|
||||
assert "IF NOT EXISTS" in full_match, (
|
||||
f"CREATE TABLE {table_name} missing IF NOT EXISTS"
|
||||
)
|
||||
|
||||
def test_all_indexes_use_if_not_exists(self, migration_sql):
|
||||
for match in re.finditer(r"CREATE INDEX\s+(\w+)", migration_sql):
|
||||
idx_name = match.group(1)
|
||||
if idx_name == "IF":
|
||||
continue
|
||||
full_match = migration_sql[match.start() : match.start() + 80]
|
||||
assert "IF NOT EXISTS" in full_match, (
|
||||
f"CREATE INDEX {idx_name} missing IF NOT EXISTS"
|
||||
)
|
||||
|
||||
def test_no_drop_statements(self, migration_sql):
|
||||
"""Migration should only add, never drop."""
|
||||
lines = [
|
||||
l.strip()
|
||||
for l in migration_sql.split("\n")
|
||||
if not l.strip().startswith("--")
|
||||
]
|
||||
sql_content = "\n".join(lines)
|
||||
assert "DROP TABLE" not in sql_content
|
||||
assert "DROP INDEX" not in sql_content
|
||||
assert "DROP COLUMN" not in sql_content
|
||||
|
||||
def test_no_truncate(self, migration_sql):
|
||||
lines = [
|
||||
l.strip()
|
||||
for l in migration_sql.split("\n")
|
||||
if not l.strip().startswith("--")
|
||||
]
|
||||
sql_content = "\n".join(lines)
|
||||
assert "TRUNCATE" not in sql_content
|
||||
|
||||
def test_fk_references_existing_tables(self, migration_sql):
|
||||
"""All REFERENCES must point to canonical_controls or canonical_generation_jobs."""
|
||||
refs = re.findall(r"REFERENCES\s+(\w+)\(", migration_sql)
|
||||
allowed_tables = {"canonical_controls", "canonical_generation_jobs"}
|
||||
for ref in refs:
|
||||
assert ref in allowed_tables, (
|
||||
f"FK reference to unknown table: {ref}"
|
||||
)
|
||||
|
||||
def test_consistent_varchar_sizes(self, migration_sql):
|
||||
"""Key fields should use consistent sizes across tables."""
|
||||
# obligation_id should be VARCHAR(50) everywhere
|
||||
obligation_id_matches = re.findall(
|
||||
r"obligation_id\s+VARCHAR\((\d+)\)", migration_sql
|
||||
)
|
||||
for size in obligation_id_matches:
|
||||
assert size == "50", f"obligation_id should be VARCHAR(50), got {size}"
|
||||
|
||||
# pattern_id should be VARCHAR(50) everywhere
|
||||
pattern_id_matches = re.findall(
|
||||
r"pattern_id\s+VARCHAR\((\d+)\)", migration_sql
|
||||
)
|
||||
for size in pattern_id_matches:
|
||||
assert size == "50", f"pattern_id should be VARCHAR(50), got {size}"
|
||||
|
||||
# regulation_code should be VARCHAR(100) everywhere
|
||||
reg_code_matches = re.findall(
|
||||
r"regulation_code\s+VARCHAR\((\d+)\)", migration_sql
|
||||
)
|
||||
for size in reg_code_matches:
|
||||
assert size == "100", f"regulation_code should be VARCHAR(100), got {size}"
|
||||
|
||||
|
||||
class TestTableComments:
|
||||
"""Validate that all new tables have COMMENT ON TABLE."""
|
||||
|
||||
def test_obligation_extractions_comment(self, migration_sql):
|
||||
assert "COMMENT ON TABLE obligation_extractions" in migration_sql
|
||||
|
||||
def test_control_patterns_comment(self, migration_sql):
|
||||
assert "COMMENT ON TABLE control_patterns" in migration_sql
|
||||
|
||||
def test_crosswalk_matrix_comment(self, migration_sql):
|
||||
assert "COMMENT ON TABLE crosswalk_matrix" in migration_sql
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Type Compatibility Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDataTypeCompatibility:
|
||||
"""Ensure data types are compatible with existing schema."""
|
||||
|
||||
def test_chunk_hash_matches_processed_chunks(self, migration_sql):
|
||||
"""chunk_hash in obligation_extractions should match canonical_processed_chunks."""
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "chunk_hash VARCHAR(64)" in block
|
||||
|
||||
def test_collection_matches_processed_chunks(self, migration_sql):
|
||||
"""collection size should match canonical_processed_chunks."""
|
||||
block = _extract_create_table(migration_sql, "obligation_extractions")
|
||||
assert "collection VARCHAR(100)" in block
|
||||
|
||||
def test_control_id_size_matches_canonical_controls(self, migration_sql):
|
||||
"""master_control_id VARCHAR(20) should match canonical_controls.control_id VARCHAR(20)."""
|
||||
block = _extract_create_table(migration_sql, "crosswalk_matrix")
|
||||
assert "master_control_id VARCHAR(20)" in block
|
||||
|
||||
def test_pattern_id_format_documented(self, migration_sql):
|
||||
"""Pattern ID format CP-{DOMAIN}-{NNN} should be documented."""
|
||||
assert "CP-{DOMAIN}-{NNN}" in migration_sql or "CP-" in migration_sql
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _extract_create_table(sql: str, table_name: str) -> str:
|
||||
"""Extract a CREATE TABLE block from SQL."""
|
||||
pattern = rf"CREATE TABLE IF NOT EXISTS {table_name}\s*\((.*?)\);"
|
||||
match = re.search(pattern, sql, re.DOTALL)
|
||||
if not match:
|
||||
pytest.fail(f"Could not find CREATE TABLE for {table_name}")
|
||||
return match.group(1)
|
||||
939
backend-compliance/tests/test_obligation_extractor.py
Normal file
939
backend-compliance/tests/test_obligation_extractor.py
Normal file
@@ -0,0 +1,939 @@
|
||||
"""Tests for Obligation Extractor — Phase 4 of Multi-Layer Control Architecture.
|
||||
|
||||
Validates:
|
||||
- Regulation code normalization (_normalize_regulation)
|
||||
- Article reference normalization (_normalize_article)
|
||||
- Cosine similarity (_cosine_sim)
|
||||
- JSON parsing from LLM responses (_parse_json)
|
||||
- Obligation loading from v2 framework
|
||||
- 3-Tier extraction: exact_match → embedding_match → llm_extracted
|
||||
- ObligationMatch serialization
|
||||
- Edge cases: empty inputs, missing data, fallback behavior
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from compliance.services.obligation_extractor import (
|
||||
EMBEDDING_CANDIDATE_THRESHOLD,
|
||||
EMBEDDING_MATCH_THRESHOLD,
|
||||
ObligationExtractor,
|
||||
ObligationMatch,
|
||||
_ObligationEntry,
|
||||
_cosine_sim,
|
||||
_find_obligations_dir,
|
||||
_normalize_article,
|
||||
_normalize_regulation,
|
||||
_parse_json,
|
||||
)
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
V2_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "obligations" / "v2"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _normalize_regulation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNormalizeRegulation:
|
||||
"""Tests for regulation code normalization."""
|
||||
|
||||
def test_dsgvo_eu_code(self):
|
||||
assert _normalize_regulation("eu_2016_679") == "dsgvo"
|
||||
|
||||
def test_dsgvo_short(self):
|
||||
assert _normalize_regulation("dsgvo") == "dsgvo"
|
||||
|
||||
def test_gdpr_alias(self):
|
||||
assert _normalize_regulation("gdpr") == "dsgvo"
|
||||
|
||||
def test_ai_act_eu_code(self):
|
||||
assert _normalize_regulation("eu_2024_1689") == "ai_act"
|
||||
|
||||
def test_ai_act_short(self):
|
||||
assert _normalize_regulation("ai_act") == "ai_act"
|
||||
|
||||
def test_nis2_eu_code(self):
|
||||
assert _normalize_regulation("eu_2022_2555") == "nis2"
|
||||
|
||||
def test_nis2_short(self):
|
||||
assert _normalize_regulation("nis2") == "nis2"
|
||||
|
||||
def test_bsig_alias(self):
|
||||
assert _normalize_regulation("bsig") == "nis2"
|
||||
|
||||
def test_bdsg(self):
|
||||
assert _normalize_regulation("bdsg") == "bdsg"
|
||||
|
||||
def test_ttdsg(self):
|
||||
assert _normalize_regulation("ttdsg") == "ttdsg"
|
||||
|
||||
def test_dsa_eu_code(self):
|
||||
assert _normalize_regulation("eu_2022_2065") == "dsa"
|
||||
|
||||
def test_data_act_eu_code(self):
|
||||
assert _normalize_regulation("eu_2023_2854") == "data_act"
|
||||
|
||||
def test_eu_machinery_eu_code(self):
|
||||
assert _normalize_regulation("eu_2023_1230") == "eu_machinery"
|
||||
|
||||
def test_dora_eu_code(self):
|
||||
assert _normalize_regulation("eu_2022_2554") == "dora"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _normalize_regulation("DSGVO") == "dsgvo"
|
||||
assert _normalize_regulation("AI_ACT") == "ai_act"
|
||||
assert _normalize_regulation("NIS2") == "nis2"
|
||||
|
||||
def test_whitespace_stripped(self):
|
||||
assert _normalize_regulation(" dsgvo ") == "dsgvo"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _normalize_regulation("") is None
|
||||
|
||||
def test_none(self):
|
||||
assert _normalize_regulation(None) is None
|
||||
|
||||
def test_unknown_code(self):
|
||||
assert _normalize_regulation("mica") is None
|
||||
|
||||
def test_prefix_matching(self):
|
||||
"""EU codes with suffixes should still match via prefix."""
|
||||
assert _normalize_regulation("eu_2016_679_consolidated") == "dsgvo"
|
||||
|
||||
def test_all_nine_regulations_covered(self):
|
||||
"""Every regulation in the manifest should be normalizable."""
|
||||
regulation_ids = ["dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
|
||||
"data_act", "eu_machinery", "dora"]
|
||||
for reg_id in regulation_ids:
|
||||
result = _normalize_regulation(reg_id)
|
||||
assert result == reg_id, f"Regulation {reg_id} not found"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _normalize_article
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNormalizeArticle:
|
||||
"""Tests for article reference normalization."""
|
||||
|
||||
def test_art_with_dot(self):
|
||||
assert _normalize_article("Art. 30") == "art. 30"
|
||||
|
||||
def test_article_english(self):
|
||||
assert _normalize_article("Article 10") == "art. 10"
|
||||
|
||||
def test_artikel_german(self):
|
||||
assert _normalize_article("Artikel 35") == "art. 35"
|
||||
|
||||
def test_paragraph_symbol(self):
|
||||
assert _normalize_article("§ 38") == "§ 38"
|
||||
|
||||
def test_paragraph_with_law_suffix(self):
|
||||
"""§ 38 BDSG → § 38 (law name stripped)."""
|
||||
assert _normalize_article("§ 38 BDSG") == "§ 38"
|
||||
|
||||
def test_paragraph_with_dsgvo_suffix(self):
|
||||
assert _normalize_article("Art. 6 DSGVO") == "art. 6"
|
||||
|
||||
def test_removes_absatz(self):
|
||||
"""Art. 30 Abs. 1 → art. 30"""
|
||||
assert _normalize_article("Art. 30 Abs. 1") == "art. 30"
|
||||
|
||||
def test_removes_paragraph(self):
|
||||
assert _normalize_article("Art. 5 paragraph 2") == "art. 5"
|
||||
|
||||
def test_removes_lit(self):
|
||||
assert _normalize_article("Art. 6 lit. a") == "art. 6"
|
||||
|
||||
def test_removes_satz(self):
|
||||
assert _normalize_article("Art. 12 Satz 3") == "art. 12"
|
||||
|
||||
def test_lowercase_output(self):
|
||||
assert _normalize_article("ART. 30") == "art. 30"
|
||||
assert _normalize_article("ARTICLE 10") == "art. 10"
|
||||
|
||||
def test_whitespace_stripped(self):
|
||||
assert _normalize_article(" Art. 30 ") == "art. 30"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _normalize_article("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert _normalize_article(None) == ""
|
||||
|
||||
def test_complex_reference(self):
|
||||
"""Art. 30 Abs. 1 Satz 2 lit. c DSGVO → art. 30"""
|
||||
result = _normalize_article("Art. 30 Abs. 1 Satz 2 lit. c DSGVO")
|
||||
# Should at minimum remove DSGVO and Abs references
|
||||
assert result.startswith("art. 30")
|
||||
|
||||
def test_nis2_article(self):
|
||||
assert _normalize_article("Art. 21 NIS2") == "art. 21"
|
||||
|
||||
def test_dora_article(self):
|
||||
assert _normalize_article("Art. 5 DORA") == "art. 5"
|
||||
|
||||
def test_ai_act_article(self):
|
||||
result = _normalize_article("Article 6 AI Act")
|
||||
assert result == "art. 6"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _cosine_sim
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCosineSim:
|
||||
"""Tests for cosine similarity calculation."""
|
||||
|
||||
def test_identical_vectors(self):
|
||||
v = [1.0, 2.0, 3.0]
|
||||
assert abs(_cosine_sim(v, v) - 1.0) < 1e-6
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [0.0, 1.0]
|
||||
assert abs(_cosine_sim(a, b)) < 1e-6
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
a = [1.0, 2.0, 3.0]
|
||||
b = [-1.0, -2.0, -3.0]
|
||||
assert abs(_cosine_sim(a, b) - (-1.0)) < 1e-6
|
||||
|
||||
def test_known_value(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [1.0, 1.0]
|
||||
expected = 1.0 / math.sqrt(2)
|
||||
assert abs(_cosine_sim(a, b) - expected) < 1e-6
|
||||
|
||||
def test_empty_vectors(self):
|
||||
assert _cosine_sim([], []) == 0.0
|
||||
|
||||
def test_one_empty(self):
|
||||
assert _cosine_sim([1.0, 2.0], []) == 0.0
|
||||
assert _cosine_sim([], [1.0, 2.0]) == 0.0
|
||||
|
||||
def test_different_lengths(self):
|
||||
assert _cosine_sim([1.0, 2.0], [1.0]) == 0.0
|
||||
|
||||
def test_zero_vector(self):
|
||||
assert _cosine_sim([0.0, 0.0], [1.0, 2.0]) == 0.0
|
||||
|
||||
def test_both_zero(self):
|
||||
assert _cosine_sim([0.0, 0.0], [0.0, 0.0]) == 0.0
|
||||
|
||||
def test_high_dimensional(self):
|
||||
"""Test with realistic embedding dimensions (1024)."""
|
||||
import random
|
||||
random.seed(42)
|
||||
a = [random.gauss(0, 1) for _ in range(1024)]
|
||||
b = [random.gauss(0, 1) for _ in range(1024)]
|
||||
score = _cosine_sim(a, b)
|
||||
assert -1.0 <= score <= 1.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _parse_json
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
"""Tests for JSON extraction from LLM responses."""
|
||||
|
||||
def test_direct_json(self):
|
||||
text = '{"obligation_text": "Test", "actor": "Controller"}'
|
||||
result = _parse_json(text)
|
||||
assert result["obligation_text"] == "Test"
|
||||
assert result["actor"] == "Controller"
|
||||
|
||||
def test_json_in_markdown_block(self):
|
||||
"""LLMs often wrap JSON in markdown code blocks."""
|
||||
text = '''Some explanation text
|
||||
```json
|
||||
{"obligation_text": "Test"}
|
||||
```
|
||||
More text'''
|
||||
result = _parse_json(text)
|
||||
assert result.get("obligation_text") == "Test"
|
||||
|
||||
def test_json_with_prefix_text(self):
|
||||
text = 'Here is the result: {"obligation_text": "Pflicht", "actor": "Verantwortlicher"}'
|
||||
result = _parse_json(text)
|
||||
assert result["obligation_text"] == "Pflicht"
|
||||
|
||||
def test_invalid_json(self):
|
||||
result = _parse_json("not json at all")
|
||||
assert result == {}
|
||||
|
||||
def test_empty_string(self):
|
||||
result = _parse_json("")
|
||||
assert result == {}
|
||||
|
||||
def test_nested_braces_picks_first(self):
|
||||
"""With nested objects, the regex picks the inner simple object."""
|
||||
text = '{"outer": {"inner": "value"}}'
|
||||
result = _parse_json(text)
|
||||
# Direct parse should work for valid nested JSON
|
||||
assert "outer" in result
|
||||
|
||||
def test_json_with_german_umlauts(self):
|
||||
text = '{"obligation_text": "Pflicht zur Datenschutz-Folgenabschaetzung"}'
|
||||
result = _parse_json(text)
|
||||
assert "Datenschutz" in result["obligation_text"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationMatch
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestObligationMatch:
|
||||
"""Tests for the ObligationMatch dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
match = ObligationMatch()
|
||||
assert match.obligation_id is None
|
||||
assert match.obligation_title is None
|
||||
assert match.obligation_text is None
|
||||
assert match.method == "none"
|
||||
assert match.confidence == 0.0
|
||||
assert match.regulation_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
match = ObligationMatch(
|
||||
obligation_id="DSGVO-OBL-001",
|
||||
obligation_title="Verarbeitungsverzeichnis",
|
||||
obligation_text="Fuehrung eines Verzeichnisses...",
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
d = match.to_dict()
|
||||
assert d["obligation_id"] == "DSGVO-OBL-001"
|
||||
assert d["method"] == "exact_match"
|
||||
assert d["confidence"] == 1.0
|
||||
assert d["regulation_id"] == "dsgvo"
|
||||
|
||||
def test_to_dict_keys(self):
|
||||
match = ObligationMatch()
|
||||
d = match.to_dict()
|
||||
expected_keys = {
|
||||
"obligation_id", "obligation_title", "obligation_text",
|
||||
"method", "confidence", "regulation_id",
|
||||
}
|
||||
assert set(d.keys()) == expected_keys
|
||||
|
||||
def test_to_dict_none_values(self):
|
||||
match = ObligationMatch()
|
||||
d = match.to_dict()
|
||||
assert d["obligation_id"] is None
|
||||
assert d["obligation_title"] is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _find_obligations_dir
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFindObligationsDir:
|
||||
"""Tests for finding the v2 obligations directory."""
|
||||
|
||||
def test_finds_v2_directory(self):
|
||||
"""Should find the v2 dir relative to the source file."""
|
||||
result = _find_obligations_dir()
|
||||
# May be None in CI without the SDK, but if found, verify it's valid
|
||||
if result is not None:
|
||||
assert result.is_dir()
|
||||
assert (result / "_manifest.json").exists()
|
||||
|
||||
def test_v2_dir_exists_in_repo(self):
|
||||
"""The v2 dir should exist in the repo for local tests."""
|
||||
assert V2_DIR.exists(), f"v2 dir not found at {V2_DIR}"
|
||||
assert (V2_DIR / "_manifest.json").exists()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — _load_obligations
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestObligationExtractorLoad:
|
||||
"""Tests for obligation loading from v2 JSON files."""
|
||||
|
||||
def test_load_obligations_populates_lookup(self):
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
assert len(extractor._obligations) > 0
|
||||
|
||||
def test_load_obligations_count(self):
|
||||
"""Should load all 325 obligations from 9 regulations."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
assert len(extractor._obligations) == 325
|
||||
|
||||
def test_article_lookup_populated(self):
|
||||
"""Article lookup should have entries for obligations with legal_basis."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
assert len(extractor._article_lookup) > 0
|
||||
|
||||
def test_article_lookup_dsgvo_art30(self):
|
||||
"""DSGVO Art. 30 should resolve to DSGVO-OBL-001."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
key = "dsgvo/art. 30"
|
||||
assert key in extractor._article_lookup
|
||||
assert "DSGVO-OBL-001" in extractor._article_lookup[key]
|
||||
|
||||
def test_obligations_have_required_fields(self):
|
||||
"""Every loaded obligation should have id, title, description, regulation_id."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
for obl_id, entry in extractor._obligations.items():
|
||||
assert entry.id == obl_id
|
||||
assert entry.title, f"{obl_id}: empty title"
|
||||
assert entry.description, f"{obl_id}: empty description"
|
||||
assert entry.regulation_id, f"{obl_id}: empty regulation_id"
|
||||
|
||||
def test_all_nine_regulations_loaded(self):
|
||||
"""All 9 regulations from the manifest should be loaded."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
regulation_ids = {e.regulation_id for e in extractor._obligations.values()}
|
||||
expected = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
|
||||
"data_act", "eu_machinery", "dora"}
|
||||
assert regulation_ids == expected
|
||||
|
||||
def test_obligation_id_format(self):
|
||||
"""All obligation IDs should follow the pattern {REG}-OBL-{NNN}."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
import re
|
||||
# Allow letters, digits, underscores in prefix (e.g. NIS2-OBL-001, EU_MACHINERY-OBL-001)
|
||||
pattern = re.compile(r"^[A-Z0-9_]+-OBL-\d{3}$")
|
||||
for obl_id in extractor._obligations:
|
||||
assert pattern.match(obl_id), f"Invalid obligation ID format: {obl_id}"
|
||||
|
||||
def test_no_duplicate_obligation_ids(self):
|
||||
"""All obligation IDs should be unique."""
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
ids = list(extractor._obligations.keys())
|
||||
assert len(ids) == len(set(ids))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — Tier 1 (Exact Match)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTier1ExactMatch:
|
||||
"""Tests for Tier 1 exact article lookup."""
|
||||
|
||||
def setup_method(self):
|
||||
self.extractor = ObligationExtractor()
|
||||
self.extractor._load_obligations()
|
||||
|
||||
def test_exact_match_dsgvo_art30(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
||||
assert match is not None
|
||||
assert match.obligation_id == "DSGVO-OBL-001"
|
||||
assert match.method == "exact_match"
|
||||
assert match.confidence == 1.0
|
||||
assert match.regulation_id == "dsgvo"
|
||||
|
||||
def test_exact_match_case_insensitive_article(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "ART. 30")
|
||||
assert match is not None
|
||||
assert match.obligation_id == "DSGVO-OBL-001"
|
||||
|
||||
def test_exact_match_article_variant(self):
|
||||
"""'Article 30' should normalize to 'art. 30' and match."""
|
||||
match = self.extractor._tier1_exact("dsgvo", "Article 30")
|
||||
assert match is not None
|
||||
assert match.obligation_id == "DSGVO-OBL-001"
|
||||
|
||||
def test_exact_match_artikel_variant(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "Artikel 30")
|
||||
assert match is not None
|
||||
assert match.obligation_id == "DSGVO-OBL-001"
|
||||
|
||||
def test_exact_match_strips_absatz(self):
|
||||
"""Art. 30 Abs. 1 → art. 30 → should match."""
|
||||
match = self.extractor._tier1_exact("dsgvo", "Art. 30 Abs. 1")
|
||||
assert match is not None
|
||||
assert match.obligation_id == "DSGVO-OBL-001"
|
||||
|
||||
def test_no_match_wrong_article(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "Art. 999")
|
||||
assert match is None
|
||||
|
||||
def test_no_match_unknown_regulation(self):
|
||||
match = self.extractor._tier1_exact("unknown_reg", "Art. 30")
|
||||
assert match is None
|
||||
|
||||
def test_no_match_none_regulation(self):
|
||||
match = self.extractor._tier1_exact(None, "Art. 30")
|
||||
assert match is None
|
||||
|
||||
def test_match_has_title(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
||||
assert match is not None
|
||||
assert match.obligation_title is not None
|
||||
assert len(match.obligation_title) > 0
|
||||
|
||||
def test_match_has_text(self):
|
||||
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
||||
assert match is not None
|
||||
assert match.obligation_text is not None
|
||||
assert len(match.obligation_text) > 20
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — Tier 2 (Embedding Match)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTier2EmbeddingMatch:
|
||||
"""Tests for Tier 2 embedding-based matching."""
|
||||
|
||||
def setup_method(self):
|
||||
self.extractor = ObligationExtractor()
|
||||
self.extractor._load_obligations()
|
||||
# Prepare fake embeddings for testing (no real embedding service)
|
||||
self.extractor._obligation_ids = list(self.extractor._obligations.keys())
|
||||
# Create simple 3D embeddings per obligation — avoid zero vectors
|
||||
self.extractor._obligation_embeddings = []
|
||||
for i in range(len(self.extractor._obligation_ids)):
|
||||
# Each obligation gets a unique-ish non-zero vector
|
||||
self.extractor._obligation_embeddings.append(
|
||||
[float(i % 10 + 1), float((i * 3) % 10 + 1), float((i * 7) % 10 + 1)]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_above_threshold(self):
|
||||
"""When cosine > 0.80, should return embedding_match."""
|
||||
# Mock the embedding service to return a vector very similar to obligation 0
|
||||
target_embedding = self.extractor._obligation_embeddings[0]
|
||||
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=target_embedding,
|
||||
):
|
||||
match = await self.extractor._tier2_embedding("test text", "dsgvo")
|
||||
|
||||
# Should find a match (cosine = 1.0 for identical vector)
|
||||
assert match is not None
|
||||
assert match.method == "embedding_match"
|
||||
assert match.confidence >= EMBEDDING_MATCH_THRESHOLD
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_returns_none_below_threshold(self):
|
||||
"""When cosine < 0.80, should return None."""
|
||||
# Return a vector orthogonal to all obligations
|
||||
orthogonal = [100.0, -100.0, 0.0]
|
||||
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=orthogonal,
|
||||
):
|
||||
match = await self.extractor._tier2_embedding("unrelated text", None)
|
||||
|
||||
# May or may not match depending on vector distribution
|
||||
# But we can verify it's either None or has correct method
|
||||
if match is not None:
|
||||
assert match.method == "embedding_match"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_empty_embeddings(self):
|
||||
"""When no embeddings loaded, should return None."""
|
||||
self.extractor._obligation_embeddings = []
|
||||
match = await self.extractor._tier2_embedding("any text", "dsgvo")
|
||||
assert match is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_failed_embedding(self):
|
||||
"""When embedding service returns empty, should return None."""
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
match = await self.extractor._tier2_embedding("some text", "dsgvo")
|
||||
assert match is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_domain_bonus_same_regulation(self):
|
||||
"""Matching regulation should add +0.05 bonus."""
|
||||
# Set up two obligations with same embeddings but different regulations
|
||||
self.extractor._obligation_ids = ["DSGVO-OBL-001", "NIS2-OBL-001"]
|
||||
self.extractor._obligation_embeddings = [
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
]
|
||||
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[1.0, 0.0, 0.0],
|
||||
):
|
||||
match = await self.extractor._tier2_embedding("test", "dsgvo")
|
||||
|
||||
# Should match (cosine = 1.0 ≥ 0.80)
|
||||
assert match is not None
|
||||
assert match.method == "embedding_match"
|
||||
# With domain bonus, DSGVO should be preferred
|
||||
assert match.regulation_id == "dsgvo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_capped_at_1(self):
|
||||
"""Confidence should not exceed 1.0 even with domain bonus."""
|
||||
self.extractor._obligation_ids = ["DSGVO-OBL-001"]
|
||||
self.extractor._obligation_embeddings = [[1.0, 0.0, 0.0]]
|
||||
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[1.0, 0.0, 0.0],
|
||||
):
|
||||
match = await self.extractor._tier2_embedding("test", "dsgvo")
|
||||
|
||||
assert match is not None
|
||||
assert match.confidence <= 1.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — Tier 3 (LLM Extraction)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTier3LLMExtraction:
|
||||
"""Tests for Tier 3 LLM-based obligation extraction."""
|
||||
|
||||
def setup_method(self):
|
||||
self.extractor = ObligationExtractor()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extraction_success(self):
|
||||
"""Successful LLM extraction returns obligation_text with confidence 0.60."""
|
||||
llm_response = json.dumps({
|
||||
"obligation_text": "Pflicht zur Fuehrung eines Verarbeitungsverzeichnisses",
|
||||
"actor": "Verantwortlicher",
|
||||
"action": "Verarbeitungsverzeichnis fuehren",
|
||||
"normative_strength": "muss",
|
||||
})
|
||||
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_response,
|
||||
):
|
||||
match = await self.extractor._tier3_llm(
|
||||
"Der Verantwortliche fuehrt ein Verzeichnis...",
|
||||
"eu_2016_679",
|
||||
"Art. 30",
|
||||
)
|
||||
|
||||
assert match.method == "llm_extracted"
|
||||
assert match.confidence == 0.60
|
||||
assert "Verarbeitungsverzeichnis" in match.obligation_text
|
||||
assert match.obligation_id is None # LLM doesn't assign IDs
|
||||
assert match.regulation_id == "dsgvo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extraction_failure(self):
|
||||
"""When LLM returns empty, should return match with confidence 0."""
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
):
|
||||
match = await self.extractor._tier3_llm("some text", "dsgvo", "Art. 1")
|
||||
|
||||
assert match.method == "llm_extracted"
|
||||
assert match.confidence == 0.0
|
||||
assert match.obligation_text is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extraction_malformed_json(self):
|
||||
"""When LLM returns non-JSON, should use raw text as fallback."""
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Dies ist die Pflicht: Daten schuetzen",
|
||||
):
|
||||
match = await self.extractor._tier3_llm("some text", "dsgvo", None)
|
||||
|
||||
assert match.method == "llm_extracted"
|
||||
assert match.confidence == 0.60
|
||||
# Fallback: uses first 500 chars of response as obligation_text
|
||||
assert "Pflicht" in match.obligation_text or "Daten" in match.obligation_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_regulation_normalization(self):
|
||||
"""Regulation code should be normalized in result."""
|
||||
with patch(
|
||||
"compliance.services.obligation_extractor._llm_ollama",
|
||||
new_callable=AsyncMock,
|
||||
return_value='{"obligation_text": "Test"}',
|
||||
):
|
||||
match = await self.extractor._tier3_llm(
|
||||
"text", "eu_2024_1689", "Art. 6"
|
||||
)
|
||||
|
||||
assert match.regulation_id == "ai_act"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — Full 3-Tier extract()
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestExtractFullFlow:
|
||||
"""Tests for the full 3-tier extraction flow."""
|
||||
|
||||
def setup_method(self):
|
||||
self.extractor = ObligationExtractor()
|
||||
self.extractor._load_obligations()
|
||||
# Mark as initialized to skip async initialize
|
||||
self.extractor._initialized = True
|
||||
# Empty embeddings — Tier 2 will return None
|
||||
self.extractor._obligation_embeddings = []
|
||||
self.extractor._obligation_ids = []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier1_takes_priority(self):
|
||||
"""When Tier 1 matches, Tier 2 and 3 should not be called."""
|
||||
with patch.object(
|
||||
self.extractor, "_tier2_embedding", new_callable=AsyncMock
|
||||
) as mock_t2, patch.object(
|
||||
self.extractor, "_tier3_llm", new_callable=AsyncMock
|
||||
) as mock_t3:
|
||||
match = await self.extractor.extract(
|
||||
chunk_text="irrelevant",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
)
|
||||
|
||||
assert match.method == "exact_match"
|
||||
mock_t2.assert_not_called()
|
||||
mock_t3.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier2_when_tier1_misses(self):
|
||||
"""When Tier 1 misses, Tier 2 should be tried."""
|
||||
tier2_result = ObligationMatch(
|
||||
obligation_id="DSGVO-OBL-050",
|
||||
method="embedding_match",
|
||||
confidence=0.85,
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.extractor, "_tier2_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tier2_result,
|
||||
) as mock_t2, patch.object(
|
||||
self.extractor, "_tier3_llm", new_callable=AsyncMock
|
||||
) as mock_t3:
|
||||
match = await self.extractor.extract(
|
||||
chunk_text="some compliance text",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 999", # Non-matching article
|
||||
)
|
||||
|
||||
assert match.method == "embedding_match"
|
||||
mock_t2.assert_called_once()
|
||||
mock_t3.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier3_when_tier1_and_2_miss(self):
|
||||
"""When Tier 1 and 2 miss, Tier 3 should be called."""
|
||||
tier3_result = ObligationMatch(
|
||||
obligation_text="LLM extracted obligation",
|
||||
method="llm_extracted",
|
||||
confidence=0.60,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.extractor, "_tier2_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch.object(
|
||||
self.extractor, "_tier3_llm",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tier3_result,
|
||||
):
|
||||
match = await self.extractor.extract(
|
||||
chunk_text="unrelated text",
|
||||
regulation_code="unknown_reg",
|
||||
article="Art. 999",
|
||||
)
|
||||
|
||||
assert match.method == "llm_extracted"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_article_skips_tier1(self):
|
||||
"""When no article is provided, Tier 1 should be skipped."""
|
||||
with patch.object(
|
||||
self.extractor, "_tier2_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
) as mock_t2, patch.object(
|
||||
self.extractor, "_tier3_llm",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
|
||||
):
|
||||
match = await self.extractor.extract(
|
||||
chunk_text="some text",
|
||||
regulation_code="dsgvo",
|
||||
article=None,
|
||||
)
|
||||
|
||||
# Tier 2 should be called (Tier 1 skipped due to no article)
|
||||
mock_t2.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_initialize(self):
|
||||
"""If not initialized, extract should call initialize()."""
|
||||
extractor = ObligationExtractor()
|
||||
assert not extractor._initialized
|
||||
|
||||
with patch.object(
|
||||
extractor, "initialize", new_callable=AsyncMock
|
||||
) as mock_init:
|
||||
# After mock init, set initialized to True
|
||||
async def side_effect():
|
||||
extractor._initialized = True
|
||||
extractor._load_obligations()
|
||||
extractor._obligation_embeddings = []
|
||||
extractor._obligation_ids = []
|
||||
|
||||
mock_init.side_effect = side_effect
|
||||
|
||||
with patch.object(
|
||||
extractor, "_tier2_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch.object(
|
||||
extractor, "_tier3_llm",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
|
||||
):
|
||||
await extractor.extract(
|
||||
chunk_text="test",
|
||||
regulation_code="dsgvo",
|
||||
article=None,
|
||||
)
|
||||
|
||||
mock_init.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ObligationExtractor — stats()
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestExtractorStats:
|
||||
"""Tests for the stats() method."""
|
||||
|
||||
def test_stats_before_init(self):
|
||||
extractor = ObligationExtractor()
|
||||
stats = extractor.stats()
|
||||
assert stats["total_obligations"] == 0
|
||||
assert stats["article_lookups"] == 0
|
||||
assert stats["initialized"] is False
|
||||
|
||||
def test_stats_after_load(self):
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
stats = extractor.stats()
|
||||
assert stats["total_obligations"] == 325
|
||||
assert stats["article_lookups"] > 0
|
||||
assert "dsgvo" in stats["regulations"]
|
||||
assert stats["initialized"] is False # not fully initialized (no embeddings)
|
||||
|
||||
def test_stats_regulations_complete(self):
|
||||
extractor = ObligationExtractor()
|
||||
extractor._load_obligations()
|
||||
stats = extractor.stats()
|
||||
expected_regs = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg",
|
||||
"dsa", "data_act", "eu_machinery", "dora"}
|
||||
assert set(stats["regulations"]) == expected_regs
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Integration — Regulation-to-Obligation mapping coverage
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRegulationObligationCoverage:
|
||||
"""Verify that the article lookup provides reasonable coverage."""
|
||||
|
||||
def setup_method(self):
|
||||
self.extractor = ObligationExtractor()
|
||||
self.extractor._load_obligations()
|
||||
|
||||
def test_dsgvo_has_article_lookups(self):
|
||||
"""DSGVO (80 obligations) should have many article lookups."""
|
||||
dsgvo_keys = [k for k in self.extractor._article_lookup if k.startswith("dsgvo/")]
|
||||
assert len(dsgvo_keys) >= 20, f"Only {len(dsgvo_keys)} DSGVO article lookups"
|
||||
|
||||
def test_ai_act_has_article_lookups(self):
|
||||
ai_keys = [k for k in self.extractor._article_lookup if k.startswith("ai_act/")]
|
||||
assert len(ai_keys) >= 10, f"Only {len(ai_keys)} AI Act article lookups"
|
||||
|
||||
def test_nis2_has_article_lookups(self):
|
||||
nis2_keys = [k for k in self.extractor._article_lookup if k.startswith("nis2/")]
|
||||
assert len(nis2_keys) >= 5, f"Only {len(nis2_keys)} NIS2 article lookups"
|
||||
|
||||
def test_all_article_lookup_values_are_valid(self):
|
||||
"""Every obligation ID in article_lookup should exist in _obligations."""
|
||||
for key, obl_ids in self.extractor._article_lookup.items():
|
||||
for obl_id in obl_ids:
|
||||
assert obl_id in self.extractor._obligations, (
|
||||
f"Article lookup {key} references missing obligation {obl_id}"
|
||||
)
|
||||
|
||||
def test_article_lookup_key_format(self):
|
||||
"""All keys should be in format 'regulation_id/normalized_article'."""
|
||||
for key in self.extractor._article_lookup:
|
||||
parts = key.split("/", 1)
|
||||
assert len(parts) == 2, f"Invalid key format: {key}"
|
||||
reg_id, article = parts
|
||||
assert reg_id, f"Empty regulation ID in key: {key}"
|
||||
assert article, f"Empty article in key: {key}"
|
||||
assert article == article.lower(), f"Article not lowercase: {key}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Constants and thresholds
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module-level constants."""
|
||||
|
||||
def test_embedding_thresholds_ordering(self):
|
||||
"""Match threshold should be higher than candidate threshold."""
|
||||
assert EMBEDDING_MATCH_THRESHOLD > EMBEDDING_CANDIDATE_THRESHOLD
|
||||
|
||||
def test_embedding_thresholds_range(self):
|
||||
"""Thresholds should be between 0 and 1."""
|
||||
assert 0 < EMBEDDING_MATCH_THRESHOLD <= 1.0
|
||||
assert 0 < EMBEDDING_CANDIDATE_THRESHOLD <= 1.0
|
||||
|
||||
def test_match_threshold_is_80(self):
|
||||
assert EMBEDDING_MATCH_THRESHOLD == 0.80
|
||||
|
||||
def test_candidate_threshold_is_60(self):
|
||||
assert EMBEDDING_CANDIDATE_THRESHOLD == 0.60
|
||||
901
backend-compliance/tests/test_pattern_matcher.py
Normal file
901
backend-compliance/tests/test_pattern_matcher.py
Normal file
@@ -0,0 +1,901 @@
|
||||
"""Tests for Pattern Matcher — Phase 5 of Multi-Layer Control Architecture.
|
||||
|
||||
Validates:
|
||||
- Pattern loading from YAML files
|
||||
- Keyword index construction
|
||||
- Keyword matching (Tier 1)
|
||||
- Embedding matching (Tier 2) with domain bonus
|
||||
- Score combination logic
|
||||
- Domain affinity mapping
|
||||
- Top-N matching
|
||||
- PatternMatchResult serialization
|
||||
- Edge cases: empty inputs, no matches, missing data
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from compliance.services.pattern_matcher import (
|
||||
DOMAIN_BONUS,
|
||||
EMBEDDING_PATTERN_THRESHOLD,
|
||||
KEYWORD_MATCH_MIN_HITS,
|
||||
ControlPattern,
|
||||
PatternMatchResult,
|
||||
PatternMatcher,
|
||||
_REGULATION_DOMAIN_AFFINITY,
|
||||
_find_patterns_dir,
|
||||
)
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
PATTERNS_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "control_patterns"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _find_patterns_dir
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFindPatternsDir:
|
||||
"""Tests for locating the control_patterns directory."""
|
||||
|
||||
def test_finds_patterns_dir(self):
|
||||
result = _find_patterns_dir()
|
||||
if result is not None:
|
||||
assert result.is_dir()
|
||||
|
||||
def test_patterns_dir_exists_in_repo(self):
|
||||
assert PATTERNS_DIR.exists(), f"Patterns dir not found at {PATTERNS_DIR}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: ControlPattern
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestControlPattern:
|
||||
"""Tests for the ControlPattern dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
p = ControlPattern(
|
||||
id="CP-TEST-001",
|
||||
name="test_pattern",
|
||||
name_de="Test-Muster",
|
||||
domain="SEC",
|
||||
category="testing",
|
||||
description="A test pattern",
|
||||
objective_template="Test objective",
|
||||
rationale_template="Test rationale",
|
||||
)
|
||||
assert p.id == "CP-TEST-001"
|
||||
assert p.severity_default == "medium"
|
||||
assert p.implementation_effort_default == "m"
|
||||
assert p.obligation_match_keywords == []
|
||||
assert p.tags == []
|
||||
assert p.composable_with == []
|
||||
|
||||
def test_full_pattern(self):
|
||||
p = ControlPattern(
|
||||
id="CP-AUTH-001",
|
||||
name="password_policy",
|
||||
name_de="Passwortrichtlinie",
|
||||
domain="AUTH",
|
||||
category="authentication",
|
||||
description="Password requirements",
|
||||
objective_template="Ensure strong passwords",
|
||||
rationale_template="Weak passwords are risky",
|
||||
obligation_match_keywords=["passwort", "password", "credential"],
|
||||
tags=["authentication", "password"],
|
||||
composable_with=["CP-AUTH-002"],
|
||||
)
|
||||
assert len(p.obligation_match_keywords) == 3
|
||||
assert "CP-AUTH-002" in p.composable_with
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatchResult
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatternMatchResult:
|
||||
"""Tests for the PatternMatchResult dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
result = PatternMatchResult()
|
||||
assert result.pattern is None
|
||||
assert result.pattern_id is None
|
||||
assert result.method == "none"
|
||||
assert result.confidence == 0.0
|
||||
assert result.keyword_hits == 0
|
||||
assert result.embedding_score == 0.0
|
||||
assert result.composable_patterns == []
|
||||
|
||||
def test_to_dict(self):
|
||||
result = PatternMatchResult(
|
||||
pattern_id="CP-AUTH-001",
|
||||
method="keyword",
|
||||
confidence=0.857,
|
||||
keyword_hits=6,
|
||||
total_keywords=7,
|
||||
embedding_score=0.823,
|
||||
domain_bonus_applied=True,
|
||||
composable_patterns=["CP-AUTH-002"],
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["pattern_id"] == "CP-AUTH-001"
|
||||
assert d["method"] == "keyword"
|
||||
assert d["confidence"] == 0.857
|
||||
assert d["keyword_hits"] == 6
|
||||
assert d["total_keywords"] == 7
|
||||
assert d["embedding_score"] == 0.823
|
||||
assert d["domain_bonus_applied"] is True
|
||||
assert d["composable_patterns"] == ["CP-AUTH-002"]
|
||||
|
||||
def test_to_dict_keys(self):
|
||||
result = PatternMatchResult()
|
||||
d = result.to_dict()
|
||||
expected_keys = {
|
||||
"pattern_id", "method", "confidence", "keyword_hits",
|
||||
"total_keywords", "embedding_score", "domain_bonus_applied",
|
||||
"composable_patterns",
|
||||
}
|
||||
assert set(d.keys()) == expected_keys
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Loading
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatternMatcherLoad:
|
||||
"""Tests for loading patterns from YAML."""
|
||||
|
||||
def test_load_patterns(self):
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
assert len(matcher._patterns) == 50
|
||||
|
||||
def test_by_id_populated(self):
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
assert "CP-AUTH-001" in matcher._by_id
|
||||
assert "CP-CRYP-001" in matcher._by_id
|
||||
|
||||
def test_by_domain_populated(self):
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
assert "AUTH" in matcher._by_domain
|
||||
assert "DATA" in matcher._by_domain
|
||||
assert len(matcher._by_domain["AUTH"]) >= 3
|
||||
|
||||
def test_pattern_fields_valid(self):
|
||||
"""Every loaded pattern should have all required fields."""
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
for p in matcher._patterns:
|
||||
assert p.id, "Empty pattern ID"
|
||||
assert p.name, f"{p.id}: empty name"
|
||||
assert p.name_de, f"{p.id}: empty name_de"
|
||||
assert p.domain, f"{p.id}: empty domain"
|
||||
assert p.category, f"{p.id}: empty category"
|
||||
assert p.description, f"{p.id}: empty description"
|
||||
assert p.objective_template, f"{p.id}: empty objective_template"
|
||||
assert len(p.obligation_match_keywords) >= 3, (
|
||||
f"{p.id}: only {len(p.obligation_match_keywords)} keywords"
|
||||
)
|
||||
|
||||
def test_no_duplicate_ids(self):
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
ids = [p.id for p in matcher._patterns]
|
||||
assert len(ids) == len(set(ids))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Keyword Index
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestKeywordIndex:
|
||||
"""Tests for the reverse keyword index."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
|
||||
def test_keyword_index_populated(self):
|
||||
assert len(self.matcher._keyword_index) > 50
|
||||
|
||||
def test_keyword_maps_to_patterns(self):
|
||||
"""'passwort' should map to CP-AUTH-001."""
|
||||
assert "passwort" in self.matcher._keyword_index
|
||||
assert "CP-AUTH-001" in self.matcher._keyword_index["passwort"]
|
||||
|
||||
def test_keyword_lowercase(self):
|
||||
"""All keywords in the index should be lowercase."""
|
||||
for kw in self.matcher._keyword_index:
|
||||
assert kw == kw.lower(), f"Keyword not lowercase: {kw}"
|
||||
|
||||
def test_keyword_shared_across_patterns(self):
|
||||
"""Some keywords like 'verschluesselung' may appear in multiple patterns."""
|
||||
# This just verifies the structure allows multi-pattern keywords
|
||||
for kw, pattern_ids in self.matcher._keyword_index.items():
|
||||
assert len(pattern_ids) >= 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Tier 1 (Keyword Match)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTier1KeywordMatch:
|
||||
"""Tests for keyword-based pattern matching."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
|
||||
def test_password_text_matches_auth(self):
|
||||
"""Text about passwords should match CP-AUTH-001."""
|
||||
result = self.matcher._tier1_keyword(
|
||||
"Die Passwortrichtlinie muss sicherstellen dass Anmeldedaten "
|
||||
"und Credentials geschuetzt sind und authentifizierung robust ist",
|
||||
None,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.pattern_id == "CP-AUTH-001"
|
||||
assert result.method == "keyword"
|
||||
assert result.keyword_hits >= KEYWORD_MATCH_MIN_HITS
|
||||
|
||||
def test_encryption_text_matches_cryp(self):
|
||||
"""Text about encryption should match CP-CRYP-001."""
|
||||
result = self.matcher._tier1_keyword(
|
||||
"Verschluesselung ruhender Daten muss mit AES-256 encryption erfolgen",
|
||||
None,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.pattern_id == "CP-CRYP-001"
|
||||
assert result.keyword_hits >= KEYWORD_MATCH_MIN_HITS
|
||||
|
||||
def test_incident_text_matches_inc(self):
|
||||
result = self.matcher._tier1_keyword(
|
||||
"Ein Vorfall-Reaktionsplan muss fuer Sicherheitsvorfaelle "
|
||||
"und incident response bereitstehen",
|
||||
None,
|
||||
)
|
||||
assert result is not None
|
||||
assert "INC" in result.pattern_id
|
||||
|
||||
def test_no_match_for_unrelated_text(self):
|
||||
result = self.matcher._tier1_keyword(
|
||||
"xyzzy foobar completely unrelated text with no keywords",
|
||||
None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_single_keyword_below_threshold(self):
|
||||
"""A single keyword hit should not be enough."""
|
||||
result = self.matcher._tier1_keyword("passwort", None)
|
||||
assert result is None # Only 1 hit < KEYWORD_MATCH_MIN_HITS (2)
|
||||
|
||||
def test_domain_bonus_applied(self):
|
||||
"""Domain bonus should be added when regulation matches."""
|
||||
result_without = self.matcher._tier1_keyword(
|
||||
"Personenbezogene Daten muessen durch Datenschutz Massnahmen "
|
||||
"und datensicherheit geschuetzt werden mit datenminimierung",
|
||||
None,
|
||||
)
|
||||
result_with = self.matcher._tier1_keyword(
|
||||
"Personenbezogene Daten muessen durch Datenschutz Massnahmen "
|
||||
"und datensicherheit geschuetzt werden mit datenminimierung",
|
||||
"dsgvo",
|
||||
)
|
||||
if result_without and result_with:
|
||||
# With DSGVO regulation, DATA domain patterns should get a bonus
|
||||
if result_with.domain_bonus_applied:
|
||||
assert result_with.confidence >= result_without.confidence
|
||||
|
||||
def test_keyword_scores_returns_dict(self):
|
||||
scores = self.matcher._keyword_scores(
|
||||
"Passwort authentifizierung credential zugang",
|
||||
None,
|
||||
)
|
||||
assert isinstance(scores, dict)
|
||||
assert "CP-AUTH-001" in scores
|
||||
hits, total, confidence = scores["CP-AUTH-001"]
|
||||
assert hits >= 3
|
||||
assert total > 0
|
||||
assert 0 < confidence <= 1.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Tier 2 (Embedding Match)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTier2EmbeddingMatch:
|
||||
"""Tests for embedding-based pattern matching."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
# Set up fake embeddings
|
||||
self.matcher._pattern_ids = [p.id for p in self.matcher._patterns]
|
||||
self.matcher._pattern_embeddings = []
|
||||
for i in range(len(self.matcher._patterns)):
|
||||
self.matcher._pattern_embeddings.append(
|
||||
[float(i % 10 + 1), float((i * 3) % 10 + 1), float((i * 7) % 10 + 1)]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_identical_vector(self):
|
||||
"""Identical vector should produce cosine = 1.0 > threshold."""
|
||||
target = self.matcher._pattern_embeddings[0]
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=target,
|
||||
):
|
||||
result = await self.matcher._tier2_embedding("test text", None)
|
||||
|
||||
assert result is not None
|
||||
assert result.method == "embedding"
|
||||
assert result.confidence >= EMBEDDING_PATTERN_THRESHOLD
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_empty(self):
|
||||
"""Empty embeddings should return None."""
|
||||
self.matcher._pattern_embeddings = []
|
||||
result = await self.matcher._tier2_embedding("test text", None)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_match_failed_service(self):
|
||||
"""Failed embedding service should return None."""
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
result = await self.matcher._tier2_embedding("test", None)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_domain_bonus(self):
|
||||
"""Domain bonus should increase score for affine regulation."""
|
||||
# Set all patterns to same embedding
|
||||
for i in range(len(self.matcher._pattern_embeddings)):
|
||||
self.matcher._pattern_embeddings[i] = [1.0, 0.0, 0.0]
|
||||
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[1.0, 0.0, 0.0],
|
||||
):
|
||||
scores = await self.matcher._embedding_scores("test", "dsgvo")
|
||||
|
||||
# DATA domain patterns should have bonus applied
|
||||
data_patterns = [p.id for p in self.matcher._patterns if p.domain == "DATA"]
|
||||
if data_patterns:
|
||||
pid = data_patterns[0]
|
||||
score, bonus = scores.get(pid, (0, False))
|
||||
assert bonus is True
|
||||
assert score > 1.0 # 1.0 cosine + 0.10 bonus
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Score Combination
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestScoreCombination:
|
||||
"""Tests for combining keyword and embedding results."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.pattern = ControlPattern(
|
||||
id="CP-TEST-001", name="test", name_de="Test",
|
||||
domain="SEC", category="test", description="d",
|
||||
objective_template="o", rationale_template="r",
|
||||
)
|
||||
|
||||
def test_both_none(self):
|
||||
result = self.matcher._combine_results(None, None)
|
||||
assert result.method == "none"
|
||||
assert result.confidence == 0.0
|
||||
|
||||
def test_only_keyword(self):
|
||||
kw = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="keyword", confidence=0.7, keyword_hits=5,
|
||||
)
|
||||
result = self.matcher._combine_results(kw, None)
|
||||
assert result.method == "keyword"
|
||||
assert result.confidence == 0.7
|
||||
|
||||
def test_only_embedding(self):
|
||||
emb = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="embedding", confidence=0.85, embedding_score=0.85,
|
||||
)
|
||||
result = self.matcher._combine_results(None, emb)
|
||||
assert result.method == "embedding"
|
||||
assert result.confidence == 0.85
|
||||
|
||||
def test_same_pattern_combined(self):
|
||||
"""When both tiers agree, confidence gets +0.05 boost."""
|
||||
kw = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="keyword", confidence=0.7, keyword_hits=5, total_keywords=7,
|
||||
)
|
||||
emb = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="embedding", confidence=0.8, embedding_score=0.8,
|
||||
)
|
||||
result = self.matcher._combine_results(kw, emb)
|
||||
assert result.method == "combined"
|
||||
assert abs(result.confidence - 0.85) < 1e-9 # max(0.7, 0.8) + 0.05
|
||||
assert result.keyword_hits == 5
|
||||
assert result.embedding_score == 0.8
|
||||
|
||||
def test_same_pattern_combined_capped(self):
|
||||
"""Combined confidence should not exceed 1.0."""
|
||||
kw = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="keyword", confidence=0.95,
|
||||
)
|
||||
emb = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="embedding", confidence=0.98, embedding_score=0.98,
|
||||
)
|
||||
result = self.matcher._combine_results(kw, emb)
|
||||
assert result.confidence <= 1.0
|
||||
|
||||
def test_different_patterns_picks_higher(self):
|
||||
"""When tiers disagree, pick the higher confidence."""
|
||||
p2 = ControlPattern(
|
||||
id="CP-TEST-002", name="test2", name_de="Test2",
|
||||
domain="SEC", category="test", description="d",
|
||||
objective_template="o", rationale_template="r",
|
||||
)
|
||||
kw = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="keyword", confidence=0.6,
|
||||
)
|
||||
emb = PatternMatchResult(
|
||||
pattern=p2, pattern_id="CP-TEST-002",
|
||||
method="embedding", confidence=0.9, embedding_score=0.9,
|
||||
)
|
||||
result = self.matcher._combine_results(kw, emb)
|
||||
assert result.pattern_id == "CP-TEST-002"
|
||||
assert result.confidence == 0.9
|
||||
|
||||
def test_different_patterns_keyword_wins(self):
|
||||
p2 = ControlPattern(
|
||||
id="CP-TEST-002", name="test2", name_de="Test2",
|
||||
domain="SEC", category="test", description="d",
|
||||
objective_template="o", rationale_template="r",
|
||||
)
|
||||
kw = PatternMatchResult(
|
||||
pattern=self.pattern, pattern_id="CP-TEST-001",
|
||||
method="keyword", confidence=0.9,
|
||||
)
|
||||
emb = PatternMatchResult(
|
||||
pattern=p2, pattern_id="CP-TEST-002",
|
||||
method="embedding", confidence=0.6, embedding_score=0.6,
|
||||
)
|
||||
result = self.matcher._combine_results(kw, emb)
|
||||
assert result.pattern_id == "CP-TEST-001"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Domain Affinity
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDomainAffinity:
|
||||
"""Tests for regulation-to-domain affinity mapping."""
|
||||
|
||||
def test_dsgvo_affine_with_data(self):
|
||||
assert PatternMatcher._domain_matches("DATA", "dsgvo")
|
||||
|
||||
def test_dsgvo_affine_with_comp(self):
|
||||
assert PatternMatcher._domain_matches("COMP", "dsgvo")
|
||||
|
||||
def test_ai_act_affine_with_ai(self):
|
||||
assert PatternMatcher._domain_matches("AI", "ai_act")
|
||||
|
||||
def test_nis2_affine_with_sec(self):
|
||||
assert PatternMatcher._domain_matches("SEC", "nis2")
|
||||
|
||||
def test_nis2_affine_with_inc(self):
|
||||
assert PatternMatcher._domain_matches("INC", "nis2")
|
||||
|
||||
def test_dora_affine_with_fin(self):
|
||||
assert PatternMatcher._domain_matches("FIN", "dora")
|
||||
|
||||
def test_no_affinity_auth_dsgvo(self):
|
||||
"""AUTH is not in DSGVO's affinity list."""
|
||||
assert not PatternMatcher._domain_matches("AUTH", "dsgvo")
|
||||
|
||||
def test_unknown_regulation(self):
|
||||
assert not PatternMatcher._domain_matches("DATA", "unknown_reg")
|
||||
|
||||
def test_all_regulations_have_affinity(self):
|
||||
"""All 9 regulations should have at least one affine domain."""
|
||||
expected_regs = [
|
||||
"dsgvo", "bdsg", "ttdsg", "ai_act", "nis2",
|
||||
"dsa", "data_act", "eu_machinery", "dora",
|
||||
]
|
||||
for reg in expected_regs:
|
||||
assert reg in _REGULATION_DOMAIN_AFFINITY, f"{reg} missing from affinity map"
|
||||
assert len(_REGULATION_DOMAIN_AFFINITY[reg]) >= 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Full match()
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMatchFull:
|
||||
"""Tests for the full match() method."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
self.matcher._initialized = True
|
||||
# Empty embeddings — Tier 2 returns None
|
||||
self.matcher._pattern_embeddings = []
|
||||
self.matcher._pattern_ids = []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_password_text(self):
|
||||
"""Password text should match CP-AUTH-001 via keywords."""
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
result = await self.matcher.match(
|
||||
obligation_text=(
|
||||
"Passwortrichtlinie muss sicherstellen dass Anmeldedaten "
|
||||
"und credential geschuetzt sind und authentifizierung robust ist"
|
||||
),
|
||||
regulation_id="nis2",
|
||||
)
|
||||
assert result.pattern_id == "CP-AUTH-001"
|
||||
assert result.confidence > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_encryption_text(self):
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
result = await self.matcher.match(
|
||||
obligation_text=(
|
||||
"Verschluesselung ruhender Daten muss mit AES-256 encryption "
|
||||
"und schluesselmanagement kryptographie erfolgen"
|
||||
),
|
||||
)
|
||||
assert result.pattern_id == "CP-CRYP-001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_empty_text(self):
|
||||
result = await self.matcher.match(obligation_text="")
|
||||
assert result.method == "none"
|
||||
assert result.confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_no_patterns(self):
|
||||
"""When no patterns loaded, should return empty result."""
|
||||
matcher = PatternMatcher()
|
||||
matcher._initialized = True
|
||||
result = await matcher.match(obligation_text="test")
|
||||
assert result.method == "none"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_composable_patterns(self):
|
||||
"""Result should include composable_with references."""
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
result = await self.matcher.match(
|
||||
obligation_text=(
|
||||
"Passwortrichtlinie muss sicherstellen dass Anmeldedaten "
|
||||
"und credential geschuetzt sind und authentifizierung robust ist"
|
||||
),
|
||||
)
|
||||
if result.pattern and result.pattern.composable_with:
|
||||
assert len(result.composable_patterns) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_with_domain_bonus(self):
|
||||
"""DSGVO obligation with DATA keywords should get domain bonus."""
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
result = await self.matcher.match(
|
||||
obligation_text=(
|
||||
"Personenbezogene Daten muessen durch Datenschutz und "
|
||||
"datensicherheit geschuetzt werden mit datenminimierung "
|
||||
"und speicherbegrenzung und loeschung"
|
||||
),
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
# Should match a DATA-domain pattern
|
||||
if result.pattern and result.pattern.domain == "DATA":
|
||||
assert result.domain_bonus_applied is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — match_top_n()
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMatchTopN:
|
||||
"""Tests for top-N matching."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
self.matcher._initialized = True
|
||||
self.matcher._pattern_embeddings = []
|
||||
self.matcher._pattern_ids = []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_returns_list(self):
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
results = await self.matcher.match_top_n(
|
||||
obligation_text=(
|
||||
"Passwortrichtlinie muss sicherstellen dass Anmeldedaten "
|
||||
"und credential geschuetzt sind und authentifizierung robust ist"
|
||||
),
|
||||
n=3,
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_sorted_by_confidence(self):
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
results = await self.matcher.match_top_n(
|
||||
obligation_text=(
|
||||
"Verschluesselung und kryptographie und schluesselmanagement "
|
||||
"und authentifizierung und password und zugriffskontrolle"
|
||||
),
|
||||
n=5,
|
||||
)
|
||||
if len(results) >= 2:
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i].confidence >= results[i + 1].confidence
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_empty_text(self):
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
results = await self.matcher.match_top_n(obligation_text="", n=3)
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_respects_limit(self):
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
results = await self.matcher.match_top_n(
|
||||
obligation_text=(
|
||||
"Verschluesselung und kryptographie und schluesselmanagement "
|
||||
"und authentifizierung und password und zugriffskontrolle"
|
||||
),
|
||||
n=2,
|
||||
)
|
||||
assert len(results) <= 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — Public Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPublicHelpers:
|
||||
"""Tests for get_pattern, get_patterns_by_domain, stats."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
|
||||
def test_get_pattern_existing(self):
|
||||
p = self.matcher.get_pattern("CP-AUTH-001")
|
||||
assert p is not None
|
||||
assert p.id == "CP-AUTH-001"
|
||||
|
||||
def test_get_pattern_case_insensitive(self):
|
||||
p = self.matcher.get_pattern("cp-auth-001")
|
||||
assert p is not None
|
||||
|
||||
def test_get_pattern_nonexistent(self):
|
||||
p = self.matcher.get_pattern("CP-FAKE-999")
|
||||
assert p is None
|
||||
|
||||
def test_get_patterns_by_domain(self):
|
||||
patterns = self.matcher.get_patterns_by_domain("AUTH")
|
||||
assert len(patterns) >= 3
|
||||
|
||||
def test_get_patterns_by_domain_case_insensitive(self):
|
||||
patterns = self.matcher.get_patterns_by_domain("auth")
|
||||
assert len(patterns) >= 3
|
||||
|
||||
def test_get_patterns_by_domain_unknown(self):
|
||||
patterns = self.matcher.get_patterns_by_domain("NOPE")
|
||||
assert patterns == []
|
||||
|
||||
def test_stats(self):
|
||||
stats = self.matcher.stats()
|
||||
assert stats["total_patterns"] == 50
|
||||
assert len(stats["domains"]) >= 5
|
||||
assert stats["keywords"] > 50
|
||||
assert stats["initialized"] is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PatternMatcher — auto initialize
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAutoInitialize:
|
||||
"""Tests for auto-initialization on first match call."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_init_on_match(self):
|
||||
matcher = PatternMatcher()
|
||||
assert not matcher._initialized
|
||||
|
||||
with patch.object(
|
||||
matcher, "initialize", new_callable=AsyncMock
|
||||
) as mock_init:
|
||||
async def side_effect():
|
||||
matcher._initialized = True
|
||||
matcher._load_patterns()
|
||||
matcher._build_keyword_index()
|
||||
matcher._pattern_embeddings = []
|
||||
matcher._pattern_ids = []
|
||||
|
||||
mock_init.side_effect = side_effect
|
||||
|
||||
with patch(
|
||||
"compliance.services.pattern_matcher._get_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
await matcher.match(obligation_text="test text")
|
||||
|
||||
mock_init.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_double_init(self):
|
||||
matcher = PatternMatcher()
|
||||
matcher._initialized = True
|
||||
matcher._patterns = []
|
||||
|
||||
with patch.object(
|
||||
matcher, "initialize", new_callable=AsyncMock
|
||||
) as mock_init:
|
||||
await matcher.match(obligation_text="test text")
|
||||
mock_init.assert_not_called()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Constants
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module-level constants."""
|
||||
|
||||
def test_keyword_min_hits(self):
|
||||
assert KEYWORD_MATCH_MIN_HITS >= 1
|
||||
|
||||
def test_embedding_threshold_range(self):
|
||||
assert 0 < EMBEDDING_PATTERN_THRESHOLD <= 1.0
|
||||
|
||||
def test_domain_bonus_range(self):
|
||||
assert 0 < DOMAIN_BONUS <= 0.20
|
||||
|
||||
def test_domain_bonus_is_010(self):
|
||||
assert DOMAIN_BONUS == 0.10
|
||||
|
||||
def test_embedding_threshold_is_075(self):
|
||||
assert EMBEDDING_PATTERN_THRESHOLD == 0.75
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: Integration — Real keyword matching scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRealKeywordScenarios:
|
||||
"""Integration tests with realistic obligation texts."""
|
||||
|
||||
def setup_method(self):
|
||||
self.matcher = PatternMatcher()
|
||||
self.matcher._load_patterns()
|
||||
self.matcher._build_keyword_index()
|
||||
|
||||
def test_dsgvo_consent_obligation(self):
|
||||
"""DSGVO consent obligation should match data protection patterns."""
|
||||
scores = self.matcher._keyword_scores(
|
||||
"Die Einwilligung der betroffenen Person muss freiwillig und "
|
||||
"informiert erfolgen. Eine Verarbeitung personenbezogener Daten "
|
||||
"ist nur mit gültiger Einwilligung zulaessig. Datenschutz.",
|
||||
"dsgvo",
|
||||
)
|
||||
# Should have matches in DATA domain patterns
|
||||
data_matches = [pid for pid in scores if pid.startswith("CP-DATA")]
|
||||
assert len(data_matches) >= 1
|
||||
|
||||
def test_ai_act_risk_assessment(self):
|
||||
"""AI Act risk assessment should match AI patterns."""
|
||||
scores = self.matcher._keyword_scores(
|
||||
"KI-Systeme mit hohem Risiko muessen einer Konformitaetsbewertung "
|
||||
"unterzogen werden. Transparenz und Erklaerbarkeit sind Pflicht.",
|
||||
"ai_act",
|
||||
)
|
||||
ai_matches = [pid for pid in scores if pid.startswith("CP-AI")]
|
||||
assert len(ai_matches) >= 1
|
||||
|
||||
def test_nis2_incident_response(self):
|
||||
"""NIS2 incident text should match INC patterns."""
|
||||
scores = self.matcher._keyword_scores(
|
||||
"Sicherheitsvorfaelle muessen innerhalb von 24 Stunden gemeldet "
|
||||
"werden. Ein incident response plan und Eskalationsverfahren "
|
||||
"sind zu etablieren fuer Vorfall und Wiederherstellung.",
|
||||
"nis2",
|
||||
)
|
||||
inc_matches = [pid for pid in scores if pid.startswith("CP-INC")]
|
||||
assert len(inc_matches) >= 1
|
||||
|
||||
def test_audit_logging_obligation(self):
|
||||
"""Audit logging obligation should match LOG patterns."""
|
||||
scores = self.matcher._keyword_scores(
|
||||
"Alle sicherheitsrelevanten Ereignisse muessen protokolliert werden. "
|
||||
"Audit-Trail und Monitoring der Zugriffe sind Pflicht. "
|
||||
"Protokollierung muss manipulationssicher sein.",
|
||||
None,
|
||||
)
|
||||
log_matches = [pid for pid in scores if pid.startswith("CP-LOG")]
|
||||
assert len(log_matches) >= 1
|
||||
|
||||
def test_access_control_obligation(self):
|
||||
"""Access control text should match ACC patterns."""
|
||||
scores = self.matcher._keyword_scores(
|
||||
"Zugriffskontrolle nach dem Least-Privilege-Prinzip. "
|
||||
"Rollenbasierte Autorisierung und Berechtigung fuer alle Systeme.",
|
||||
None,
|
||||
)
|
||||
acc_matches = [pid for pid in scores if pid.startswith("CP-ACC")]
|
||||
assert len(acc_matches) >= 1
|
||||
682
backend-compliance/tests/test_pipeline_adapter.py
Normal file
682
backend-compliance/tests/test_pipeline_adapter.py
Normal file
@@ -0,0 +1,682 @@
|
||||
"""Tests for Pipeline Adapter — Phase 7 of Multi-Layer Control Architecture.
|
||||
|
||||
Validates:
|
||||
- PipelineChunk and PipelineResult dataclasses
|
||||
- PipelineAdapter.process_chunk() — full 3-stage flow
|
||||
- PipelineAdapter.process_batch() — batch processing
|
||||
- PipelineAdapter.write_crosswalk() — DB write logic (mocked)
|
||||
- MigrationPasses — all 5 passes (with mocked DB)
|
||||
- _extract_regulation_article helper
|
||||
- Edge cases: missing data, LLM failures, initialization
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from compliance.services.pipeline_adapter import (
|
||||
MigrationPasses,
|
||||
PipelineAdapter,
|
||||
PipelineChunk,
|
||||
PipelineResult,
|
||||
_extract_regulation_article,
|
||||
)
|
||||
from compliance.services.obligation_extractor import ObligationMatch
|
||||
from compliance.services.pattern_matcher import ControlPattern, PatternMatchResult
|
||||
from compliance.services.control_composer import ComposedControl
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineChunk
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPipelineChunk:
|
||||
def test_defaults(self):
|
||||
chunk = PipelineChunk(text="test")
|
||||
assert chunk.text == "test"
|
||||
assert chunk.collection == ""
|
||||
assert chunk.regulation_code == ""
|
||||
assert chunk.license_rule == 3
|
||||
assert chunk.chunk_hash == ""
|
||||
|
||||
def test_compute_hash(self):
|
||||
chunk = PipelineChunk(text="hello world")
|
||||
h = chunk.compute_hash()
|
||||
assert len(h) == 64 # SHA256 hex
|
||||
assert h == chunk.chunk_hash # cached
|
||||
|
||||
def test_compute_hash_deterministic(self):
|
||||
chunk1 = PipelineChunk(text="same text")
|
||||
chunk2 = PipelineChunk(text="same text")
|
||||
assert chunk1.compute_hash() == chunk2.compute_hash()
|
||||
|
||||
def test_compute_hash_idempotent(self):
|
||||
chunk = PipelineChunk(text="test")
|
||||
h1 = chunk.compute_hash()
|
||||
h2 = chunk.compute_hash()
|
||||
assert h1 == h2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineResult
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPipelineResult:
|
||||
def test_defaults(self):
|
||||
chunk = PipelineChunk(text="test")
|
||||
result = PipelineResult(chunk=chunk)
|
||||
assert result.control is None
|
||||
assert result.crosswalk_written is False
|
||||
assert result.error is None
|
||||
|
||||
def test_to_dict(self):
|
||||
chunk = PipelineChunk(text="test")
|
||||
chunk.compute_hash()
|
||||
result = PipelineResult(
|
||||
chunk=chunk,
|
||||
obligation=ObligationMatch(
|
||||
obligation_id="DSGVO-OBL-001",
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
),
|
||||
pattern_result=PatternMatchResult(
|
||||
pattern_id="CP-AUTH-001",
|
||||
method="keyword",
|
||||
confidence=0.85,
|
||||
),
|
||||
control=ComposedControl(title="Test Control"),
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["chunk_hash"] == chunk.chunk_hash
|
||||
assert d["obligation"]["obligation_id"] == "DSGVO-OBL-001"
|
||||
assert d["pattern"]["pattern_id"] == "CP-AUTH-001"
|
||||
assert d["control"]["title"] == "Test Control"
|
||||
assert d["error"] is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: _extract_regulation_article
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestExtractRegulationArticle:
|
||||
def test_from_citation_json(self):
|
||||
citation = json.dumps({
|
||||
"source": "eu_2016_679",
|
||||
"article": "Art. 30",
|
||||
})
|
||||
reg, art = _extract_regulation_article(citation, None)
|
||||
assert reg == "dsgvo"
|
||||
assert art == "Art. 30"
|
||||
|
||||
def test_from_metadata(self):
|
||||
metadata = json.dumps({
|
||||
"source_regulation": "eu_2024_1689",
|
||||
"source_article": "Art. 6",
|
||||
})
|
||||
reg, art = _extract_regulation_article(None, metadata)
|
||||
assert reg == "ai_act"
|
||||
assert art == "Art. 6"
|
||||
|
||||
def test_citation_takes_priority(self):
|
||||
citation = json.dumps({"source": "dsgvo", "article": "Art. 30"})
|
||||
metadata = json.dumps({"source_regulation": "nis2", "source_article": "Art. 21"})
|
||||
reg, art = _extract_regulation_article(citation, metadata)
|
||||
assert reg == "dsgvo"
|
||||
assert art == "Art. 30"
|
||||
|
||||
def test_empty_inputs(self):
|
||||
reg, art = _extract_regulation_article(None, None)
|
||||
assert reg is None
|
||||
assert art is None
|
||||
|
||||
def test_invalid_json(self):
|
||||
reg, art = _extract_regulation_article("not json", "also not json")
|
||||
assert reg is None
|
||||
assert art is None
|
||||
|
||||
def test_citation_as_dict(self):
|
||||
citation = {"source": "bdsg", "article": "§ 38"}
|
||||
reg, art = _extract_regulation_article(citation, None)
|
||||
assert reg == "bdsg"
|
||||
assert art == "§ 38"
|
||||
|
||||
def test_source_article_key(self):
|
||||
citation = json.dumps({"source": "dsgvo", "source_article": "Art. 32"})
|
||||
reg, art = _extract_regulation_article(citation, None)
|
||||
assert reg == "dsgvo"
|
||||
assert art == "Art. 32"
|
||||
|
||||
def test_unknown_source(self):
|
||||
citation = json.dumps({"source": "unknown_law", "article": "Art. 1"})
|
||||
reg, art = _extract_regulation_article(citation, None)
|
||||
assert reg is None # _normalize_regulation returns None
|
||||
assert art == "Art. 1"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineAdapter — process_chunk
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPipelineAdapterProcessChunk:
|
||||
"""Tests for the full 3-stage chunk processing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_chunk_full_flow(self):
|
||||
"""Process a chunk through all 3 stages."""
|
||||
adapter = PipelineAdapter()
|
||||
|
||||
obligation = ObligationMatch(
|
||||
obligation_id="DSGVO-OBL-001",
|
||||
obligation_title="Verarbeitungsverzeichnis",
|
||||
obligation_text="Fuehrung eines Verzeichnisses",
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
pattern_result = PatternMatchResult(
|
||||
pattern_id="CP-COMP-001",
|
||||
method="keyword",
|
||||
confidence=0.85,
|
||||
)
|
||||
composed = ComposedControl(
|
||||
title="Test Control",
|
||||
objective="Test objective",
|
||||
pattern_id="CP-COMP-001",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
adapter._extractor, "initialize", new_callable=AsyncMock
|
||||
), patch.object(
|
||||
adapter._matcher, "initialize", new_callable=AsyncMock
|
||||
), patch.object(
|
||||
adapter._extractor, "extract",
|
||||
new_callable=AsyncMock, return_value=obligation,
|
||||
), patch.object(
|
||||
adapter._matcher, "match",
|
||||
new_callable=AsyncMock, return_value=pattern_result,
|
||||
), patch.object(
|
||||
adapter._composer, "compose",
|
||||
new_callable=AsyncMock, return_value=composed,
|
||||
):
|
||||
adapter._initialized = True
|
||||
chunk = PipelineChunk(
|
||||
text="Art. 30 DSGVO Verarbeitungsverzeichnis",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
license_rule=1,
|
||||
)
|
||||
result = await adapter.process_chunk(chunk)
|
||||
|
||||
assert result.obligation.obligation_id == "DSGVO-OBL-001"
|
||||
assert result.pattern_result.pattern_id == "CP-COMP-001"
|
||||
assert result.control.title == "Test Control"
|
||||
assert result.error is None
|
||||
assert result.chunk.chunk_hash != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_chunk_error_handling(self):
|
||||
"""Errors during processing should be captured, not raised."""
|
||||
adapter = PipelineAdapter()
|
||||
adapter._initialized = True
|
||||
|
||||
with patch.object(
|
||||
adapter._extractor, "extract",
|
||||
new_callable=AsyncMock, side_effect=Exception("LLM timeout"),
|
||||
):
|
||||
chunk = PipelineChunk(text="test text")
|
||||
result = await adapter.process_chunk(chunk)
|
||||
|
||||
assert result.error == "LLM timeout"
|
||||
assert result.control is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_chunk_uses_obligation_text_for_pattern(self):
|
||||
"""Pattern matcher should receive obligation text, not raw chunk."""
|
||||
adapter = PipelineAdapter()
|
||||
adapter._initialized = True
|
||||
|
||||
obligation = ObligationMatch(
|
||||
obligation_text="Specific obligation text",
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
adapter._extractor, "extract",
|
||||
new_callable=AsyncMock, return_value=obligation,
|
||||
), patch.object(
|
||||
adapter._matcher, "match",
|
||||
new_callable=AsyncMock, return_value=PatternMatchResult(),
|
||||
) as mock_match, patch.object(
|
||||
adapter._composer, "compose",
|
||||
new_callable=AsyncMock, return_value=ComposedControl(),
|
||||
):
|
||||
await adapter.process_chunk(PipelineChunk(text="raw chunk text"))
|
||||
|
||||
# Pattern matcher should receive the obligation text
|
||||
mock_match.assert_called_once()
|
||||
call_args = mock_match.call_args
|
||||
assert call_args.kwargs["obligation_text"] == "Specific obligation text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_chunk_fallback_to_chunk_text(self):
|
||||
"""When obligation has no text, use chunk text for pattern matching."""
|
||||
adapter = PipelineAdapter()
|
||||
adapter._initialized = True
|
||||
|
||||
obligation = ObligationMatch() # No text
|
||||
|
||||
with patch.object(
|
||||
adapter._extractor, "extract",
|
||||
new_callable=AsyncMock, return_value=obligation,
|
||||
), patch.object(
|
||||
adapter._matcher, "match",
|
||||
new_callable=AsyncMock, return_value=PatternMatchResult(),
|
||||
) as mock_match, patch.object(
|
||||
adapter._composer, "compose",
|
||||
new_callable=AsyncMock, return_value=ComposedControl(),
|
||||
):
|
||||
await adapter.process_chunk(PipelineChunk(text="fallback chunk text"))
|
||||
|
||||
call_args = mock_match.call_args
|
||||
assert "fallback chunk text" in call_args.kwargs["obligation_text"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineAdapter — process_batch
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPipelineAdapterBatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch(self):
|
||||
adapter = PipelineAdapter()
|
||||
adapter._initialized = True
|
||||
|
||||
with patch.object(
|
||||
adapter, "process_chunk",
|
||||
new_callable=AsyncMock,
|
||||
return_value=PipelineResult(chunk=PipelineChunk(text="x")),
|
||||
):
|
||||
chunks = [PipelineChunk(text="a"), PipelineChunk(text="b")]
|
||||
results = await adapter.process_batch(chunks)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_empty(self):
|
||||
adapter = PipelineAdapter()
|
||||
adapter._initialized = True
|
||||
results = await adapter.process_batch([])
|
||||
assert results == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineAdapter — write_crosswalk
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestWriteCrosswalk:
|
||||
def test_write_crosswalk_success(self):
|
||||
"""write_crosswalk should execute 3 DB statements."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute = MagicMock()
|
||||
mock_db.commit = MagicMock()
|
||||
|
||||
adapter = PipelineAdapter(db=mock_db)
|
||||
chunk = PipelineChunk(
|
||||
text="test", regulation_code="eu_2016_679",
|
||||
article="Art. 30", collection="bp_compliance_ce",
|
||||
)
|
||||
chunk.compute_hash()
|
||||
|
||||
result = PipelineResult(
|
||||
chunk=chunk,
|
||||
obligation=ObligationMatch(
|
||||
obligation_id="DSGVO-OBL-001",
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
),
|
||||
pattern_result=PatternMatchResult(
|
||||
pattern_id="CP-COMP-001",
|
||||
confidence=0.85,
|
||||
),
|
||||
control=ComposedControl(
|
||||
control_id="COMP-001",
|
||||
pattern_id="CP-COMP-001",
|
||||
obligation_ids=["DSGVO-OBL-001"],
|
||||
),
|
||||
)
|
||||
|
||||
success = adapter.write_crosswalk(result, "uuid-123")
|
||||
assert success is True
|
||||
assert mock_db.execute.call_count == 3 # insert + insert + update
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_write_crosswalk_no_db(self):
|
||||
adapter = PipelineAdapter(db=None)
|
||||
chunk = PipelineChunk(text="test")
|
||||
result = PipelineResult(chunk=chunk, control=ComposedControl())
|
||||
assert adapter.write_crosswalk(result, "uuid") is False
|
||||
|
||||
def test_write_crosswalk_no_control(self):
|
||||
mock_db = MagicMock()
|
||||
adapter = PipelineAdapter(db=mock_db)
|
||||
chunk = PipelineChunk(text="test")
|
||||
result = PipelineResult(chunk=chunk, control=None)
|
||||
assert adapter.write_crosswalk(result, "uuid") is False
|
||||
|
||||
def test_write_crosswalk_db_error(self):
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute = MagicMock(side_effect=Exception("DB error"))
|
||||
mock_db.rollback = MagicMock()
|
||||
|
||||
adapter = PipelineAdapter(db=mock_db)
|
||||
chunk = PipelineChunk(text="test")
|
||||
chunk.compute_hash()
|
||||
result = PipelineResult(
|
||||
chunk=chunk,
|
||||
obligation=ObligationMatch(),
|
||||
pattern_result=PatternMatchResult(),
|
||||
control=ComposedControl(control_id="X-001"),
|
||||
)
|
||||
assert adapter.write_crosswalk(result, "uuid") is False
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: PipelineAdapter — stats and initialization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPipelineAdapterInit:
|
||||
def test_stats_before_init(self):
|
||||
adapter = PipelineAdapter()
|
||||
stats = adapter.stats()
|
||||
assert stats["initialized"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_initialize(self):
|
||||
adapter = PipelineAdapter()
|
||||
with patch.object(
|
||||
adapter, "initialize", new_callable=AsyncMock,
|
||||
) as mock_init:
|
||||
async def side_effect():
|
||||
adapter._initialized = True
|
||||
mock_init.side_effect = side_effect
|
||||
|
||||
with patch.object(
|
||||
adapter._extractor, "extract",
|
||||
new_callable=AsyncMock, return_value=ObligationMatch(),
|
||||
), patch.object(
|
||||
adapter._matcher, "match",
|
||||
new_callable=AsyncMock, return_value=PatternMatchResult(),
|
||||
), patch.object(
|
||||
adapter._composer, "compose",
|
||||
new_callable=AsyncMock, return_value=ComposedControl(),
|
||||
):
|
||||
await adapter.process_chunk(PipelineChunk(text="test"))
|
||||
|
||||
mock_init.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — Pass 1 (Obligation Linkage)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPass1ObligationLinkage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass1_links_controls(self):
|
||||
"""Pass 1 should link controls with matching articles to obligations."""
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Simulate 2 controls: one with citation, one without
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "COMP-001",
|
||||
json.dumps({"source": "eu_2016_679", "article": "Art. 30"}),
|
||||
json.dumps({"source_regulation": "eu_2016_679"}),
|
||||
),
|
||||
(
|
||||
"uuid-2", "SEC-001",
|
||||
None, # No citation
|
||||
None, # No metadata
|
||||
),
|
||||
]
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
await migration.initialize()
|
||||
|
||||
# Reset mock after initialize queries
|
||||
mock_db.execute.reset_mock()
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "COMP-001",
|
||||
json.dumps({"source": "eu_2016_679", "article": "Art. 30"}),
|
||||
json.dumps({"source_regulation": "eu_2016_679"}),
|
||||
),
|
||||
(
|
||||
"uuid-2", "SEC-001",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
stats = await migration.run_pass1_obligation_linkage()
|
||||
|
||||
assert stats["total"] == 2
|
||||
assert stats["no_citation"] >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass1_with_limit(self):
|
||||
"""Pass 1 should respect limit parameter."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchall.return_value = []
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
migration._initialized = True
|
||||
migration._extractor._load_obligations()
|
||||
|
||||
stats = await migration.run_pass1_obligation_linkage(limit=10)
|
||||
assert stats["total"] == 0
|
||||
|
||||
# Check that LIMIT was in the SQL text clause
|
||||
query_call = mock_db.execute.call_args
|
||||
sql_text_obj = query_call[0][0] # first positional arg is the text() object
|
||||
assert "LIMIT" in sql_text_obj.text
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — Pass 2 (Pattern Classification)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPass2PatternClassification:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass2_classifies_controls(self):
|
||||
"""Pass 2 should match controls to patterns via keywords."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "AUTH-001",
|
||||
"Passwortrichtlinie und Authentifizierung",
|
||||
"Sicherstellen dass Anmeldedaten credential geschuetzt sind",
|
||||
),
|
||||
]
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
await migration.initialize()
|
||||
|
||||
mock_db.execute.reset_mock()
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "AUTH-001",
|
||||
"Passwortrichtlinie und Authentifizierung",
|
||||
"Sicherstellen dass Anmeldedaten credential geschuetzt sind",
|
||||
),
|
||||
]
|
||||
|
||||
stats = await migration.run_pass2_pattern_classification()
|
||||
|
||||
assert stats["total"] == 1
|
||||
# Should classify because "passwort", "authentifizierung", "anmeldedaten" are keywords
|
||||
assert stats["classified"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass2_no_match(self):
|
||||
"""Controls without keyword matches should be counted as no_match."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "MISC-001",
|
||||
"Completely unrelated title",
|
||||
"No keywords match here at all",
|
||||
),
|
||||
]
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
await migration.initialize()
|
||||
|
||||
mock_db.execute.reset_mock()
|
||||
mock_db.execute.return_value.fetchall.return_value = [
|
||||
(
|
||||
"uuid-1", "MISC-001",
|
||||
"Completely unrelated title",
|
||||
"No keywords match here at all",
|
||||
),
|
||||
]
|
||||
|
||||
stats = await migration.run_pass2_pattern_classification()
|
||||
assert stats["no_match"] == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — Pass 3 (Quality Triage)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPass3QualityTriage:
|
||||
def test_pass3_executes_4_updates(self):
|
||||
"""Pass 3 should execute exactly 4 UPDATE statements."""
|
||||
mock_db = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 10
|
||||
mock_db.execute.return_value = mock_result
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
stats = migration.run_pass3_quality_triage()
|
||||
|
||||
assert mock_db.execute.call_count == 4
|
||||
mock_db.commit.assert_called_once()
|
||||
assert "review" in stats
|
||||
assert "needs_obligation" in stats
|
||||
assert "needs_pattern" in stats
|
||||
assert "legacy_unlinked" in stats
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — Pass 4 (Crosswalk Backfill)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPass4CrosswalkBackfill:
|
||||
def test_pass4_inserts_crosswalk_rows(self):
|
||||
mock_db = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 42
|
||||
mock_db.execute.return_value = mock_result
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
stats = migration.run_pass4_crosswalk_backfill()
|
||||
|
||||
assert stats["rows_inserted"] == 42
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — Pass 5 (Deduplication)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPass5Deduplication:
|
||||
def test_pass5_no_duplicates(self):
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchall.return_value = []
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
stats = migration.run_pass5_deduplication()
|
||||
|
||||
assert stats["groups_found"] == 0
|
||||
assert stats["controls_deprecated"] == 0
|
||||
|
||||
def test_pass5_deprecates_duplicates(self):
|
||||
"""Pass 5 should keep first (highest confidence) and deprecate rest."""
|
||||
mock_db = MagicMock()
|
||||
|
||||
# First call: groups query returns one group with 3 controls
|
||||
groups_result = MagicMock()
|
||||
groups_result.fetchall.return_value = [
|
||||
(
|
||||
"CP-AUTH-001", # pattern_id
|
||||
"DSGVO-OBL-001", # obligation_id
|
||||
["uuid-1", "uuid-2", "uuid-3"], # ids (ordered by confidence)
|
||||
3, # count
|
||||
),
|
||||
]
|
||||
|
||||
# Subsequent calls: UPDATE queries
|
||||
update_result = MagicMock()
|
||||
update_result.rowcount = 1
|
||||
|
||||
mock_db.execute.side_effect = [groups_result, update_result, update_result]
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
stats = migration.run_pass5_deduplication()
|
||||
|
||||
assert stats["groups_found"] == 1
|
||||
assert stats["controls_deprecated"] == 2 # uuid-2, uuid-3
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MigrationPasses — migration_status
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMigrationStatus:
|
||||
def test_migration_status(self):
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchone.return_value = (
|
||||
4800, # total
|
||||
2880, # has_obligation (60%)
|
||||
3360, # has_pattern (70%)
|
||||
2400, # fully_linked (50%)
|
||||
300, # deprecated
|
||||
)
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
status = migration.migration_status()
|
||||
|
||||
assert status["total_controls"] == 4800
|
||||
assert status["has_obligation"] == 2880
|
||||
assert status["has_pattern"] == 3360
|
||||
assert status["fully_linked"] == 2400
|
||||
assert status["deprecated"] == 300
|
||||
assert status["coverage_obligation_pct"] == 60.0
|
||||
assert status["coverage_pattern_pct"] == 70.0
|
||||
assert status["coverage_full_pct"] == 50.0
|
||||
|
||||
def test_migration_status_empty_db(self):
|
||||
mock_db = MagicMock()
|
||||
mock_db.execute.return_value.fetchone.return_value = (0, 0, 0, 0, 0)
|
||||
|
||||
migration = MigrationPasses(db=mock_db)
|
||||
status = migration.migration_status()
|
||||
|
||||
assert status["total_controls"] == 0
|
||||
assert status["coverage_obligation_pct"] == 0.0
|
||||
Reference in New Issue
Block a user