Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Implements the full Multi-Layer Control Architecture for migrating ~25,000 Rich Controls into atomic, deduplicated Master Controls with full traceability. Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance New services: - ObligationExtractor: 3-tier extraction (exact → embedding → LLM) - PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus) - ControlComposer: Pattern + Obligation → Master Control - PipelineAdapter: Pipeline integration + Migration Passes 1-5 - DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls - CrosswalkRoutes: 15 API endpoints under /v1/canonical/ New DB schema: - Migration 060: obligation_extractions, control_patterns, crosswalk_matrix - Migration 061: obligation_candidates, parent_control_uuid tracking Pattern Library: 50 YAML patterns (30 core + 20 IT-security) Go SDK: Pattern loader with YAML validation and indexing Documentation: MkDocs updated with full architecture overview 500 Python tests passing across all components. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
940 lines
34 KiB
Python
940 lines
34 KiB
Python
"""Tests for Obligation Extractor — Phase 4 of Multi-Layer Control Architecture.
|
|
|
|
Validates:
|
|
- Regulation code normalization (_normalize_regulation)
|
|
- Article reference normalization (_normalize_article)
|
|
- Cosine similarity (_cosine_sim)
|
|
- JSON parsing from LLM responses (_parse_json)
|
|
- Obligation loading from v2 framework
|
|
- 3-Tier extraction: exact_match → embedding_match → llm_extracted
|
|
- ObligationMatch serialization
|
|
- Edge cases: empty inputs, missing data, fallback behavior
|
|
"""
|
|
|
|
import json
|
|
import math
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from compliance.services.obligation_extractor import (
|
|
EMBEDDING_CANDIDATE_THRESHOLD,
|
|
EMBEDDING_MATCH_THRESHOLD,
|
|
ObligationExtractor,
|
|
ObligationMatch,
|
|
_ObligationEntry,
|
|
_cosine_sim,
|
|
_find_obligations_dir,
|
|
_normalize_article,
|
|
_normalize_regulation,
|
|
_parse_json,
|
|
)
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
V2_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "obligations" / "v2"
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: _normalize_regulation
|
|
# =============================================================================
|
|
|
|
|
|
class TestNormalizeRegulation:
|
|
"""Tests for regulation code normalization."""
|
|
|
|
def test_dsgvo_eu_code(self):
|
|
assert _normalize_regulation("eu_2016_679") == "dsgvo"
|
|
|
|
def test_dsgvo_short(self):
|
|
assert _normalize_regulation("dsgvo") == "dsgvo"
|
|
|
|
def test_gdpr_alias(self):
|
|
assert _normalize_regulation("gdpr") == "dsgvo"
|
|
|
|
def test_ai_act_eu_code(self):
|
|
assert _normalize_regulation("eu_2024_1689") == "ai_act"
|
|
|
|
def test_ai_act_short(self):
|
|
assert _normalize_regulation("ai_act") == "ai_act"
|
|
|
|
def test_nis2_eu_code(self):
|
|
assert _normalize_regulation("eu_2022_2555") == "nis2"
|
|
|
|
def test_nis2_short(self):
|
|
assert _normalize_regulation("nis2") == "nis2"
|
|
|
|
def test_bsig_alias(self):
|
|
assert _normalize_regulation("bsig") == "nis2"
|
|
|
|
def test_bdsg(self):
|
|
assert _normalize_regulation("bdsg") == "bdsg"
|
|
|
|
def test_ttdsg(self):
|
|
assert _normalize_regulation("ttdsg") == "ttdsg"
|
|
|
|
def test_dsa_eu_code(self):
|
|
assert _normalize_regulation("eu_2022_2065") == "dsa"
|
|
|
|
def test_data_act_eu_code(self):
|
|
assert _normalize_regulation("eu_2023_2854") == "data_act"
|
|
|
|
def test_eu_machinery_eu_code(self):
|
|
assert _normalize_regulation("eu_2023_1230") == "eu_machinery"
|
|
|
|
def test_dora_eu_code(self):
|
|
assert _normalize_regulation("eu_2022_2554") == "dora"
|
|
|
|
def test_case_insensitive(self):
|
|
assert _normalize_regulation("DSGVO") == "dsgvo"
|
|
assert _normalize_regulation("AI_ACT") == "ai_act"
|
|
assert _normalize_regulation("NIS2") == "nis2"
|
|
|
|
def test_whitespace_stripped(self):
|
|
assert _normalize_regulation(" dsgvo ") == "dsgvo"
|
|
|
|
def test_empty_string(self):
|
|
assert _normalize_regulation("") is None
|
|
|
|
def test_none(self):
|
|
assert _normalize_regulation(None) is None
|
|
|
|
def test_unknown_code(self):
|
|
assert _normalize_regulation("mica") is None
|
|
|
|
def test_prefix_matching(self):
|
|
"""EU codes with suffixes should still match via prefix."""
|
|
assert _normalize_regulation("eu_2016_679_consolidated") == "dsgvo"
|
|
|
|
def test_all_nine_regulations_covered(self):
|
|
"""Every regulation in the manifest should be normalizable."""
|
|
regulation_ids = ["dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
|
|
"data_act", "eu_machinery", "dora"]
|
|
for reg_id in regulation_ids:
|
|
result = _normalize_regulation(reg_id)
|
|
assert result == reg_id, f"Regulation {reg_id} not found"
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: _normalize_article
|
|
# =============================================================================
|
|
|
|
|
|
class TestNormalizeArticle:
|
|
"""Tests for article reference normalization."""
|
|
|
|
def test_art_with_dot(self):
|
|
assert _normalize_article("Art. 30") == "art. 30"
|
|
|
|
def test_article_english(self):
|
|
assert _normalize_article("Article 10") == "art. 10"
|
|
|
|
def test_artikel_german(self):
|
|
assert _normalize_article("Artikel 35") == "art. 35"
|
|
|
|
def test_paragraph_symbol(self):
|
|
assert _normalize_article("§ 38") == "§ 38"
|
|
|
|
def test_paragraph_with_law_suffix(self):
|
|
"""§ 38 BDSG → § 38 (law name stripped)."""
|
|
assert _normalize_article("§ 38 BDSG") == "§ 38"
|
|
|
|
def test_paragraph_with_dsgvo_suffix(self):
|
|
assert _normalize_article("Art. 6 DSGVO") == "art. 6"
|
|
|
|
def test_removes_absatz(self):
|
|
"""Art. 30 Abs. 1 → art. 30"""
|
|
assert _normalize_article("Art. 30 Abs. 1") == "art. 30"
|
|
|
|
def test_removes_paragraph(self):
|
|
assert _normalize_article("Art. 5 paragraph 2") == "art. 5"
|
|
|
|
def test_removes_lit(self):
|
|
assert _normalize_article("Art. 6 lit. a") == "art. 6"
|
|
|
|
def test_removes_satz(self):
|
|
assert _normalize_article("Art. 12 Satz 3") == "art. 12"
|
|
|
|
def test_lowercase_output(self):
|
|
assert _normalize_article("ART. 30") == "art. 30"
|
|
assert _normalize_article("ARTICLE 10") == "art. 10"
|
|
|
|
def test_whitespace_stripped(self):
|
|
assert _normalize_article(" Art. 30 ") == "art. 30"
|
|
|
|
def test_empty_string(self):
|
|
assert _normalize_article("") == ""
|
|
|
|
def test_none(self):
|
|
assert _normalize_article(None) == ""
|
|
|
|
def test_complex_reference(self):
|
|
"""Art. 30 Abs. 1 Satz 2 lit. c DSGVO → art. 30"""
|
|
result = _normalize_article("Art. 30 Abs. 1 Satz 2 lit. c DSGVO")
|
|
# Should at minimum remove DSGVO and Abs references
|
|
assert result.startswith("art. 30")
|
|
|
|
def test_nis2_article(self):
|
|
assert _normalize_article("Art. 21 NIS2") == "art. 21"
|
|
|
|
def test_dora_article(self):
|
|
assert _normalize_article("Art. 5 DORA") == "art. 5"
|
|
|
|
def test_ai_act_article(self):
|
|
result = _normalize_article("Article 6 AI Act")
|
|
assert result == "art. 6"
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: _cosine_sim
|
|
# =============================================================================
|
|
|
|
|
|
class TestCosineSim:
|
|
"""Tests for cosine similarity calculation."""
|
|
|
|
def test_identical_vectors(self):
|
|
v = [1.0, 2.0, 3.0]
|
|
assert abs(_cosine_sim(v, v) - 1.0) < 1e-6
|
|
|
|
def test_orthogonal_vectors(self):
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
assert abs(_cosine_sim(a, b)) < 1e-6
|
|
|
|
def test_opposite_vectors(self):
|
|
a = [1.0, 2.0, 3.0]
|
|
b = [-1.0, -2.0, -3.0]
|
|
assert abs(_cosine_sim(a, b) - (-1.0)) < 1e-6
|
|
|
|
def test_known_value(self):
|
|
a = [1.0, 0.0]
|
|
b = [1.0, 1.0]
|
|
expected = 1.0 / math.sqrt(2)
|
|
assert abs(_cosine_sim(a, b) - expected) < 1e-6
|
|
|
|
def test_empty_vectors(self):
|
|
assert _cosine_sim([], []) == 0.0
|
|
|
|
def test_one_empty(self):
|
|
assert _cosine_sim([1.0, 2.0], []) == 0.0
|
|
assert _cosine_sim([], [1.0, 2.0]) == 0.0
|
|
|
|
def test_different_lengths(self):
|
|
assert _cosine_sim([1.0, 2.0], [1.0]) == 0.0
|
|
|
|
def test_zero_vector(self):
|
|
assert _cosine_sim([0.0, 0.0], [1.0, 2.0]) == 0.0
|
|
|
|
def test_both_zero(self):
|
|
assert _cosine_sim([0.0, 0.0], [0.0, 0.0]) == 0.0
|
|
|
|
def test_high_dimensional(self):
|
|
"""Test with realistic embedding dimensions (1024)."""
|
|
import random
|
|
random.seed(42)
|
|
a = [random.gauss(0, 1) for _ in range(1024)]
|
|
b = [random.gauss(0, 1) for _ in range(1024)]
|
|
score = _cosine_sim(a, b)
|
|
assert -1.0 <= score <= 1.0
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: _parse_json
|
|
# =============================================================================
|
|
|
|
|
|
class TestParseJson:
|
|
"""Tests for JSON extraction from LLM responses."""
|
|
|
|
def test_direct_json(self):
|
|
text = '{"obligation_text": "Test", "actor": "Controller"}'
|
|
result = _parse_json(text)
|
|
assert result["obligation_text"] == "Test"
|
|
assert result["actor"] == "Controller"
|
|
|
|
def test_json_in_markdown_block(self):
|
|
"""LLMs often wrap JSON in markdown code blocks."""
|
|
text = '''Some explanation text
|
|
```json
|
|
{"obligation_text": "Test"}
|
|
```
|
|
More text'''
|
|
result = _parse_json(text)
|
|
assert result.get("obligation_text") == "Test"
|
|
|
|
def test_json_with_prefix_text(self):
|
|
text = 'Here is the result: {"obligation_text": "Pflicht", "actor": "Verantwortlicher"}'
|
|
result = _parse_json(text)
|
|
assert result["obligation_text"] == "Pflicht"
|
|
|
|
def test_invalid_json(self):
|
|
result = _parse_json("not json at all")
|
|
assert result == {}
|
|
|
|
def test_empty_string(self):
|
|
result = _parse_json("")
|
|
assert result == {}
|
|
|
|
def test_nested_braces_picks_first(self):
|
|
"""With nested objects, the regex picks the inner simple object."""
|
|
text = '{"outer": {"inner": "value"}}'
|
|
result = _parse_json(text)
|
|
# Direct parse should work for valid nested JSON
|
|
assert "outer" in result
|
|
|
|
def test_json_with_german_umlauts(self):
|
|
text = '{"obligation_text": "Pflicht zur Datenschutz-Folgenabschaetzung"}'
|
|
result = _parse_json(text)
|
|
assert "Datenschutz" in result["obligation_text"]
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationMatch
|
|
# =============================================================================
|
|
|
|
|
|
class TestObligationMatch:
|
|
"""Tests for the ObligationMatch dataclass."""
|
|
|
|
def test_defaults(self):
|
|
match = ObligationMatch()
|
|
assert match.obligation_id is None
|
|
assert match.obligation_title is None
|
|
assert match.obligation_text is None
|
|
assert match.method == "none"
|
|
assert match.confidence == 0.0
|
|
assert match.regulation_id is None
|
|
|
|
def test_to_dict(self):
|
|
match = ObligationMatch(
|
|
obligation_id="DSGVO-OBL-001",
|
|
obligation_title="Verarbeitungsverzeichnis",
|
|
obligation_text="Fuehrung eines Verzeichnisses...",
|
|
method="exact_match",
|
|
confidence=1.0,
|
|
regulation_id="dsgvo",
|
|
)
|
|
d = match.to_dict()
|
|
assert d["obligation_id"] == "DSGVO-OBL-001"
|
|
assert d["method"] == "exact_match"
|
|
assert d["confidence"] == 1.0
|
|
assert d["regulation_id"] == "dsgvo"
|
|
|
|
def test_to_dict_keys(self):
|
|
match = ObligationMatch()
|
|
d = match.to_dict()
|
|
expected_keys = {
|
|
"obligation_id", "obligation_title", "obligation_text",
|
|
"method", "confidence", "regulation_id",
|
|
}
|
|
assert set(d.keys()) == expected_keys
|
|
|
|
def test_to_dict_none_values(self):
|
|
match = ObligationMatch()
|
|
d = match.to_dict()
|
|
assert d["obligation_id"] is None
|
|
assert d["obligation_title"] is None
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: _find_obligations_dir
|
|
# =============================================================================
|
|
|
|
|
|
class TestFindObligationsDir:
|
|
"""Tests for finding the v2 obligations directory."""
|
|
|
|
def test_finds_v2_directory(self):
|
|
"""Should find the v2 dir relative to the source file."""
|
|
result = _find_obligations_dir()
|
|
# May be None in CI without the SDK, but if found, verify it's valid
|
|
if result is not None:
|
|
assert result.is_dir()
|
|
assert (result / "_manifest.json").exists()
|
|
|
|
def test_v2_dir_exists_in_repo(self):
|
|
"""The v2 dir should exist in the repo for local tests."""
|
|
assert V2_DIR.exists(), f"v2 dir not found at {V2_DIR}"
|
|
assert (V2_DIR / "_manifest.json").exists()
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — _load_obligations
|
|
# =============================================================================
|
|
|
|
|
|
class TestObligationExtractorLoad:
|
|
"""Tests for obligation loading from v2 JSON files."""
|
|
|
|
def test_load_obligations_populates_lookup(self):
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
assert len(extractor._obligations) > 0
|
|
|
|
def test_load_obligations_count(self):
|
|
"""Should load all 325 obligations from 9 regulations."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
assert len(extractor._obligations) == 325
|
|
|
|
def test_article_lookup_populated(self):
|
|
"""Article lookup should have entries for obligations with legal_basis."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
assert len(extractor._article_lookup) > 0
|
|
|
|
def test_article_lookup_dsgvo_art30(self):
|
|
"""DSGVO Art. 30 should resolve to DSGVO-OBL-001."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
key = "dsgvo/art. 30"
|
|
assert key in extractor._article_lookup
|
|
assert "DSGVO-OBL-001" in extractor._article_lookup[key]
|
|
|
|
def test_obligations_have_required_fields(self):
|
|
"""Every loaded obligation should have id, title, description, regulation_id."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
for obl_id, entry in extractor._obligations.items():
|
|
assert entry.id == obl_id
|
|
assert entry.title, f"{obl_id}: empty title"
|
|
assert entry.description, f"{obl_id}: empty description"
|
|
assert entry.regulation_id, f"{obl_id}: empty regulation_id"
|
|
|
|
def test_all_nine_regulations_loaded(self):
|
|
"""All 9 regulations from the manifest should be loaded."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
regulation_ids = {e.regulation_id for e in extractor._obligations.values()}
|
|
expected = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa",
|
|
"data_act", "eu_machinery", "dora"}
|
|
assert regulation_ids == expected
|
|
|
|
def test_obligation_id_format(self):
|
|
"""All obligation IDs should follow the pattern {REG}-OBL-{NNN}."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
import re
|
|
# Allow letters, digits, underscores in prefix (e.g. NIS2-OBL-001, EU_MACHINERY-OBL-001)
|
|
pattern = re.compile(r"^[A-Z0-9_]+-OBL-\d{3}$")
|
|
for obl_id in extractor._obligations:
|
|
assert pattern.match(obl_id), f"Invalid obligation ID format: {obl_id}"
|
|
|
|
def test_no_duplicate_obligation_ids(self):
|
|
"""All obligation IDs should be unique."""
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
ids = list(extractor._obligations.keys())
|
|
assert len(ids) == len(set(ids))
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — Tier 1 (Exact Match)
|
|
# =============================================================================
|
|
|
|
|
|
class TestTier1ExactMatch:
|
|
"""Tests for Tier 1 exact article lookup."""
|
|
|
|
def setup_method(self):
|
|
self.extractor = ObligationExtractor()
|
|
self.extractor._load_obligations()
|
|
|
|
def test_exact_match_dsgvo_art30(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
|
assert match is not None
|
|
assert match.obligation_id == "DSGVO-OBL-001"
|
|
assert match.method == "exact_match"
|
|
assert match.confidence == 1.0
|
|
assert match.regulation_id == "dsgvo"
|
|
|
|
def test_exact_match_case_insensitive_article(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "ART. 30")
|
|
assert match is not None
|
|
assert match.obligation_id == "DSGVO-OBL-001"
|
|
|
|
def test_exact_match_article_variant(self):
|
|
"""'Article 30' should normalize to 'art. 30' and match."""
|
|
match = self.extractor._tier1_exact("dsgvo", "Article 30")
|
|
assert match is not None
|
|
assert match.obligation_id == "DSGVO-OBL-001"
|
|
|
|
def test_exact_match_artikel_variant(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "Artikel 30")
|
|
assert match is not None
|
|
assert match.obligation_id == "DSGVO-OBL-001"
|
|
|
|
def test_exact_match_strips_absatz(self):
|
|
"""Art. 30 Abs. 1 → art. 30 → should match."""
|
|
match = self.extractor._tier1_exact("dsgvo", "Art. 30 Abs. 1")
|
|
assert match is not None
|
|
assert match.obligation_id == "DSGVO-OBL-001"
|
|
|
|
def test_no_match_wrong_article(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "Art. 999")
|
|
assert match is None
|
|
|
|
def test_no_match_unknown_regulation(self):
|
|
match = self.extractor._tier1_exact("unknown_reg", "Art. 30")
|
|
assert match is None
|
|
|
|
def test_no_match_none_regulation(self):
|
|
match = self.extractor._tier1_exact(None, "Art. 30")
|
|
assert match is None
|
|
|
|
def test_match_has_title(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
|
assert match is not None
|
|
assert match.obligation_title is not None
|
|
assert len(match.obligation_title) > 0
|
|
|
|
def test_match_has_text(self):
|
|
match = self.extractor._tier1_exact("dsgvo", "Art. 30")
|
|
assert match is not None
|
|
assert match.obligation_text is not None
|
|
assert len(match.obligation_text) > 20
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — Tier 2 (Embedding Match)
|
|
# =============================================================================
|
|
|
|
|
|
class TestTier2EmbeddingMatch:
|
|
"""Tests for Tier 2 embedding-based matching."""
|
|
|
|
def setup_method(self):
|
|
self.extractor = ObligationExtractor()
|
|
self.extractor._load_obligations()
|
|
# Prepare fake embeddings for testing (no real embedding service)
|
|
self.extractor._obligation_ids = list(self.extractor._obligations.keys())
|
|
# Create simple 3D embeddings per obligation — avoid zero vectors
|
|
self.extractor._obligation_embeddings = []
|
|
for i in range(len(self.extractor._obligation_ids)):
|
|
# Each obligation gets a unique-ish non-zero vector
|
|
self.extractor._obligation_embeddings.append(
|
|
[float(i % 10 + 1), float((i * 3) % 10 + 1), float((i * 7) % 10 + 1)]
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_match_above_threshold(self):
|
|
"""When cosine > 0.80, should return embedding_match."""
|
|
# Mock the embedding service to return a vector very similar to obligation 0
|
|
target_embedding = self.extractor._obligation_embeddings[0]
|
|
|
|
with patch(
|
|
"compliance.services.obligation_extractor._get_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=target_embedding,
|
|
):
|
|
match = await self.extractor._tier2_embedding("test text", "dsgvo")
|
|
|
|
# Should find a match (cosine = 1.0 for identical vector)
|
|
assert match is not None
|
|
assert match.method == "embedding_match"
|
|
assert match.confidence >= EMBEDDING_MATCH_THRESHOLD
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_match_returns_none_below_threshold(self):
|
|
"""When cosine < 0.80, should return None."""
|
|
# Return a vector orthogonal to all obligations
|
|
orthogonal = [100.0, -100.0, 0.0]
|
|
|
|
with patch(
|
|
"compliance.services.obligation_extractor._get_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=orthogonal,
|
|
):
|
|
match = await self.extractor._tier2_embedding("unrelated text", None)
|
|
|
|
# May or may not match depending on vector distribution
|
|
# But we can verify it's either None or has correct method
|
|
if match is not None:
|
|
assert match.method == "embedding_match"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_match_empty_embeddings(self):
|
|
"""When no embeddings loaded, should return None."""
|
|
self.extractor._obligation_embeddings = []
|
|
match = await self.extractor._tier2_embedding("any text", "dsgvo")
|
|
assert match is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_match_failed_embedding(self):
|
|
"""When embedding service returns empty, should return None."""
|
|
with patch(
|
|
"compliance.services.obligation_extractor._get_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=[],
|
|
):
|
|
match = await self.extractor._tier2_embedding("some text", "dsgvo")
|
|
assert match is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_bonus_same_regulation(self):
|
|
"""Matching regulation should add +0.05 bonus."""
|
|
# Set up two obligations with same embeddings but different regulations
|
|
self.extractor._obligation_ids = ["DSGVO-OBL-001", "NIS2-OBL-001"]
|
|
self.extractor._obligation_embeddings = [
|
|
[1.0, 0.0, 0.0],
|
|
[1.0, 0.0, 0.0],
|
|
]
|
|
|
|
with patch(
|
|
"compliance.services.obligation_extractor._get_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=[1.0, 0.0, 0.0],
|
|
):
|
|
match = await self.extractor._tier2_embedding("test", "dsgvo")
|
|
|
|
# Should match (cosine = 1.0 ≥ 0.80)
|
|
assert match is not None
|
|
assert match.method == "embedding_match"
|
|
# With domain bonus, DSGVO should be preferred
|
|
assert match.regulation_id == "dsgvo"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_confidence_capped_at_1(self):
|
|
"""Confidence should not exceed 1.0 even with domain bonus."""
|
|
self.extractor._obligation_ids = ["DSGVO-OBL-001"]
|
|
self.extractor._obligation_embeddings = [[1.0, 0.0, 0.0]]
|
|
|
|
with patch(
|
|
"compliance.services.obligation_extractor._get_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=[1.0, 0.0, 0.0],
|
|
):
|
|
match = await self.extractor._tier2_embedding("test", "dsgvo")
|
|
|
|
assert match is not None
|
|
assert match.confidence <= 1.0
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — Tier 3 (LLM Extraction)
|
|
# =============================================================================
|
|
|
|
|
|
class TestTier3LLMExtraction:
|
|
"""Tests for Tier 3 LLM-based obligation extraction."""
|
|
|
|
def setup_method(self):
|
|
self.extractor = ObligationExtractor()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_extraction_success(self):
|
|
"""Successful LLM extraction returns obligation_text with confidence 0.60."""
|
|
llm_response = json.dumps({
|
|
"obligation_text": "Pflicht zur Fuehrung eines Verarbeitungsverzeichnisses",
|
|
"actor": "Verantwortlicher",
|
|
"action": "Verarbeitungsverzeichnis fuehren",
|
|
"normative_strength": "muss",
|
|
})
|
|
|
|
with patch(
|
|
"compliance.services.obligation_extractor._llm_ollama",
|
|
new_callable=AsyncMock,
|
|
return_value=llm_response,
|
|
):
|
|
match = await self.extractor._tier3_llm(
|
|
"Der Verantwortliche fuehrt ein Verzeichnis...",
|
|
"eu_2016_679",
|
|
"Art. 30",
|
|
)
|
|
|
|
assert match.method == "llm_extracted"
|
|
assert match.confidence == 0.60
|
|
assert "Verarbeitungsverzeichnis" in match.obligation_text
|
|
assert match.obligation_id is None # LLM doesn't assign IDs
|
|
assert match.regulation_id == "dsgvo"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_extraction_failure(self):
|
|
"""When LLM returns empty, should return match with confidence 0."""
|
|
with patch(
|
|
"compliance.services.obligation_extractor._llm_ollama",
|
|
new_callable=AsyncMock,
|
|
return_value="",
|
|
):
|
|
match = await self.extractor._tier3_llm("some text", "dsgvo", "Art. 1")
|
|
|
|
assert match.method == "llm_extracted"
|
|
assert match.confidence == 0.0
|
|
assert match.obligation_text is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_extraction_malformed_json(self):
|
|
"""When LLM returns non-JSON, should use raw text as fallback."""
|
|
with patch(
|
|
"compliance.services.obligation_extractor._llm_ollama",
|
|
new_callable=AsyncMock,
|
|
return_value="Dies ist die Pflicht: Daten schuetzen",
|
|
):
|
|
match = await self.extractor._tier3_llm("some text", "dsgvo", None)
|
|
|
|
assert match.method == "llm_extracted"
|
|
assert match.confidence == 0.60
|
|
# Fallback: uses first 500 chars of response as obligation_text
|
|
assert "Pflicht" in match.obligation_text or "Daten" in match.obligation_text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_regulation_normalization(self):
|
|
"""Regulation code should be normalized in result."""
|
|
with patch(
|
|
"compliance.services.obligation_extractor._llm_ollama",
|
|
new_callable=AsyncMock,
|
|
return_value='{"obligation_text": "Test"}',
|
|
):
|
|
match = await self.extractor._tier3_llm(
|
|
"text", "eu_2024_1689", "Art. 6"
|
|
)
|
|
|
|
assert match.regulation_id == "ai_act"
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — Full 3-Tier extract()
|
|
# =============================================================================
|
|
|
|
|
|
class TestExtractFullFlow:
|
|
"""Tests for the full 3-tier extraction flow."""
|
|
|
|
def setup_method(self):
|
|
self.extractor = ObligationExtractor()
|
|
self.extractor._load_obligations()
|
|
# Mark as initialized to skip async initialize
|
|
self.extractor._initialized = True
|
|
# Empty embeddings — Tier 2 will return None
|
|
self.extractor._obligation_embeddings = []
|
|
self.extractor._obligation_ids = []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tier1_takes_priority(self):
|
|
"""When Tier 1 matches, Tier 2 and 3 should not be called."""
|
|
with patch.object(
|
|
self.extractor, "_tier2_embedding", new_callable=AsyncMock
|
|
) as mock_t2, patch.object(
|
|
self.extractor, "_tier3_llm", new_callable=AsyncMock
|
|
) as mock_t3:
|
|
match = await self.extractor.extract(
|
|
chunk_text="irrelevant",
|
|
regulation_code="eu_2016_679",
|
|
article="Art. 30",
|
|
)
|
|
|
|
assert match.method == "exact_match"
|
|
mock_t2.assert_not_called()
|
|
mock_t3.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tier2_when_tier1_misses(self):
|
|
"""When Tier 1 misses, Tier 2 should be tried."""
|
|
tier2_result = ObligationMatch(
|
|
obligation_id="DSGVO-OBL-050",
|
|
method="embedding_match",
|
|
confidence=0.85,
|
|
regulation_id="dsgvo",
|
|
)
|
|
|
|
with patch.object(
|
|
self.extractor, "_tier2_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=tier2_result,
|
|
) as mock_t2, patch.object(
|
|
self.extractor, "_tier3_llm", new_callable=AsyncMock
|
|
) as mock_t3:
|
|
match = await self.extractor.extract(
|
|
chunk_text="some compliance text",
|
|
regulation_code="eu_2016_679",
|
|
article="Art. 999", # Non-matching article
|
|
)
|
|
|
|
assert match.method == "embedding_match"
|
|
mock_t2.assert_called_once()
|
|
mock_t3.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tier3_when_tier1_and_2_miss(self):
|
|
"""When Tier 1 and 2 miss, Tier 3 should be called."""
|
|
tier3_result = ObligationMatch(
|
|
obligation_text="LLM extracted obligation",
|
|
method="llm_extracted",
|
|
confidence=0.60,
|
|
)
|
|
|
|
with patch.object(
|
|
self.extractor, "_tier2_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=None,
|
|
), patch.object(
|
|
self.extractor, "_tier3_llm",
|
|
new_callable=AsyncMock,
|
|
return_value=tier3_result,
|
|
):
|
|
match = await self.extractor.extract(
|
|
chunk_text="unrelated text",
|
|
regulation_code="unknown_reg",
|
|
article="Art. 999",
|
|
)
|
|
|
|
assert match.method == "llm_extracted"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_article_skips_tier1(self):
|
|
"""When no article is provided, Tier 1 should be skipped."""
|
|
with patch.object(
|
|
self.extractor, "_tier2_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=None,
|
|
) as mock_t2, patch.object(
|
|
self.extractor, "_tier3_llm",
|
|
new_callable=AsyncMock,
|
|
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
|
|
):
|
|
match = await self.extractor.extract(
|
|
chunk_text="some text",
|
|
regulation_code="dsgvo",
|
|
article=None,
|
|
)
|
|
|
|
# Tier 2 should be called (Tier 1 skipped due to no article)
|
|
mock_t2.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auto_initialize(self):
|
|
"""If not initialized, extract should call initialize()."""
|
|
extractor = ObligationExtractor()
|
|
assert not extractor._initialized
|
|
|
|
with patch.object(
|
|
extractor, "initialize", new_callable=AsyncMock
|
|
) as mock_init:
|
|
# After mock init, set initialized to True
|
|
async def side_effect():
|
|
extractor._initialized = True
|
|
extractor._load_obligations()
|
|
extractor._obligation_embeddings = []
|
|
extractor._obligation_ids = []
|
|
|
|
mock_init.side_effect = side_effect
|
|
|
|
with patch.object(
|
|
extractor, "_tier2_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=None,
|
|
), patch.object(
|
|
extractor, "_tier3_llm",
|
|
new_callable=AsyncMock,
|
|
return_value=ObligationMatch(method="llm_extracted", confidence=0.60),
|
|
):
|
|
await extractor.extract(
|
|
chunk_text="test",
|
|
regulation_code="dsgvo",
|
|
article=None,
|
|
)
|
|
|
|
mock_init.assert_called_once()
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: ObligationExtractor — stats()
|
|
# =============================================================================
|
|
|
|
|
|
class TestExtractorStats:
|
|
"""Tests for the stats() method."""
|
|
|
|
def test_stats_before_init(self):
|
|
extractor = ObligationExtractor()
|
|
stats = extractor.stats()
|
|
assert stats["total_obligations"] == 0
|
|
assert stats["article_lookups"] == 0
|
|
assert stats["initialized"] is False
|
|
|
|
def test_stats_after_load(self):
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
stats = extractor.stats()
|
|
assert stats["total_obligations"] == 325
|
|
assert stats["article_lookups"] > 0
|
|
assert "dsgvo" in stats["regulations"]
|
|
assert stats["initialized"] is False # not fully initialized (no embeddings)
|
|
|
|
def test_stats_regulations_complete(self):
|
|
extractor = ObligationExtractor()
|
|
extractor._load_obligations()
|
|
stats = extractor.stats()
|
|
expected_regs = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg",
|
|
"dsa", "data_act", "eu_machinery", "dora"}
|
|
assert set(stats["regulations"]) == expected_regs
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Integration — Regulation-to-Obligation mapping coverage
|
|
# =============================================================================
|
|
|
|
|
|
class TestRegulationObligationCoverage:
|
|
"""Verify that the article lookup provides reasonable coverage."""
|
|
|
|
def setup_method(self):
|
|
self.extractor = ObligationExtractor()
|
|
self.extractor._load_obligations()
|
|
|
|
def test_dsgvo_has_article_lookups(self):
|
|
"""DSGVO (80 obligations) should have many article lookups."""
|
|
dsgvo_keys = [k for k in self.extractor._article_lookup if k.startswith("dsgvo/")]
|
|
assert len(dsgvo_keys) >= 20, f"Only {len(dsgvo_keys)} DSGVO article lookups"
|
|
|
|
def test_ai_act_has_article_lookups(self):
|
|
ai_keys = [k for k in self.extractor._article_lookup if k.startswith("ai_act/")]
|
|
assert len(ai_keys) >= 10, f"Only {len(ai_keys)} AI Act article lookups"
|
|
|
|
def test_nis2_has_article_lookups(self):
|
|
nis2_keys = [k for k in self.extractor._article_lookup if k.startswith("nis2/")]
|
|
assert len(nis2_keys) >= 5, f"Only {len(nis2_keys)} NIS2 article lookups"
|
|
|
|
def test_all_article_lookup_values_are_valid(self):
|
|
"""Every obligation ID in article_lookup should exist in _obligations."""
|
|
for key, obl_ids in self.extractor._article_lookup.items():
|
|
for obl_id in obl_ids:
|
|
assert obl_id in self.extractor._obligations, (
|
|
f"Article lookup {key} references missing obligation {obl_id}"
|
|
)
|
|
|
|
def test_article_lookup_key_format(self):
|
|
"""All keys should be in format 'regulation_id/normalized_article'."""
|
|
for key in self.extractor._article_lookup:
|
|
parts = key.split("/", 1)
|
|
assert len(parts) == 2, f"Invalid key format: {key}"
|
|
reg_id, article = parts
|
|
assert reg_id, f"Empty regulation ID in key: {key}"
|
|
assert article, f"Empty article in key: {key}"
|
|
assert article == article.lower(), f"Article not lowercase: {key}"
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Constants and thresholds
|
|
# =============================================================================
|
|
|
|
|
|
class TestConstants:
|
|
"""Tests for module-level constants."""
|
|
|
|
def test_embedding_thresholds_ordering(self):
|
|
"""Match threshold should be higher than candidate threshold."""
|
|
assert EMBEDDING_MATCH_THRESHOLD > EMBEDDING_CANDIDATE_THRESHOLD
|
|
|
|
def test_embedding_thresholds_range(self):
|
|
"""Thresholds should be between 0 and 1."""
|
|
assert 0 < EMBEDDING_MATCH_THRESHOLD <= 1.0
|
|
assert 0 < EMBEDDING_CANDIDATE_THRESHOLD <= 1.0
|
|
|
|
def test_match_threshold_is_80(self):
|
|
assert EMBEDDING_MATCH_THRESHOLD == 0.80
|
|
|
|
def test_candidate_threshold_is_60(self):
|
|
assert EMBEDDING_CANDIDATE_THRESHOLD == 0.60
|