"""Tests for Control Deduplication Engine (4-Stage Matching Pipeline). Covers: - normalize_action(): German → canonical English verb mapping - normalize_object(): Compliance object normalization - canonicalize_text(): Canonicalization layer for embedding - cosine_similarity(): Vector math - DedupResult dataclass - ControlDedupChecker.check_duplicate() — all 4 stages and verdicts """ import pytest from unittest.mock import MagicMock, AsyncMock, patch from compliance.services.control_dedup import ( normalize_action, normalize_object, canonicalize_text, cosine_similarity, DedupResult, ControlDedupChecker, LINK_THRESHOLD, REVIEW_THRESHOLD, LINK_THRESHOLD_DIFF_OBJECT, CROSS_REG_LINK_THRESHOLD, ) # --------------------------------------------------------------------------- # normalize_action TESTS # --------------------------------------------------------------------------- class TestNormalizeAction: """Stage 2: Action normalization (German → canonical English).""" def test_german_implement_synonyms(self): for verb in ["implementieren", "umsetzen", "einrichten", "einführen", "aktivieren"]: assert normalize_action(verb) == "implement", f"{verb} should map to implement" def test_german_test_synonyms(self): for verb in ["testen", "prüfen", "überprüfen", "verifizieren", "validieren"]: assert normalize_action(verb) == "test", f"{verb} should map to test" def test_german_monitor_synonyms(self): for verb in ["überwachen", "monitoring", "beobachten"]: assert normalize_action(verb) == "monitor", f"{verb} should map to monitor" def test_german_encrypt(self): assert normalize_action("verschlüsseln") == "encrypt" def test_german_log_synonyms(self): for verb in ["protokollieren", "aufzeichnen", "loggen"]: assert normalize_action(verb) == "log", f"{verb} should map to log" def test_german_restrict_synonyms(self): for verb in ["beschränken", "einschränken", "begrenzen"]: assert normalize_action(verb) == "restrict", f"{verb} should map to restrict" def test_english_passthrough(self): assert normalize_action("implement") == "implement" assert normalize_action("test") == "test" assert normalize_action("monitor") == "monitor" assert normalize_action("encrypt") == "encrypt" def test_case_insensitive(self): assert normalize_action("IMPLEMENTIEREN") == "implement" assert normalize_action("Testen") == "test" def test_whitespace_handling(self): assert normalize_action(" implementieren ") == "implement" def test_empty_string(self): assert normalize_action("") == "" def test_unknown_verb_passthrough(self): assert normalize_action("fluxkapazitieren") == "fluxkapazitieren" def test_german_authorize_synonyms(self): for verb in ["autorisieren", "genehmigen", "freigeben"]: assert normalize_action(verb) == "authorize", f"{verb} should map to authorize" def test_german_notify_synonyms(self): for verb in ["benachrichtigen", "informieren"]: assert normalize_action(verb) == "notify", f"{verb} should map to notify" # --------------------------------------------------------------------------- # normalize_object TESTS # --------------------------------------------------------------------------- class TestNormalizeObject: """Stage 3: Object normalization (compliance objects → canonical tokens).""" def test_mfa_synonyms(self): for obj in ["MFA", "2FA", "multi-faktor-authentifizierung", "two-factor"]: assert normalize_object(obj) == "multi_factor_auth", f"{obj} should → multi_factor_auth" def test_password_synonyms(self): for obj in ["Passwort", "Kennwort", "password"]: assert normalize_object(obj) == "password_policy", f"{obj} should → password_policy" def test_privileged_access(self): for obj in ["Admin-Konten", "admin accounts", "privilegierte Zugriffe"]: assert normalize_object(obj) == "privileged_access", f"{obj} should → privileged_access" def test_remote_access(self): for obj in ["Remote-Zugriff", "Fernzugriff", "remote access"]: assert normalize_object(obj) == "remote_access", f"{obj} should → remote_access" def test_encryption_synonyms(self): for obj in ["Verschlüsselung", "encryption", "Kryptografie"]: assert normalize_object(obj) == "encryption", f"{obj} should → encryption" def test_key_management(self): for obj in ["Schlüssel", "key management", "Schlüsselverwaltung"]: assert normalize_object(obj) == "key_management", f"{obj} should → key_management" def test_transport_encryption(self): for obj in ["TLS", "SSL", "HTTPS"]: assert normalize_object(obj) == "transport_encryption", f"{obj} should → transport_encryption" def test_audit_logging(self): for obj in ["Audit-Log", "audit log", "Protokoll", "logging"]: assert normalize_object(obj) == "audit_logging", f"{obj} should → audit_logging" def test_vulnerability(self): assert normalize_object("Schwachstelle") == "vulnerability" assert normalize_object("vulnerability") == "vulnerability" def test_patch_management(self): for obj in ["Patch", "patching"]: assert normalize_object(obj) == "patch_management", f"{obj} should → patch_management" def test_case_insensitive(self): assert normalize_object("FIREWALL") == "firewall" assert normalize_object("VPN") == "vpn" def test_empty_string(self): assert normalize_object("") == "" def test_substring_match(self): """Longer phrases containing known keywords should match.""" assert normalize_object("die Admin-Konten des Unternehmens") == "privileged_access" assert normalize_object("zentrale Schlüsselverwaltung") == "key_management" def test_unknown_object_fallback(self): """Unknown objects get cleaned and underscore-joined.""" result = normalize_object("Quantencomputer Resistenz") assert "_" in result or result == "quantencomputer_resistenz" def test_articles_stripped_in_fallback(self): """German/English articles should be stripped in fallback.""" result = normalize_object("der grosse Quantencomputer") # "der" and "grosse" (>2 chars) → tokens without articles assert "der" not in result # --------------------------------------------------------------------------- # canonicalize_text TESTS # --------------------------------------------------------------------------- class TestCanonicalizeText: """Canonicalization layer: German compliance text → normalized English for embedding.""" def test_basic_canonicalization(self): result = canonicalize_text("implementieren", "MFA") assert "implement" in result assert "multi_factor_auth" in result def test_with_title(self): result = canonicalize_text("testen", "Firewall", "Netzwerk-Firewall regelmässig prüfen") assert "test" in result assert "firewall" in result def test_title_filler_stripped(self): result = canonicalize_text("implementieren", "VPN", "für den Zugriff gemäß Richtlinie") # "für", "den", "gemäß" should be stripped assert "für" not in result assert "gemäß" not in result def test_empty_action_and_object(self): result = canonicalize_text("", "") assert result.strip() == "" def test_example_from_spec(self): """The canonical form of the spec example.""" result = canonicalize_text("implementieren", "MFA", "Administratoren müssen MFA verwenden") assert "implement" in result assert "multi_factor_auth" in result # --------------------------------------------------------------------------- # cosine_similarity TESTS # --------------------------------------------------------------------------- class TestCosineSimilarity: def test_identical_vectors(self): v = [1.0, 0.0, 0.0] assert cosine_similarity(v, v) == pytest.approx(1.0) def test_orthogonal_vectors(self): a = [1.0, 0.0] b = [0.0, 1.0] assert cosine_similarity(a, b) == pytest.approx(0.0) def test_opposite_vectors(self): a = [1.0, 0.0] b = [-1.0, 0.0] assert cosine_similarity(a, b) == pytest.approx(-1.0) def test_empty_vectors(self): assert cosine_similarity([], []) == 0.0 def test_mismatched_lengths(self): assert cosine_similarity([1.0], [1.0, 2.0]) == 0.0 def test_zero_vector(self): assert cosine_similarity([0.0, 0.0], [1.0, 1.0]) == 0.0 # --------------------------------------------------------------------------- # DedupResult TESTS # --------------------------------------------------------------------------- class TestDedupResult: def test_defaults(self): r = DedupResult(verdict="new") assert r.verdict == "new" assert r.matched_control_uuid is None assert r.stage == "" assert r.similarity_score == 0.0 assert r.details == {} def test_link_result(self): r = DedupResult( verdict="link", matched_control_uuid="abc-123", matched_control_id="AUTH-2001", stage="embedding_match", similarity_score=0.95, ) assert r.verdict == "link" assert r.matched_control_id == "AUTH-2001" # --------------------------------------------------------------------------- # ControlDedupChecker TESTS (mocked DB + embedding) # --------------------------------------------------------------------------- class TestControlDedupChecker: """Integration tests for the 4-stage dedup pipeline with mocks.""" def _make_checker(self, existing_controls=None, search_results=None): """Build a ControlDedupChecker with mocked dependencies.""" db = MagicMock() # Mock DB query for existing controls if existing_controls is not None: mock_rows = [] for c in existing_controls: row = (c["uuid"], c["control_id"], c["title"], c["objective"], c.get("pattern_id", "CP-AUTH-001"), c.get("obligation_type")) mock_rows.append(row) db.execute.return_value.fetchall.return_value = mock_rows # Mock embedding function async def fake_embed(text): return [0.1] * 1024 # Mock Qdrant search async def fake_search(embedding, pattern_id, top_k=10): return search_results or [] return ControlDedupChecker(db, embed_fn=fake_embed, search_fn=fake_search) @pytest.mark.asyncio async def test_no_pattern_id_returns_new(self): checker = self._make_checker() result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id=None) assert result.verdict == "new" assert result.stage == "no_pattern" @pytest.mark.asyncio async def test_no_existing_controls_returns_new(self): checker = self._make_checker(existing_controls=[]) result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "pattern_gate" @pytest.mark.asyncio async def test_no_qdrant_matches_returns_new(self): checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[], ) result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "no_qdrant_matches" @pytest.mark.asyncio async def test_action_mismatch_returns_new(self): """Stage 2: Different action verbs → always NEW, even if embedding is high.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.96, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "test", "object_normalized": "multi_factor_auth", "title": "MFA testen", }, }], ) result = await checker.check_duplicate("implementieren", "MFA", "MFA implementieren", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "action_mismatch" assert result.details["candidate_action"] == "implement" assert result.details["existing_action"] == "test" @pytest.mark.asyncio async def test_object_mismatch_high_score_links(self): """Stage 3: Different objects but similarity > 0.95 → LINK.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.96, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "remote_access", "title": "Remote-Zugriff MFA", }, }], ) result = await checker.check_duplicate("implementieren", "Admin-Konten", "Admin MFA", pattern_id="CP-AUTH-001") assert result.verdict == "link" assert result.stage == "embedding_diff_object" @pytest.mark.asyncio async def test_object_mismatch_low_score_returns_new(self): """Stage 3: Different objects and similarity < 0.95 → NEW.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.88, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "remote_access", "title": "Remote-Zugriff MFA", }, }], ) result = await checker.check_duplicate("implementieren", "Admin-Konten", "Admin MFA", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "object_mismatch_below_threshold" @pytest.mark.asyncio async def test_same_action_object_high_score_links(self): """Stage 4: Same action + object + similarity > 0.92 → LINK.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.94, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "multi_factor_auth", "title": "MFA implementieren", }, }], ) result = await checker.check_duplicate("implementieren", "MFA", "MFA umsetzen", pattern_id="CP-AUTH-001") assert result.verdict == "link" assert result.stage == "embedding_match" assert result.similarity_score == 0.94 @pytest.mark.asyncio async def test_same_action_object_review_range(self): """Stage 4: Same action + object + 0.85 < similarity < 0.92 → REVIEW.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.88, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "multi_factor_auth", "title": "MFA implementieren", }, }], ) result = await checker.check_duplicate("implementieren", "MFA", "MFA für Admins", pattern_id="CP-AUTH-001") assert result.verdict == "review" assert result.stage == "embedding_review" @pytest.mark.asyncio async def test_same_action_object_low_score_new(self): """Stage 4: Same action + object but similarity < 0.85 → NEW.""" checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.72, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "multi_factor_auth", "title": "MFA implementieren", }, }], ) result = await checker.check_duplicate("implementieren", "MFA", "Ganz anderer MFA Kontext", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "embedding_below_threshold" @pytest.mark.asyncio async def test_embedding_failure_returns_new(self): """If embedding service is down, default to NEW.""" db = MagicMock() db.execute.return_value.fetchall.return_value = [ ("a1", "AUTH-2001", "t", "o", "CP-AUTH-001", None) ] async def failing_embed(text): return [] checker = ControlDedupChecker(db, embed_fn=failing_embed) result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001") assert result.verdict == "new" assert result.stage == "embedding_unavailable" @pytest.mark.asyncio async def test_spec_false_positive_example(self): """The spec example: Admin-MFA vs Remote-MFA should NOT dedup. Even if embedding says >0.9, different objects (privileged_access vs remote_access) and score < 0.95 means NEW. """ checker = self._make_checker( existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}], search_results=[{ "score": 0.91, "payload": { "control_uuid": "a1", "control_id": "AUTH-2001", "action_normalized": "implement", "object_normalized": "remote_access", "title": "Remote-Zugriffe müssen MFA nutzen", }, }], ) result = await checker.check_duplicate( "implementieren", "Admin-Konten", "Admin-Zugriffe müssen MFA nutzen", pattern_id="CP-AUTH-001", ) assert result.verdict == "new" assert result.stage == "object_mismatch_below_threshold" # --------------------------------------------------------------------------- # THRESHOLD CONFIGURATION TESTS # --------------------------------------------------------------------------- class TestThresholds: """Verify the configured threshold values match the spec.""" def test_link_threshold(self): assert LINK_THRESHOLD == 0.92 def test_review_threshold(self): assert REVIEW_THRESHOLD == 0.85 def test_diff_object_threshold(self): assert LINK_THRESHOLD_DIFF_OBJECT == 0.95 def test_threshold_ordering(self): assert LINK_THRESHOLD_DIFF_OBJECT > LINK_THRESHOLD > REVIEW_THRESHOLD def test_cross_reg_threshold(self): assert CROSS_REG_LINK_THRESHOLD == 0.95 def test_cross_reg_threshold_higher_than_intra(self): assert CROSS_REG_LINK_THRESHOLD >= LINK_THRESHOLD # --------------------------------------------------------------------------- # CROSS-REGULATION DEDUP TESTS # --------------------------------------------------------------------------- class TestCrossRegulationDedup: """Tests for cross-regulation linking (second dedup pass).""" def _make_checker(self): """Create a checker with mocked DB, embedding, and search.""" mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [ ("uuid-1", "CTRL-001", "MFA", "Enable MFA", "SEC-AUTH", "pflicht"), ] embed_fn = AsyncMock(return_value=[0.1] * 1024) search_fn = AsyncMock(return_value=[]) # no intra-pattern matches return ControlDedupChecker(db=mock_db, embed_fn=embed_fn, search_fn=search_fn) @pytest.mark.asyncio async def test_cross_reg_triggered_when_intra_is_new(self): """Cross-reg runs when intra-pattern returns 'new'.""" checker = self._make_checker() cross_results = [{ "score": 0.96, "payload": { "control_uuid": "cross-uuid-1", "control_id": "NIS2-CTRL-001", "title": "MFA (NIS2)", }, }] with patch( "compliance.services.control_dedup.qdrant_search_cross_regulation", new_callable=AsyncMock, return_value=cross_results, ): result = await checker.check_duplicate( action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH" ) assert result.verdict == "link" assert result.stage == "cross_regulation" assert result.link_type == "cross_regulation" assert result.matched_control_id == "NIS2-CTRL-001" assert result.similarity_score == 0.96 @pytest.mark.asyncio async def test_cross_reg_not_triggered_when_intra_is_link(self): """Cross-reg should NOT run when intra-pattern already found a link.""" mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [ ("uuid-1", "CTRL-001", "MFA", "Enable MFA", "SEC-AUTH", "pflicht"), ] embed_fn = AsyncMock(return_value=[0.1] * 1024) # Intra-pattern search returns a high match search_fn = AsyncMock(return_value=[{ "score": 0.95, "payload": { "control_uuid": "intra-uuid", "control_id": "CTRL-001", "title": "MFA", "action_normalized": "implement", "object_normalized": "multi_factor_auth", }, }]) checker = ControlDedupChecker(db=mock_db, embed_fn=embed_fn, search_fn=search_fn) with patch( "compliance.services.control_dedup.qdrant_search_cross_regulation", new_callable=AsyncMock, ) as mock_cross: result = await checker.check_duplicate( action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH" ) assert result.verdict == "link" assert result.stage == "embedding_match" assert result.link_type == "dedup_merge" # not cross_regulation mock_cross.assert_not_called() @pytest.mark.asyncio async def test_cross_reg_below_threshold_keeps_new(self): """Cross-reg score below 0.95 keeps the verdict as 'new'.""" checker = self._make_checker() cross_results = [{ "score": 0.93, # below CROSS_REG_LINK_THRESHOLD "payload": { "control_uuid": "cross-uuid-2", "control_id": "NIS2-CTRL-002", "title": "Similar control", }, }] with patch( "compliance.services.control_dedup.qdrant_search_cross_regulation", new_callable=AsyncMock, return_value=cross_results, ): result = await checker.check_duplicate( action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH" ) assert result.verdict == "new" @pytest.mark.asyncio async def test_cross_reg_no_results_keeps_new(self): """No cross-reg results keeps the verdict as 'new'.""" checker = self._make_checker() with patch( "compliance.services.control_dedup.qdrant_search_cross_regulation", new_callable=AsyncMock, return_value=[], ): result = await checker.check_duplicate( action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH" ) assert result.verdict == "new" class TestDedupResultLinkType: """Tests for the link_type field on DedupResult.""" def test_default_link_type(self): r = DedupResult(verdict="new") assert r.link_type == "dedup_merge" def test_cross_regulation_link_type(self): r = DedupResult(verdict="link", link_type="cross_regulation") assert r.link_type == "cross_regulation"