Files
breakpilot-compliance/backend-compliance/tests/test_obligation_extractor.py
Benjamin Admin 825e070ed9
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
feat(multi-layer): complete Multi-Layer Control Architecture (Phases 1-8 + Pass 0)
Implements the full Multi-Layer Control Architecture for migrating ~25,000
Rich Controls into atomic, deduplicated Master Controls with full traceability.

Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance

New services:
- ObligationExtractor: 3-tier extraction (exact → embedding → LLM)
- PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus)
- ControlComposer: Pattern + Obligation → Master Control
- PipelineAdapter: Pipeline integration + Migration Passes 1-5
- DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls
- CrosswalkRoutes: 15 API endpoints under /v1/canonical/

New DB schema:
- Migration 060: obligation_extractions, control_patterns, crosswalk_matrix
- Migration 061: obligation_candidates, parent_control_uuid tracking

Pattern Library: 50 YAML patterns (30 core + 20 IT-security)
Go SDK: Pattern loader with YAML validation and indexing
Documentation: MkDocs updated with full architecture overview

500 Python tests passing across all components.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 09:00:37 +01:00

940 lines
34 KiB
Python

"""Tests for Obligation Extractor — Phase 4 of Multi-Layer Control Architecture.
Validates:
- Regulation code normalization (_normalize_regulation)
- Article reference normalization (_normalize_article)
- Cosine similarity (_cosine_sim)
- JSON parsing from LLM responses (_parse_json)
- Obligation loading from v2 framework
- 3-Tier extraction: exact_match → embedding_match → llm_extracted
- ObligationMatch serialization
- Edge cases: empty inputs, missing data, fallback behavior
"""
import json
import math
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from compliance.services.obligation_extractor import (
EMBEDDING_CANDIDATE_THRESHOLD,
EMBEDDING_MATCH_THRESHOLD,
ObligationExtractor,
ObligationMatch,
_ObligationEntry,
_cosine_sim,
_find_obligations_dir,
_normalize_article,
_normalize_regulation,
_parse_json,
)
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
V2_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "obligations" / "v2"
# =============================================================================
# Tests: _normalize_regulation
# =============================================================================
class TestNormalizeRegulation:
"""Tests for regulation code normalization."""
def test_dsgvo_eu_code(self):
assert _normalize_regulation("eu_2016_679") == "dsgvo"
def test_dsgvo_short(self):
assert _normalize_regulation("dsgvo") == "dsgvo"
def test_gdpr_alias(self):
assert _normalize_regulation("gdpr") == "dsgvo"
def test_ai_act_eu_code(self):
assert _normalize_regulation("eu_2024_1689") == "ai_act"
def test_ai_act_short(self):
assert _normalize_regulation("ai_act") == "ai_act"
def test_nis2_eu_code(self):
assert _normalize_regulation("eu_2022_2555") == "nis2"
def test_nis2_short(self):
assert _normalize_regulation("nis2") == "nis2"
def test_bsig_alias(self):
assert _normalize_regulation("bsig") == "nis2"
def test_bdsg(self):
assert _normalize_regulation("bdsg") == "bdsg"
def test_ttdsg(self):
assert _normalize_regulation("ttdsg") == "ttdsg"
def test_dsa_eu_code(self):
assert _normalize_regulation("eu_2022_2065") == "dsa"
def test_data_act_eu_code(self):
assert _normalize_regulation("eu_2023_2854") == "data_act"
def test_eu_machinery_eu_code(self):
assert _normalize_regulation("eu_2023_1230") == "eu_machinery"
def test_dora_eu_code(self):
assert _normalize_regulation("eu_2022_2554") == "dora"
def test_case_insensitive(self):
assert _normalize_regulation("DSGVO") == "dsgvo"
assert _normalize_regulation("AI_ACT") == "ai_act"
assert _normalize_regulation("NIS2") == "nis2"
def test_whitespace_stripped(self):
assert _normalize_regulation(" dsgvo ") == "dsgvo"
def test_empty_string(self):
assert _normalize_regulation("") is None
def test_none(self):
assert _normalize_regulation(None) is None
def test_unknown_code(self):
assert _normalize_regulation("mica") is None
def test_prefix_matching(self):
"""EU codes with suffixes should still match via prefix."""
assert _normalize_regulation("eu_2016_679_consolidated") == "dsgvo"
def test_all_nine_regulations_covered(self):
"""Every regulation in the manifest should be normalizable."""
regulation_ids = ["dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
"data_act", "eu_machinery", "dora"]
for reg_id in regulation_ids:
result = _normalize_regulation(reg_id)
assert result == reg_id, f"Regulation {reg_id} not found"
# =============================================================================
# Tests: _normalize_article
# =============================================================================
class TestNormalizeArticle:
"""Tests for article reference normalization."""
def test_art_with_dot(self):
assert _normalize_article("Art. 30") == "art. 30"
def test_article_english(self):
assert _normalize_article("Article 10") == "art. 10"
def test_artikel_german(self):
assert _normalize_article("Artikel 35") == "art. 35"
def test_paragraph_symbol(self):
assert _normalize_article("§ 38") == "§ 38"
def test_paragraph_with_law_suffix(self):
"""§ 38 BDSG → § 38 (law name stripped)."""
assert _normalize_article("§ 38 BDSG") == "§ 38"
def test_paragraph_with_dsgvo_suffix(self):
assert _normalize_article("Art. 6 DSGVO") == "art. 6"
def test_removes_absatz(self):
"""Art. 30 Abs. 1 → art. 30"""
assert _normalize_article("Art. 30 Abs. 1") == "art. 30"
def test_removes_paragraph(self):
assert _normalize_article("Art. 5 paragraph 2") == "art. 5"
def test_removes_lit(self):
assert _normalize_article("Art. 6 lit. a") == "art. 6"
def test_removes_satz(self):
assert _normalize_article("Art. 12 Satz 3") == "art. 12"
def test_lowercase_output(self):
assert _normalize_article("ART. 30") == "art. 30"
assert _normalize_article("ARTICLE 10") == "art. 10"
def test_whitespace_stripped(self):
assert _normalize_article(" Art. 30 ") == "art. 30"
def test_empty_string(self):
assert _normalize_article("") == ""
def test_none(self):
assert _normalize_article(None) == ""
def test_complex_reference(self):
"""Art. 30 Abs. 1 Satz 2 lit. c DSGVO → art. 30"""
result = _normalize_article("Art. 30 Abs. 1 Satz 2 lit. c DSGVO")
# Should at minimum remove DSGVO and Abs references
assert result.startswith("art. 30")
def test_nis2_article(self):
assert _normalize_article("Art. 21 NIS2") == "art. 21"
def test_dora_article(self):
assert _normalize_article("Art. 5 DORA") == "art. 5"
def test_ai_act_article(self):
result = _normalize_article("Article 6 AI Act")
assert result == "art. 6"
# =============================================================================
# Tests: _cosine_sim
# =============================================================================
class TestCosineSim:
"""Tests for cosine similarity calculation."""
def test_identical_vectors(self):
v = [1.0, 2.0, 3.0]
assert abs(_cosine_sim(v, v) - 1.0) < 1e-6
def test_orthogonal_vectors(self):
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(_cosine_sim(a, b)) < 1e-6
def test_opposite_vectors(self):
a = [1.0, 2.0, 3.0]
b = [-1.0, -2.0, -3.0]
assert abs(_cosine_sim(a, b) - (-1.0)) < 1e-6
def test_known_value(self):
a = [1.0, 0.0]
b = [1.0, 1.0]
expected = 1.0 / math.sqrt(2)
assert abs(_cosine_sim(a, b) - expected) < 1e-6
def test_empty_vectors(self):
assert _cosine_sim([], []) == 0.0
def test_one_empty(self):
assert _cosine_sim([1.0, 2.0], []) == 0.0
assert _cosine_sim([], [1.0, 2.0]) == 0.0
def test_different_lengths(self):
assert _cosine_sim([1.0, 2.0], [1.0]) == 0.0
def test_zero_vector(self):
assert _cosine_sim([0.0, 0.0], [1.0, 2.0]) == 0.0
def test_both_zero(self):
assert _cosine_sim([0.0, 0.0], [0.0, 0.0]) == 0.0
def test_high_dimensional(self):
"""Test with realistic embedding dimensions (1024)."""
import random
random.seed(42)
a = [random.gauss(0, 1) for _ in range(1024)]
b = [random.gauss(0, 1) for _ in range(1024)]
score = _cosine_sim(a, b)
assert -1.0 <= score <= 1.0
# =============================================================================
# Tests: _parse_json
# =============================================================================
class TestParseJson:
"""Tests for JSON extraction from LLM responses."""
def test_direct_json(self):
text = '{"obligation_text": "Test", "actor": "Controller"}'
result = _parse_json(text)
assert result["obligation_text"] == "Test"
assert result["actor"] == "Controller"
def test_json_in_markdown_block(self):
"""LLMs often wrap JSON in markdown code blocks."""
text = '''Some explanation text
```json
{"obligation_text": "Test"}
```
More text'''
result = _parse_json(text)
assert result.get("obligation_text") == "Test"
def test_json_with_prefix_text(self):
text = 'Here is the result: {"obligation_text": "Pflicht", "actor": "Verantwortlicher"}'
result = _parse_json(text)
assert result["obligation_text"] == "Pflicht"
def test_invalid_json(self):
result = _parse_json("not json at all")
assert result == {}
def test_empty_string(self):
result = _parse_json("")
assert result == {}
def test_nested_braces_picks_first(self):
"""With nested objects, the regex picks the inner simple object."""
text = '{"outer": {"inner": "value"}}'
result = _parse_json(text)
# Direct parse should work for valid nested JSON
assert "outer" in result
def test_json_with_german_umlauts(self):
text = '{"obligation_text": "Pflicht zur Datenschutz-Folgenabschaetzung"}'
result = _parse_json(text)
assert "Datenschutz" in result["obligation_text"]
# =============================================================================
# Tests: ObligationMatch
# =============================================================================
class TestObligationMatch:
"""Tests for the ObligationMatch dataclass."""
def test_defaults(self):
match = ObligationMatch()
assert match.obligation_id is None
assert match.obligation_title is None
assert match.obligation_text is None
assert match.method == "none"
assert match.confidence == 0.0
assert match.regulation_id is None
def test_to_dict(self):
match = ObligationMatch(
obligation_id="DSGVO-OBL-001",
obligation_title="Verarbeitungsverzeichnis",
obligation_text="Fuehrung eines Verzeichnisses...",
method="exact_match",
confidence=1.0,
regulation_id="dsgvo",
)
d = match.to_dict()
assert d["obligation_id"] == "DSGVO-OBL-001"
assert d["method"] == "exact_match"
assert d["confidence"] == 1.0
assert d["regulation_id"] == "dsgvo"
def test_to_dict_keys(self):
match = ObligationMatch()
d = match.to_dict()
expected_keys = {
"obligation_id", "obligation_title", "obligation_text",
"method", "confidence", "regulation_id",
}
assert set(d.keys()) == expected_keys
def test_to_dict_none_values(self):
match = ObligationMatch()
d = match.to_dict()
assert d["obligation_id"] is None
assert d["obligation_title"] is None
# =============================================================================
# Tests: _find_obligations_dir
# =============================================================================
class TestFindObligationsDir:
"""Tests for finding the v2 obligations directory."""
def test_finds_v2_directory(self):
"""Should find the v2 dir relative to the source file."""
result = _find_obligations_dir()
# May be None in CI without the SDK, but if found, verify it's valid
if result is not None:
assert result.is_dir()
assert (result / "_manifest.json").exists()
def test_v2_dir_exists_in_repo(self):
"""The v2 dir should exist in the repo for local tests."""
assert V2_DIR.exists(), f"v2 dir not found at {V2_DIR}"
assert (V2_DIR / "_manifest.json").exists()
# =============================================================================
# Tests: ObligationExtractor — _load_obligations
# =============================================================================
class TestObligationExtractorLoad:
"""Tests for obligation loading from v2 JSON files."""
def test_load_obligations_populates_lookup(self):
extractor = ObligationExtractor()
extractor._load_obligations()
assert len(extractor._obligations) > 0
def test_load_obligations_count(self):
"""Should load all 325 obligations from 9 regulations."""
extractor = ObligationExtractor()
extractor._load_obligations()
assert len(extractor._obligations) == 325
def test_article_lookup_populated(self):
"""Article lookup should have entries for obligations with legal_basis."""
extractor = ObligationExtractor()
extractor._load_obligations()
assert len(extractor._article_lookup) > 0
def test_article_lookup_dsgvo_art30(self):
"""DSGVO Art. 30 should resolve to DSGVO-OBL-001."""
extractor = ObligationExtractor()
extractor._load_obligations()
key = "dsgvo/art. 30"
assert key in extractor._article_lookup
assert "DSGVO-OBL-001" in extractor._article_lookup[key]
def test_obligations_have_required_fields(self):
"""Every loaded obligation should have id, title, description, regulation_id."""
extractor = ObligationExtractor()
extractor._load_obligations()
for obl_id, entry in extractor._obligations.items():
assert entry.id == obl_id
assert entry.title, f"{obl_id}: empty title"
assert entry.description, f"{obl_id}: empty description"
assert entry.regulation_id, f"{obl_id}: empty regulation_id"
def test_all_nine_regulations_loaded(self):
"""All 9 regulations from the manifest should be loaded."""
extractor = ObligationExtractor()
extractor._load_obligations()
regulation_ids = {e.regulation_id for e in extractor._obligations.values()}
expected = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
"data_act", "eu_machinery", "dora"}
assert regulation_ids == expected
def test_obligation_id_format(self):
"""All obligation IDs should follow the pattern {REG}-OBL-{NNN}."""
extractor = ObligationExtractor()
extractor._load_obligations()
import re
# Allow letters, digits, underscores in prefix (e.g. NIS2-OBL-001, EU_MACHINERY-OBL-001)
pattern = re.compile(r"^[A-Z0-9_]+-OBL-\d{3}$")
for obl_id in extractor._obligations:
assert pattern.match(obl_id), f"Invalid obligation ID format: {obl_id}"
def test_no_duplicate_obligation_ids(self):
"""All obligation IDs should be unique."""
extractor = ObligationExtractor()
extractor._load_obligations()
ids = list(extractor._obligations.keys())
assert len(ids) == len(set(ids))
# =============================================================================
# Tests: ObligationExtractor — Tier 1 (Exact Match)
# =============================================================================
class TestTier1ExactMatch:
"""Tests for Tier 1 exact article lookup."""
def setup_method(self):
self.extractor = ObligationExtractor()
self.extractor._load_obligations()
def test_exact_match_dsgvo_art30(self):
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
assert match is not None
assert match.obligation_id == "DSGVO-OBL-001"
assert match.method == "exact_match"
assert match.confidence == 1.0
assert match.regulation_id == "dsgvo"
def test_exact_match_case_insensitive_article(self):
match = self.extractor._tier1_exact("dsgvo", "ART. 30")
assert match is not None
assert match.obligation_id == "DSGVO-OBL-001"
def test_exact_match_article_variant(self):
"""'Article 30' should normalize to 'art. 30' and match."""
match = self.extractor._tier1_exact("dsgvo", "Article 30")
assert match is not None
assert match.obligation_id == "DSGVO-OBL-001"
def test_exact_match_artikel_variant(self):
match = self.extractor._tier1_exact("dsgvo", "Artikel 30")
assert match is not None
assert match.obligation_id == "DSGVO-OBL-001"
def test_exact_match_strips_absatz(self):
"""Art. 30 Abs. 1 → art. 30 → should match."""
match = self.extractor._tier1_exact("dsgvo", "Art. 30 Abs. 1")
assert match is not None
assert match.obligation_id == "DSGVO-OBL-001"
def test_no_match_wrong_article(self):
match = self.extractor._tier1_exact("dsgvo", "Art. 999")
assert match is None
def test_no_match_unknown_regulation(self):
match = self.extractor._tier1_exact("unknown_reg", "Art. 30")
assert match is None
def test_no_match_none_regulation(self):
match = self.extractor._tier1_exact(None, "Art. 30")
assert match is None
def test_match_has_title(self):
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
assert match is not None
assert match.obligation_title is not None
assert len(match.obligation_title) > 0
def test_match_has_text(self):
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
assert match is not None
assert match.obligation_text is not None
assert len(match.obligation_text) > 20
# =============================================================================
# Tests: ObligationExtractor — Tier 2 (Embedding Match)
# =============================================================================
class TestTier2EmbeddingMatch:
"""Tests for Tier 2 embedding-based matching."""
def setup_method(self):
self.extractor = ObligationExtractor()
self.extractor._load_obligations()
# Prepare fake embeddings for testing (no real embedding service)
self.extractor._obligation_ids = list(self.extractor._obligations.keys())
# Create simple 3D embeddings per obligation — avoid zero vectors
self.extractor._obligation_embeddings = []
for i in range(len(self.extractor._obligation_ids)):
# Each obligation gets a unique-ish non-zero vector
self.extractor._obligation_embeddings.append(
[float(i % 10 + 1), float((i * 3) % 10 + 1), float((i * 7) % 10 + 1)]
)
@pytest.mark.asyncio
async def test_embedding_match_above_threshold(self):
"""When cosine > 0.80, should return embedding_match."""
# Mock the embedding service to return a vector very similar to obligation 0
target_embedding = self.extractor._obligation_embeddings[0]
with patch(
"compliance.services.obligation_extractor._get_embedding",
new_callable=AsyncMock,
return_value=target_embedding,
):
match = await self.extractor._tier2_embedding("test text", "dsgvo")
# Should find a match (cosine = 1.0 for identical vector)
assert match is not None
assert match.method == "embedding_match"
assert match.confidence >= EMBEDDING_MATCH_THRESHOLD
@pytest.mark.asyncio
async def test_embedding_match_returns_none_below_threshold(self):
"""When cosine < 0.80, should return None."""
# Return a vector orthogonal to all obligations
orthogonal = [100.0, -100.0, 0.0]
with patch(
"compliance.services.obligation_extractor._get_embedding",
new_callable=AsyncMock,
return_value=orthogonal,
):
match = await self.extractor._tier2_embedding("unrelated text", None)
# May or may not match depending on vector distribution
# But we can verify it's either None or has correct method
if match is not None:
assert match.method == "embedding_match"
@pytest.mark.asyncio
async def test_embedding_match_empty_embeddings(self):
"""When no embeddings loaded, should return None."""
self.extractor._obligation_embeddings = []
match = await self.extractor._tier2_embedding("any text", "dsgvo")
assert match is None
@pytest.mark.asyncio
async def test_embedding_match_failed_embedding(self):
"""When embedding service returns empty, should return None."""
with patch(
"compliance.services.obligation_extractor._get_embedding",
new_callable=AsyncMock,
return_value=[],
):
match = await self.extractor._tier2_embedding("some text", "dsgvo")
assert match is None
@pytest.mark.asyncio
async def test_domain_bonus_same_regulation(self):
"""Matching regulation should add +0.05 bonus."""
# Set up two obligations with same embeddings but different regulations
self.extractor._obligation_ids = ["DSGVO-OBL-001", "NIS2-OBL-001"]
self.extractor._obligation_embeddings = [
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
]
with patch(
"compliance.services.obligation_extractor._get_embedding",
new_callable=AsyncMock,
return_value=[1.0, 0.0, 0.0],
):
match = await self.extractor._tier2_embedding("test", "dsgvo")
# Should match (cosine = 1.0 ≥ 0.80)
assert match is not None
assert match.method == "embedding_match"
# With domain bonus, DSGVO should be preferred
assert match.regulation_id == "dsgvo"
@pytest.mark.asyncio
async def test_confidence_capped_at_1(self):
"""Confidence should not exceed 1.0 even with domain bonus."""
self.extractor._obligation_ids = ["DSGVO-OBL-001"]
self.extractor._obligation_embeddings = [[1.0, 0.0, 0.0]]
with patch(
"compliance.services.obligation_extractor._get_embedding",
new_callable=AsyncMock,
return_value=[1.0, 0.0, 0.0],
):
match = await self.extractor._tier2_embedding("test", "dsgvo")
assert match is not None
assert match.confidence <= 1.0
# =============================================================================
# Tests: ObligationExtractor — Tier 3 (LLM Extraction)
# =============================================================================
class TestTier3LLMExtraction:
"""Tests for Tier 3 LLM-based obligation extraction."""
def setup_method(self):
self.extractor = ObligationExtractor()
@pytest.mark.asyncio
async def test_llm_extraction_success(self):
"""Successful LLM extraction returns obligation_text with confidence 0.60."""
llm_response = json.dumps({
"obligation_text": "Pflicht zur Fuehrung eines Verarbeitungsverzeichnisses",
"actor": "Verantwortlicher",
"action": "Verarbeitungsverzeichnis fuehren",
"normative_strength": "muss",
})
with patch(
"compliance.services.obligation_extractor._llm_ollama",
new_callable=AsyncMock,
return_value=llm_response,
):
match = await self.extractor._tier3_llm(
"Der Verantwortliche fuehrt ein Verzeichnis...",
"eu_2016_679",
"Art. 30",
)
assert match.method == "llm_extracted"
assert match.confidence == 0.60
assert "Verarbeitungsverzeichnis" in match.obligation_text
assert match.obligation_id is None # LLM doesn't assign IDs
assert match.regulation_id == "dsgvo"
@pytest.mark.asyncio
async def test_llm_extraction_failure(self):
"""When LLM returns empty, should return match with confidence 0."""
with patch(
"compliance.services.obligation_extractor._llm_ollama",
new_callable=AsyncMock,
return_value="",
):
match = await self.extractor._tier3_llm("some text", "dsgvo", "Art. 1")
assert match.method == "llm_extracted"
assert match.confidence == 0.0
assert match.obligation_text is None
@pytest.mark.asyncio
async def test_llm_extraction_malformed_json(self):
"""When LLM returns non-JSON, should use raw text as fallback."""
with patch(
"compliance.services.obligation_extractor._llm_ollama",
new_callable=AsyncMock,
return_value="Dies ist die Pflicht: Daten schuetzen",
):
match = await self.extractor._tier3_llm("some text", "dsgvo", None)
assert match.method == "llm_extracted"
assert match.confidence == 0.60
# Fallback: uses first 500 chars of response as obligation_text
assert "Pflicht" in match.obligation_text or "Daten" in match.obligation_text
@pytest.mark.asyncio
async def test_llm_regulation_normalization(self):
"""Regulation code should be normalized in result."""
with patch(
"compliance.services.obligation_extractor._llm_ollama",
new_callable=AsyncMock,
return_value='{"obligation_text": "Test"}',
):
match = await self.extractor._tier3_llm(
"text", "eu_2024_1689", "Art. 6"
)
assert match.regulation_id == "ai_act"
# =============================================================================
# Tests: ObligationExtractor — Full 3-Tier extract()
# =============================================================================
class TestExtractFullFlow:
"""Tests for the full 3-tier extraction flow."""
def setup_method(self):
self.extractor = ObligationExtractor()
self.extractor._load_obligations()
# Mark as initialized to skip async initialize
self.extractor._initialized = True
# Empty embeddings — Tier 2 will return None
self.extractor._obligation_embeddings = []
self.extractor._obligation_ids = []
@pytest.mark.asyncio
async def test_tier1_takes_priority(self):
"""When Tier 1 matches, Tier 2 and 3 should not be called."""
with patch.object(
self.extractor, "_tier2_embedding", new_callable=AsyncMock
) as mock_t2, patch.object(
self.extractor, "_tier3_llm", new_callable=AsyncMock
) as mock_t3:
match = await self.extractor.extract(
chunk_text="irrelevant",
regulation_code="eu_2016_679",
article="Art. 30",
)
assert match.method == "exact_match"
mock_t2.assert_not_called()
mock_t3.assert_not_called()
@pytest.mark.asyncio
async def test_tier2_when_tier1_misses(self):
"""When Tier 1 misses, Tier 2 should be tried."""
tier2_result = ObligationMatch(
obligation_id="DSGVO-OBL-050",
method="embedding_match",
confidence=0.85,
regulation_id="dsgvo",
)
with patch.object(
self.extractor, "_tier2_embedding",
new_callable=AsyncMock,
return_value=tier2_result,
) as mock_t2, patch.object(
self.extractor, "_tier3_llm", new_callable=AsyncMock
) as mock_t3:
match = await self.extractor.extract(
chunk_text="some compliance text",
regulation_code="eu_2016_679",
article="Art. 999", # Non-matching article
)
assert match.method == "embedding_match"
mock_t2.assert_called_once()
mock_t3.assert_not_called()
@pytest.mark.asyncio
async def test_tier3_when_tier1_and_2_miss(self):
"""When Tier 1 and 2 miss, Tier 3 should be called."""
tier3_result = ObligationMatch(
obligation_text="LLM extracted obligation",
method="llm_extracted",
confidence=0.60,
)
with patch.object(
self.extractor, "_tier2_embedding",
new_callable=AsyncMock,
return_value=None,
), patch.object(
self.extractor, "_tier3_llm",
new_callable=AsyncMock,
return_value=tier3_result,
):
match = await self.extractor.extract(
chunk_text="unrelated text",
regulation_code="unknown_reg",
article="Art. 999",
)
assert match.method == "llm_extracted"
@pytest.mark.asyncio
async def test_no_article_skips_tier1(self):
"""When no article is provided, Tier 1 should be skipped."""
with patch.object(
self.extractor, "_tier2_embedding",
new_callable=AsyncMock,
return_value=None,
) as mock_t2, patch.object(
self.extractor, "_tier3_llm",
new_callable=AsyncMock,
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
):
match = await self.extractor.extract(
chunk_text="some text",
regulation_code="dsgvo",
article=None,
)
# Tier 2 should be called (Tier 1 skipped due to no article)
mock_t2.assert_called_once()
@pytest.mark.asyncio
async def test_auto_initialize(self):
"""If not initialized, extract should call initialize()."""
extractor = ObligationExtractor()
assert not extractor._initialized
with patch.object(
extractor, "initialize", new_callable=AsyncMock
) as mock_init:
# After mock init, set initialized to True
async def side_effect():
extractor._initialized = True
extractor._load_obligations()
extractor._obligation_embeddings = []
extractor._obligation_ids = []
mock_init.side_effect = side_effect
with patch.object(
extractor, "_tier2_embedding",
new_callable=AsyncMock,
return_value=None,
), patch.object(
extractor, "_tier3_llm",
new_callable=AsyncMock,
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
):
await extractor.extract(
chunk_text="test",
regulation_code="dsgvo",
article=None,
)
mock_init.assert_called_once()
# =============================================================================
# Tests: ObligationExtractor — stats()
# =============================================================================
class TestExtractorStats:
"""Tests for the stats() method."""
def test_stats_before_init(self):
extractor = ObligationExtractor()
stats = extractor.stats()
assert stats["total_obligations"] == 0
assert stats["article_lookups"] == 0
assert stats["initialized"] is False
def test_stats_after_load(self):
extractor = ObligationExtractor()
extractor._load_obligations()
stats = extractor.stats()
assert stats["total_obligations"] == 325
assert stats["article_lookups"] > 0
assert "dsgvo" in stats["regulations"]
assert stats["initialized"] is False # not fully initialized (no embeddings)
def test_stats_regulations_complete(self):
extractor = ObligationExtractor()
extractor._load_obligations()
stats = extractor.stats()
expected_regs = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg",
"dsa", "data_act", "eu_machinery", "dora"}
assert set(stats["regulations"]) == expected_regs
# =============================================================================
# Tests: Integration — Regulation-to-Obligation mapping coverage
# =============================================================================
class TestRegulationObligationCoverage:
"""Verify that the article lookup provides reasonable coverage."""
def setup_method(self):
self.extractor = ObligationExtractor()
self.extractor._load_obligations()
def test_dsgvo_has_article_lookups(self):
"""DSGVO (80 obligations) should have many article lookups."""
dsgvo_keys = [k for k in self.extractor._article_lookup if k.startswith("dsgvo/")]
assert len(dsgvo_keys) >= 20, f"Only {len(dsgvo_keys)} DSGVO article lookups"
def test_ai_act_has_article_lookups(self):
ai_keys = [k for k in self.extractor._article_lookup if k.startswith("ai_act/")]
assert len(ai_keys) >= 10, f"Only {len(ai_keys)} AI Act article lookups"
def test_nis2_has_article_lookups(self):
nis2_keys = [k for k in self.extractor._article_lookup if k.startswith("nis2/")]
assert len(nis2_keys) >= 5, f"Only {len(nis2_keys)} NIS2 article lookups"
def test_all_article_lookup_values_are_valid(self):
"""Every obligation ID in article_lookup should exist in _obligations."""
for key, obl_ids in self.extractor._article_lookup.items():
for obl_id in obl_ids:
assert obl_id in self.extractor._obligations, (
f"Article lookup {key} references missing obligation {obl_id}"
)
def test_article_lookup_key_format(self):
"""All keys should be in format 'regulation_id/normalized_article'."""
for key in self.extractor._article_lookup:
parts = key.split("/", 1)
assert len(parts) == 2, f"Invalid key format: {key}"
reg_id, article = parts
assert reg_id, f"Empty regulation ID in key: {key}"
assert article, f"Empty article in key: {key}"
assert article == article.lower(), f"Article not lowercase: {key}"
# =============================================================================
# Tests: Constants and thresholds
# =============================================================================
class TestConstants:
"""Tests for module-level constants."""
def test_embedding_thresholds_ordering(self):
"""Match threshold should be higher than candidate threshold."""
assert EMBEDDING_MATCH_THRESHOLD > EMBEDDING_CANDIDATE_THRESHOLD
def test_embedding_thresholds_range(self):
"""Thresholds should be between 0 and 1."""
assert 0 < EMBEDDING_MATCH_THRESHOLD <= 1.0
assert 0 < EMBEDDING_CANDIDATE_THRESHOLD <= 1.0
def test_match_threshold_is_80(self):
assert EMBEDDING_MATCH_THRESHOLD == 0.80
def test_candidate_threshold_is_60(self):
assert EMBEDDING_CANDIDATE_THRESHOLD == 0.60