"""Tests for Obligation Extractor — Phase 4 of Multi-Layer Control Architecture. Validates: - Regulation code normalization (_normalize_regulation) - Article reference normalization (_normalize_article) - Cosine similarity (_cosine_sim) - JSON parsing from LLM responses (_parse_json) - Obligation loading from v2 framework - 3-Tier extraction: exact_match → embedding_match → llm_extracted - ObligationMatch serialization - Edge cases: empty inputs, missing data, fallback behavior """ import json import math from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from compliance.services.obligation_extractor import ( EMBEDDING_CANDIDATE_THRESHOLD, EMBEDDING_MATCH_THRESHOLD, ObligationExtractor, ObligationMatch, _ObligationEntry, _cosine_sim, _find_obligations_dir, _normalize_article, _normalize_regulation, _parse_json, ) REPO_ROOT = Path(__file__).resolve().parent.parent.parent V2_DIR = REPO_ROOT / "ai-compliance-sdk" / "policies" / "obligations" / "v2" # ============================================================================= # Tests: _normalize_regulation # ============================================================================= class TestNormalizeRegulation: """Tests for regulation code normalization.""" def test_dsgvo_eu_code(self): assert _normalize_regulation("eu_2016_679") == "dsgvo" def test_dsgvo_short(self): assert _normalize_regulation("dsgvo") == "dsgvo" def test_gdpr_alias(self): assert _normalize_regulation("gdpr") == "dsgvo" def test_ai_act_eu_code(self): assert _normalize_regulation("eu_2024_1689") == "ai_act" def test_ai_act_short(self): assert _normalize_regulation("ai_act") == "ai_act" def test_nis2_eu_code(self): assert _normalize_regulation("eu_2022_2555") == "nis2" def test_nis2_short(self): assert _normalize_regulation("nis2") == "nis2" def test_bsig_alias(self): assert _normalize_regulation("bsig") == "nis2" def test_bdsg(self): assert _normalize_regulation("bdsg") == "bdsg" def test_ttdsg(self): assert _normalize_regulation("ttdsg") == "ttdsg" def test_dsa_eu_code(self): assert _normalize_regulation("eu_2022_2065") == "dsa" def test_data_act_eu_code(self): assert _normalize_regulation("eu_2023_2854") == "data_act" def test_eu_machinery_eu_code(self): assert _normalize_regulation("eu_2023_1230") == "eu_machinery" def test_dora_eu_code(self): assert _normalize_regulation("eu_2022_2554") == "dora" def test_case_insensitive(self): assert _normalize_regulation("DSGVO") == "dsgvo" assert _normalize_regulation("AI_ACT") == "ai_act" assert _normalize_regulation("NIS2") == "nis2" def test_whitespace_stripped(self): assert _normalize_regulation(" dsgvo ") == "dsgvo" def test_empty_string(self): assert _normalize_regulation("") is None def test_none(self): assert _normalize_regulation(None) is None def test_unknown_code(self): assert _normalize_regulation("mica") is None def test_prefix_matching(self): """EU codes with suffixes should still match via prefix.""" assert _normalize_regulation("eu_2016_679_consolidated") == "dsgvo" def test_all_nine_regulations_covered(self): """Every regulation in the manifest should be normalizable.""" regulation_ids = ["dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa", "data_act", "eu_machinery", "dora"] for reg_id in regulation_ids: result = _normalize_regulation(reg_id) assert result == reg_id, f"Regulation {reg_id} not found" # ============================================================================= # Tests: _normalize_article # ============================================================================= class TestNormalizeArticle: """Tests for article reference normalization.""" def test_art_with_dot(self): assert _normalize_article("Art. 30") == "art. 30" def test_article_english(self): assert _normalize_article("Article 10") == "art. 10" def test_artikel_german(self): assert _normalize_article("Artikel 35") == "art. 35" def test_paragraph_symbol(self): assert _normalize_article("§ 38") == "§ 38" def test_paragraph_with_law_suffix(self): """§ 38 BDSG → § 38 (law name stripped).""" assert _normalize_article("§ 38 BDSG") == "§ 38" def test_paragraph_with_dsgvo_suffix(self): assert _normalize_article("Art. 6 DSGVO") == "art. 6" def test_removes_absatz(self): """Art. 30 Abs. 1 → art. 30""" assert _normalize_article("Art. 30 Abs. 1") == "art. 30" def test_removes_paragraph(self): assert _normalize_article("Art. 5 paragraph 2") == "art. 5" def test_removes_lit(self): assert _normalize_article("Art. 6 lit. a") == "art. 6" def test_removes_satz(self): assert _normalize_article("Art. 12 Satz 3") == "art. 12" def test_lowercase_output(self): assert _normalize_article("ART. 30") == "art. 30" assert _normalize_article("ARTICLE 10") == "art. 10" def test_whitespace_stripped(self): assert _normalize_article(" Art. 30 ") == "art. 30" def test_empty_string(self): assert _normalize_article("") == "" def test_none(self): assert _normalize_article(None) == "" def test_complex_reference(self): """Art. 30 Abs. 1 Satz 2 lit. c DSGVO → art. 30""" result = _normalize_article("Art. 30 Abs. 1 Satz 2 lit. c DSGVO") # Should at minimum remove DSGVO and Abs references assert result.startswith("art. 30") def test_nis2_article(self): assert _normalize_article("Art. 21 NIS2") == "art. 21" def test_dora_article(self): assert _normalize_article("Art. 5 DORA") == "art. 5" def test_ai_act_article(self): result = _normalize_article("Article 6 AI Act") assert result == "art. 6" # ============================================================================= # Tests: _cosine_sim # ============================================================================= class TestCosineSim: """Tests for cosine similarity calculation.""" def test_identical_vectors(self): v = [1.0, 2.0, 3.0] assert abs(_cosine_sim(v, v) - 1.0) < 1e-6 def test_orthogonal_vectors(self): a = [1.0, 0.0] b = [0.0, 1.0] assert abs(_cosine_sim(a, b)) < 1e-6 def test_opposite_vectors(self): a = [1.0, 2.0, 3.0] b = [-1.0, -2.0, -3.0] assert abs(_cosine_sim(a, b) - (-1.0)) < 1e-6 def test_known_value(self): a = [1.0, 0.0] b = [1.0, 1.0] expected = 1.0 / math.sqrt(2) assert abs(_cosine_sim(a, b) - expected) < 1e-6 def test_empty_vectors(self): assert _cosine_sim([], []) == 0.0 def test_one_empty(self): assert _cosine_sim([1.0, 2.0], []) == 0.0 assert _cosine_sim([], [1.0, 2.0]) == 0.0 def test_different_lengths(self): assert _cosine_sim([1.0, 2.0], [1.0]) == 0.0 def test_zero_vector(self): assert _cosine_sim([0.0, 0.0], [1.0, 2.0]) == 0.0 def test_both_zero(self): assert _cosine_sim([0.0, 0.0], [0.0, 0.0]) == 0.0 def test_high_dimensional(self): """Test with realistic embedding dimensions (1024).""" import random random.seed(42) a = [random.gauss(0, 1) for _ in range(1024)] b = [random.gauss(0, 1) for _ in range(1024)] score = _cosine_sim(a, b) assert -1.0 <= score <= 1.0 # ============================================================================= # Tests: _parse_json # ============================================================================= class TestParseJson: """Tests for JSON extraction from LLM responses.""" def test_direct_json(self): text = '{"obligation_text": "Test", "actor": "Controller"}' result = _parse_json(text) assert result["obligation_text"] == "Test" assert result["actor"] == "Controller" def test_json_in_markdown_block(self): """LLMs often wrap JSON in markdown code blocks.""" text = '''Some explanation text ```json {"obligation_text": "Test"} ``` More text''' result = _parse_json(text) assert result.get("obligation_text") == "Test" def test_json_with_prefix_text(self): text = 'Here is the result: {"obligation_text": "Pflicht", "actor": "Verantwortlicher"}' result = _parse_json(text) assert result["obligation_text"] == "Pflicht" def test_invalid_json(self): result = _parse_json("not json at all") assert result == {} def test_empty_string(self): result = _parse_json("") assert result == {} def test_nested_braces_picks_first(self): """With nested objects, the regex picks the inner simple object.""" text = '{"outer": {"inner": "value"}}' result = _parse_json(text) # Direct parse should work for valid nested JSON assert "outer" in result def test_json_with_german_umlauts(self): text = '{"obligation_text": "Pflicht zur Datenschutz-Folgenabschaetzung"}' result = _parse_json(text) assert "Datenschutz" in result["obligation_text"] # ============================================================================= # Tests: ObligationMatch # ============================================================================= class TestObligationMatch: """Tests for the ObligationMatch dataclass.""" def test_defaults(self): match = ObligationMatch() assert match.obligation_id is None assert match.obligation_title is None assert match.obligation_text is None assert match.method == "none" assert match.confidence == 0.0 assert match.regulation_id is None def test_to_dict(self): match = ObligationMatch( obligation_id="DSGVO-OBL-001", obligation_title="Verarbeitungsverzeichnis", obligation_text="Fuehrung eines Verzeichnisses...", method="exact_match", confidence=1.0, regulation_id="dsgvo", ) d = match.to_dict() assert d["obligation_id"] == "DSGVO-OBL-001" assert d["method"] == "exact_match" assert d["confidence"] == 1.0 assert d["regulation_id"] == "dsgvo" def test_to_dict_keys(self): match = ObligationMatch() d = match.to_dict() expected_keys = { "obligation_id", "obligation_title", "obligation_text", "method", "confidence", "regulation_id", } assert set(d.keys()) == expected_keys def test_to_dict_none_values(self): match = ObligationMatch() d = match.to_dict() assert d["obligation_id"] is None assert d["obligation_title"] is None # ============================================================================= # Tests: _find_obligations_dir # ============================================================================= class TestFindObligationsDir: """Tests for finding the v2 obligations directory.""" def test_finds_v2_directory(self): """Should find the v2 dir relative to the source file.""" result = _find_obligations_dir() # May be None in CI without the SDK, but if found, verify it's valid if result is not None: assert result.is_dir() assert (result / "_manifest.json").exists() def test_v2_dir_exists_in_repo(self): """The v2 dir should exist in the repo for local tests.""" assert V2_DIR.exists(), f"v2 dir not found at {V2_DIR}" assert (V2_DIR / "_manifest.json").exists() # ============================================================================= # Tests: ObligationExtractor — _load_obligations # ============================================================================= class TestObligationExtractorLoad: """Tests for obligation loading from v2 JSON files.""" def test_load_obligations_populates_lookup(self): extractor = ObligationExtractor() extractor._load_obligations() assert len(extractor._obligations) > 0 def test_load_obligations_count(self): """Should load all 325 obligations from 9 regulations.""" extractor = ObligationExtractor() extractor._load_obligations() assert len(extractor._obligations) == 325 def test_article_lookup_populated(self): """Article lookup should have entries for obligations with legal_basis.""" extractor = ObligationExtractor() extractor._load_obligations() assert len(extractor._article_lookup) > 0 def test_article_lookup_dsgvo_art30(self): """DSGVO Art. 30 should resolve to DSGVO-OBL-001.""" extractor = ObligationExtractor() extractor._load_obligations() key = "dsgvo/art. 30" assert key in extractor._article_lookup assert "DSGVO-OBL-001" in extractor._article_lookup[key] def test_obligations_have_required_fields(self): """Every loaded obligation should have id, title, description, regulation_id.""" extractor = ObligationExtractor() extractor._load_obligations() for obl_id, entry in extractor._obligations.items(): assert entry.id == obl_id assert entry.title, f"{obl_id}: empty title" assert entry.description, f"{obl_id}: empty description" assert entry.regulation_id, f"{obl_id}: empty regulation_id" def test_all_nine_regulations_loaded(self): """All 9 regulations from the manifest should be loaded.""" extractor = ObligationExtractor() extractor._load_obligations() regulation_ids = {e.regulation_id for e in extractor._obligations.values()} expected = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa", "data_act", "eu_machinery", "dora"} assert regulation_ids == expected def test_obligation_id_format(self): """All obligation IDs should follow the pattern {REG}-OBL-{NNN}.""" extractor = ObligationExtractor() extractor._load_obligations() import re # Allow letters, digits, underscores in prefix (e.g. NIS2-OBL-001, EU_MACHINERY-OBL-001) pattern = re.compile(r"^[A-Z0-9_]+-OBL-\d{3}$") for obl_id in extractor._obligations: assert pattern.match(obl_id), f"Invalid obligation ID format: {obl_id}" def test_no_duplicate_obligation_ids(self): """All obligation IDs should be unique.""" extractor = ObligationExtractor() extractor._load_obligations() ids = list(extractor._obligations.keys()) assert len(ids) == len(set(ids)) # ============================================================================= # Tests: ObligationExtractor — Tier 1 (Exact Match) # ============================================================================= class TestTier1ExactMatch: """Tests for Tier 1 exact article lookup.""" def setup_method(self): self.extractor = ObligationExtractor() self.extractor._load_obligations() def test_exact_match_dsgvo_art30(self): match = self.extractor._tier1_exact("dsgvo", "Art. 30") assert match is not None assert match.obligation_id == "DSGVO-OBL-001" assert match.method == "exact_match" assert match.confidence == 1.0 assert match.regulation_id == "dsgvo" def test_exact_match_case_insensitive_article(self): match = self.extractor._tier1_exact("dsgvo", "ART. 30") assert match is not None assert match.obligation_id == "DSGVO-OBL-001" def test_exact_match_article_variant(self): """'Article 30' should normalize to 'art. 30' and match.""" match = self.extractor._tier1_exact("dsgvo", "Article 30") assert match is not None assert match.obligation_id == "DSGVO-OBL-001" def test_exact_match_artikel_variant(self): match = self.extractor._tier1_exact("dsgvo", "Artikel 30") assert match is not None assert match.obligation_id == "DSGVO-OBL-001" def test_exact_match_strips_absatz(self): """Art. 30 Abs. 1 → art. 30 → should match.""" match = self.extractor._tier1_exact("dsgvo", "Art. 30 Abs. 1") assert match is not None assert match.obligation_id == "DSGVO-OBL-001" def test_no_match_wrong_article(self): match = self.extractor._tier1_exact("dsgvo", "Art. 999") assert match is None def test_no_match_unknown_regulation(self): match = self.extractor._tier1_exact("unknown_reg", "Art. 30") assert match is None def test_no_match_none_regulation(self): match = self.extractor._tier1_exact(None, "Art. 30") assert match is None def test_match_has_title(self): match = self.extractor._tier1_exact("dsgvo", "Art. 30") assert match is not None assert match.obligation_title is not None assert len(match.obligation_title) > 0 def test_match_has_text(self): match = self.extractor._tier1_exact("dsgvo", "Art. 30") assert match is not None assert match.obligation_text is not None assert len(match.obligation_text) > 20 # ============================================================================= # Tests: ObligationExtractor — Tier 2 (Embedding Match) # ============================================================================= class TestTier2EmbeddingMatch: """Tests for Tier 2 embedding-based matching.""" def setup_method(self): self.extractor = ObligationExtractor() self.extractor._load_obligations() # Prepare fake embeddings for testing (no real embedding service) self.extractor._obligation_ids = list(self.extractor._obligations.keys()) # Create simple 3D embeddings per obligation — avoid zero vectors self.extractor._obligation_embeddings = [] for i in range(len(self.extractor._obligation_ids)): # Each obligation gets a unique-ish non-zero vector self.extractor._obligation_embeddings.append( [float(i % 10 + 1), float((i * 3) % 10 + 1), float((i * 7) % 10 + 1)] ) @pytest.mark.asyncio async def test_embedding_match_above_threshold(self): """When cosine > 0.80, should return embedding_match.""" # Mock the embedding service to return a vector very similar to obligation 0 target_embedding = self.extractor._obligation_embeddings[0] with patch( "compliance.services.obligation_extractor._get_embedding", new_callable=AsyncMock, return_value=target_embedding, ): match = await self.extractor._tier2_embedding("test text", "dsgvo") # Should find a match (cosine = 1.0 for identical vector) assert match is not None assert match.method == "embedding_match" assert match.confidence >= EMBEDDING_MATCH_THRESHOLD @pytest.mark.asyncio async def test_embedding_match_returns_none_below_threshold(self): """When cosine < 0.80, should return None.""" # Return a vector orthogonal to all obligations orthogonal = [100.0, -100.0, 0.0] with patch( "compliance.services.obligation_extractor._get_embedding", new_callable=AsyncMock, return_value=orthogonal, ): match = await self.extractor._tier2_embedding("unrelated text", None) # May or may not match depending on vector distribution # But we can verify it's either None or has correct method if match is not None: assert match.method == "embedding_match" @pytest.mark.asyncio async def test_embedding_match_empty_embeddings(self): """When no embeddings loaded, should return None.""" self.extractor._obligation_embeddings = [] match = await self.extractor._tier2_embedding("any text", "dsgvo") assert match is None @pytest.mark.asyncio async def test_embedding_match_failed_embedding(self): """When embedding service returns empty, should return None.""" with patch( "compliance.services.obligation_extractor._get_embedding", new_callable=AsyncMock, return_value=[], ): match = await self.extractor._tier2_embedding("some text", "dsgvo") assert match is None @pytest.mark.asyncio async def test_domain_bonus_same_regulation(self): """Matching regulation should add +0.05 bonus.""" # Set up two obligations with same embeddings but different regulations self.extractor._obligation_ids = ["DSGVO-OBL-001", "NIS2-OBL-001"] self.extractor._obligation_embeddings = [ [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], ] with patch( "compliance.services.obligation_extractor._get_embedding", new_callable=AsyncMock, return_value=[1.0, 0.0, 0.0], ): match = await self.extractor._tier2_embedding("test", "dsgvo") # Should match (cosine = 1.0 ≥ 0.80) assert match is not None assert match.method == "embedding_match" # With domain bonus, DSGVO should be preferred assert match.regulation_id == "dsgvo" @pytest.mark.asyncio async def test_confidence_capped_at_1(self): """Confidence should not exceed 1.0 even with domain bonus.""" self.extractor._obligation_ids = ["DSGVO-OBL-001"] self.extractor._obligation_embeddings = [[1.0, 0.0, 0.0]] with patch( "compliance.services.obligation_extractor._get_embedding", new_callable=AsyncMock, return_value=[1.0, 0.0, 0.0], ): match = await self.extractor._tier2_embedding("test", "dsgvo") assert match is not None assert match.confidence <= 1.0 # ============================================================================= # Tests: ObligationExtractor — Tier 3 (LLM Extraction) # ============================================================================= class TestTier3LLMExtraction: """Tests for Tier 3 LLM-based obligation extraction.""" def setup_method(self): self.extractor = ObligationExtractor() @pytest.mark.asyncio async def test_llm_extraction_success(self): """Successful LLM extraction returns obligation_text with confidence 0.60.""" llm_response = json.dumps({ "obligation_text": "Pflicht zur Fuehrung eines Verarbeitungsverzeichnisses", "actor": "Verantwortlicher", "action": "Verarbeitungsverzeichnis fuehren", "normative_strength": "muss", }) with patch( "compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock, return_value=llm_response, ): match = await self.extractor._tier3_llm( "Der Verantwortliche fuehrt ein Verzeichnis...", "eu_2016_679", "Art. 30", ) assert match.method == "llm_extracted" assert match.confidence == 0.60 assert "Verarbeitungsverzeichnis" in match.obligation_text assert match.obligation_id is None # LLM doesn't assign IDs assert match.regulation_id == "dsgvo" @pytest.mark.asyncio async def test_llm_extraction_failure(self): """When LLM returns empty, should return match with confidence 0.""" with patch( "compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock, return_value="", ): match = await self.extractor._tier3_llm("some text", "dsgvo", "Art. 1") assert match.method == "llm_extracted" assert match.confidence == 0.0 assert match.obligation_text is None @pytest.mark.asyncio async def test_llm_extraction_malformed_json(self): """When LLM returns non-JSON, should use raw text as fallback.""" with patch( "compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock, return_value="Dies ist die Pflicht: Daten schuetzen", ): match = await self.extractor._tier3_llm("some text", "dsgvo", None) assert match.method == "llm_extracted" assert match.confidence == 0.60 # Fallback: uses first 500 chars of response as obligation_text assert "Pflicht" in match.obligation_text or "Daten" in match.obligation_text @pytest.mark.asyncio async def test_llm_regulation_normalization(self): """Regulation code should be normalized in result.""" with patch( "compliance.services.obligation_extractor._llm_ollama", new_callable=AsyncMock, return_value='{"obligation_text": "Test"}', ): match = await self.extractor._tier3_llm( "text", "eu_2024_1689", "Art. 6" ) assert match.regulation_id == "ai_act" # ============================================================================= # Tests: ObligationExtractor — Full 3-Tier extract() # ============================================================================= class TestExtractFullFlow: """Tests for the full 3-tier extraction flow.""" def setup_method(self): self.extractor = ObligationExtractor() self.extractor._load_obligations() # Mark as initialized to skip async initialize self.extractor._initialized = True # Empty embeddings — Tier 2 will return None self.extractor._obligation_embeddings = [] self.extractor._obligation_ids = [] @pytest.mark.asyncio async def test_tier1_takes_priority(self): """When Tier 1 matches, Tier 2 and 3 should not be called.""" with patch.object( self.extractor, "_tier2_embedding", new_callable=AsyncMock ) as mock_t2, patch.object( self.extractor, "_tier3_llm", new_callable=AsyncMock ) as mock_t3: match = await self.extractor.extract( chunk_text="irrelevant", regulation_code="eu_2016_679", article="Art. 30", ) assert match.method == "exact_match" mock_t2.assert_not_called() mock_t3.assert_not_called() @pytest.mark.asyncio async def test_tier2_when_tier1_misses(self): """When Tier 1 misses, Tier 2 should be tried.""" tier2_result = ObligationMatch( obligation_id="DSGVO-OBL-050", method="embedding_match", confidence=0.85, regulation_id="dsgvo", ) with patch.object( self.extractor, "_tier2_embedding", new_callable=AsyncMock, return_value=tier2_result, ) as mock_t2, patch.object( self.extractor, "_tier3_llm", new_callable=AsyncMock ) as mock_t3: match = await self.extractor.extract( chunk_text="some compliance text", regulation_code="eu_2016_679", article="Art. 999", # Non-matching article ) assert match.method == "embedding_match" mock_t2.assert_called_once() mock_t3.assert_not_called() @pytest.mark.asyncio async def test_tier3_when_tier1_and_2_miss(self): """When Tier 1 and 2 miss, Tier 3 should be called.""" tier3_result = ObligationMatch( obligation_text="LLM extracted obligation", method="llm_extracted", confidence=0.60, ) with patch.object( self.extractor, "_tier2_embedding", new_callable=AsyncMock, return_value=None, ), patch.object( self.extractor, "_tier3_llm", new_callable=AsyncMock, return_value=tier3_result, ): match = await self.extractor.extract( chunk_text="unrelated text", regulation_code="unknown_reg", article="Art. 999", ) assert match.method == "llm_extracted" @pytest.mark.asyncio async def test_no_article_skips_tier1(self): """When no article is provided, Tier 1 should be skipped.""" with patch.object( self.extractor, "_tier2_embedding", new_callable=AsyncMock, return_value=None, ) as mock_t2, patch.object( self.extractor, "_tier3_llm", new_callable=AsyncMock, return_value=ObligationMatch(method="llm_extracted", confidence=0.60), ): match = await self.extractor.extract( chunk_text="some text", regulation_code="dsgvo", article=None, ) # Tier 2 should be called (Tier 1 skipped due to no article) mock_t2.assert_called_once() @pytest.mark.asyncio async def test_auto_initialize(self): """If not initialized, extract should call initialize().""" extractor = ObligationExtractor() assert not extractor._initialized with patch.object( extractor, "initialize", new_callable=AsyncMock ) as mock_init: # After mock init, set initialized to True async def side_effect(): extractor._initialized = True extractor._load_obligations() extractor._obligation_embeddings = [] extractor._obligation_ids = [] mock_init.side_effect = side_effect with patch.object( extractor, "_tier2_embedding", new_callable=AsyncMock, return_value=None, ), patch.object( extractor, "_tier3_llm", new_callable=AsyncMock, return_value=ObligationMatch(method="llm_extracted", confidence=0.60), ): await extractor.extract( chunk_text="test", regulation_code="dsgvo", article=None, ) mock_init.assert_called_once() # ============================================================================= # Tests: ObligationExtractor — stats() # ============================================================================= class TestExtractorStats: """Tests for the stats() method.""" def test_stats_before_init(self): extractor = ObligationExtractor() stats = extractor.stats() assert stats["total_obligations"] == 0 assert stats["article_lookups"] == 0 assert stats["initialized"] is False def test_stats_after_load(self): extractor = ObligationExtractor() extractor._load_obligations() stats = extractor.stats() assert stats["total_obligations"] == 325 assert stats["article_lookups"] > 0 assert "dsgvo" in stats["regulations"] assert stats["initialized"] is False # not fully initialized (no embeddings) def test_stats_regulations_complete(self): extractor = ObligationExtractor() extractor._load_obligations() stats = extractor.stats() expected_regs = {"dsgvo", "ai_act", "nis2", "bdsg", "ttdsg", "dsa", "data_act", "eu_machinery", "dora"} assert set(stats["regulations"]) == expected_regs # ============================================================================= # Tests: Integration — Regulation-to-Obligation mapping coverage # ============================================================================= class TestRegulationObligationCoverage: """Verify that the article lookup provides reasonable coverage.""" def setup_method(self): self.extractor = ObligationExtractor() self.extractor._load_obligations() def test_dsgvo_has_article_lookups(self): """DSGVO (80 obligations) should have many article lookups.""" dsgvo_keys = [k for k in self.extractor._article_lookup if k.startswith("dsgvo/")] assert len(dsgvo_keys) >= 20, f"Only {len(dsgvo_keys)} DSGVO article lookups" def test_ai_act_has_article_lookups(self): ai_keys = [k for k in self.extractor._article_lookup if k.startswith("ai_act/")] assert len(ai_keys) >= 10, f"Only {len(ai_keys)} AI Act article lookups" def test_nis2_has_article_lookups(self): nis2_keys = [k for k in self.extractor._article_lookup if k.startswith("nis2/")] assert len(nis2_keys) >= 5, f"Only {len(nis2_keys)} NIS2 article lookups" def test_all_article_lookup_values_are_valid(self): """Every obligation ID in article_lookup should exist in _obligations.""" for key, obl_ids in self.extractor._article_lookup.items(): for obl_id in obl_ids: assert obl_id in self.extractor._obligations, ( f"Article lookup {key} references missing obligation {obl_id}" ) def test_article_lookup_key_format(self): """All keys should be in format 'regulation_id/normalized_article'.""" for key in self.extractor._article_lookup: parts = key.split("/", 1) assert len(parts) == 2, f"Invalid key format: {key}" reg_id, article = parts assert reg_id, f"Empty regulation ID in key: {key}" assert article, f"Empty article in key: {key}" assert article == article.lower(), f"Article not lowercase: {key}" # ============================================================================= # Tests: Constants and thresholds # ============================================================================= class TestConstants: """Tests for module-level constants.""" def test_embedding_thresholds_ordering(self): """Match threshold should be higher than candidate threshold.""" assert EMBEDDING_MATCH_THRESHOLD > EMBEDDING_CANDIDATE_THRESHOLD def test_embedding_thresholds_range(self): """Thresholds should be between 0 and 1.""" assert 0 < EMBEDDING_MATCH_THRESHOLD <= 1.0 assert 0 < EMBEDDING_CANDIDATE_THRESHOLD <= 1.0 def test_match_threshold_is_80(self): assert EMBEDDING_MATCH_THRESHOLD == 0.80 def test_candidate_threshold_is_60(self): assert EMBEDDING_CANDIDATE_THRESHOLD == 0.60