feat(rag): optimize RAG pipeline — JSON-Mode, CoT, Hybrid Search, Re-Ranking, Cross-Reg Dedup, chunk 1024
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped

Phase 1 (LLM Quality):
- Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill)
- Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts

Phase 2 (Retrieval Quality):
- Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go)
- Fallback to dense-only search if Query API unavailable
- Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default)
- CPU-only PyTorch dependency to keep Docker image small

Phase 3 (Data Layer):
- Cross-regulation dedup pass (threshold 0.95) links controls across regulations
- DedupResult.link_type field distinguishes dedup_merge vs cross_regulation
- Chunk size defaults updated 512/50 → 1024/128 for new ingestions only
- Existing collections and controls are NOT affected

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-21 11:49:43 +01:00
parent c3a53fe5d2
commit c52dbdb8f1
24 changed files with 2620 additions and 139 deletions

View File

@@ -219,3 +219,36 @@ class TestCitationBackfillMatching:
sql_text = str(self.db.execute.call_args[0][0].text)
assert "license_rule IN (1, 2)" in sql_text
assert "source_citation IS NOT NULL" in sql_text
# =============================================================================
# Tests: Ollama JSON-Mode
# =============================================================================
class TestOllamaJsonMode:
"""Verify that citation_backfill Ollama payloads include format=json."""
@pytest.mark.asyncio
async def test_ollama_payload_contains_format_json(self):
"""_llm_ollama must send format='json' in the request payload."""
from compliance.services.citation_backfill import _llm_ollama
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"message": {"content": '{"article": "Art. 1"}'}
}
with patch("compliance.services.citation_backfill.httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await _llm_ollama("test prompt", "system prompt")
mock_client.post.assert_called_once()
call_kwargs = mock_client.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["format"] == "json"

View File

@@ -0,0 +1,625 @@
"""Tests for Control Deduplication Engine (4-Stage Matching Pipeline).
Covers:
- normalize_action(): German → canonical English verb mapping
- normalize_object(): Compliance object normalization
- canonicalize_text(): Canonicalization layer for embedding
- cosine_similarity(): Vector math
- DedupResult dataclass
- ControlDedupChecker.check_duplicate() — all 4 stages and verdicts
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from compliance.services.control_dedup import (
normalize_action,
normalize_object,
canonicalize_text,
cosine_similarity,
DedupResult,
ControlDedupChecker,
LINK_THRESHOLD,
REVIEW_THRESHOLD,
LINK_THRESHOLD_DIFF_OBJECT,
CROSS_REG_LINK_THRESHOLD,
)
# ---------------------------------------------------------------------------
# normalize_action TESTS
# ---------------------------------------------------------------------------
class TestNormalizeAction:
"""Stage 2: Action normalization (German → canonical English)."""
def test_german_implement_synonyms(self):
for verb in ["implementieren", "umsetzen", "einrichten", "einführen", "aktivieren"]:
assert normalize_action(verb) == "implement", f"{verb} should map to implement"
def test_german_test_synonyms(self):
for verb in ["testen", "prüfen", "überprüfen", "verifizieren", "validieren"]:
assert normalize_action(verb) == "test", f"{verb} should map to test"
def test_german_monitor_synonyms(self):
for verb in ["überwachen", "monitoring", "beobachten"]:
assert normalize_action(verb) == "monitor", f"{verb} should map to monitor"
def test_german_encrypt(self):
assert normalize_action("verschlüsseln") == "encrypt"
def test_german_log_synonyms(self):
for verb in ["protokollieren", "aufzeichnen", "loggen"]:
assert normalize_action(verb) == "log", f"{verb} should map to log"
def test_german_restrict_synonyms(self):
for verb in ["beschränken", "einschränken", "begrenzen"]:
assert normalize_action(verb) == "restrict", f"{verb} should map to restrict"
def test_english_passthrough(self):
assert normalize_action("implement") == "implement"
assert normalize_action("test") == "test"
assert normalize_action("monitor") == "monitor"
assert normalize_action("encrypt") == "encrypt"
def test_case_insensitive(self):
assert normalize_action("IMPLEMENTIEREN") == "implement"
assert normalize_action("Testen") == "test"
def test_whitespace_handling(self):
assert normalize_action(" implementieren ") == "implement"
def test_empty_string(self):
assert normalize_action("") == ""
def test_unknown_verb_passthrough(self):
assert normalize_action("fluxkapazitieren") == "fluxkapazitieren"
def test_german_authorize_synonyms(self):
for verb in ["autorisieren", "genehmigen", "freigeben"]:
assert normalize_action(verb) == "authorize", f"{verb} should map to authorize"
def test_german_notify_synonyms(self):
for verb in ["benachrichtigen", "informieren"]:
assert normalize_action(verb) == "notify", f"{verb} should map to notify"
# ---------------------------------------------------------------------------
# normalize_object TESTS
# ---------------------------------------------------------------------------
class TestNormalizeObject:
"""Stage 3: Object normalization (compliance objects → canonical tokens)."""
def test_mfa_synonyms(self):
for obj in ["MFA", "2FA", "multi-faktor-authentifizierung", "two-factor"]:
assert normalize_object(obj) == "multi_factor_auth", f"{obj} should → multi_factor_auth"
def test_password_synonyms(self):
for obj in ["Passwort", "Kennwort", "password"]:
assert normalize_object(obj) == "password_policy", f"{obj} should → password_policy"
def test_privileged_access(self):
for obj in ["Admin-Konten", "admin accounts", "privilegierte Zugriffe"]:
assert normalize_object(obj) == "privileged_access", f"{obj} should → privileged_access"
def test_remote_access(self):
for obj in ["Remote-Zugriff", "Fernzugriff", "remote access"]:
assert normalize_object(obj) == "remote_access", f"{obj} should → remote_access"
def test_encryption_synonyms(self):
for obj in ["Verschlüsselung", "encryption", "Kryptografie"]:
assert normalize_object(obj) == "encryption", f"{obj} should → encryption"
def test_key_management(self):
for obj in ["Schlüssel", "key management", "Schlüsselverwaltung"]:
assert normalize_object(obj) == "key_management", f"{obj} should → key_management"
def test_transport_encryption(self):
for obj in ["TLS", "SSL", "HTTPS"]:
assert normalize_object(obj) == "transport_encryption", f"{obj} should → transport_encryption"
def test_audit_logging(self):
for obj in ["Audit-Log", "audit log", "Protokoll", "logging"]:
assert normalize_object(obj) == "audit_logging", f"{obj} should → audit_logging"
def test_vulnerability(self):
assert normalize_object("Schwachstelle") == "vulnerability"
assert normalize_object("vulnerability") == "vulnerability"
def test_patch_management(self):
for obj in ["Patch", "patching"]:
assert normalize_object(obj) == "patch_management", f"{obj} should → patch_management"
def test_case_insensitive(self):
assert normalize_object("FIREWALL") == "firewall"
assert normalize_object("VPN") == "vpn"
def test_empty_string(self):
assert normalize_object("") == ""
def test_substring_match(self):
"""Longer phrases containing known keywords should match."""
assert normalize_object("die Admin-Konten des Unternehmens") == "privileged_access"
assert normalize_object("zentrale Schlüsselverwaltung") == "key_management"
def test_unknown_object_fallback(self):
"""Unknown objects get cleaned and underscore-joined."""
result = normalize_object("Quantencomputer Resistenz")
assert "_" in result or result == "quantencomputer_resistenz"
def test_articles_stripped_in_fallback(self):
"""German/English articles should be stripped in fallback."""
result = normalize_object("der grosse Quantencomputer")
# "der" and "grosse" (>2 chars) → tokens without articles
assert "der" not in result
# ---------------------------------------------------------------------------
# canonicalize_text TESTS
# ---------------------------------------------------------------------------
class TestCanonicalizeText:
"""Canonicalization layer: German compliance text → normalized English for embedding."""
def test_basic_canonicalization(self):
result = canonicalize_text("implementieren", "MFA")
assert "implement" in result
assert "multi_factor_auth" in result
def test_with_title(self):
result = canonicalize_text("testen", "Firewall", "Netzwerk-Firewall regelmässig prüfen")
assert "test" in result
assert "firewall" in result
def test_title_filler_stripped(self):
result = canonicalize_text("implementieren", "VPN", "für den Zugriff gemäß Richtlinie")
# "für", "den", "gemäß" should be stripped
assert "für" not in result
assert "gemäß" not in result
def test_empty_action_and_object(self):
result = canonicalize_text("", "")
assert result.strip() == ""
def test_example_from_spec(self):
"""The canonical form of the spec example."""
result = canonicalize_text("implementieren", "MFA", "Administratoren müssen MFA verwenden")
assert "implement" in result
assert "multi_factor_auth" in result
# ---------------------------------------------------------------------------
# cosine_similarity TESTS
# ---------------------------------------------------------------------------
class TestCosineSimilarity:
def test_identical_vectors(self):
v = [1.0, 0.0, 0.0]
assert cosine_similarity(v, v) == pytest.approx(1.0)
def test_orthogonal_vectors(self):
a = [1.0, 0.0]
b = [0.0, 1.0]
assert cosine_similarity(a, b) == pytest.approx(0.0)
def test_opposite_vectors(self):
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert cosine_similarity(a, b) == pytest.approx(-1.0)
def test_empty_vectors(self):
assert cosine_similarity([], []) == 0.0
def test_mismatched_lengths(self):
assert cosine_similarity([1.0], [1.0, 2.0]) == 0.0
def test_zero_vector(self):
assert cosine_similarity([0.0, 0.0], [1.0, 1.0]) == 0.0
# ---------------------------------------------------------------------------
# DedupResult TESTS
# ---------------------------------------------------------------------------
class TestDedupResult:
def test_defaults(self):
r = DedupResult(verdict="new")
assert r.verdict == "new"
assert r.matched_control_uuid is None
assert r.stage == ""
assert r.similarity_score == 0.0
assert r.details == {}
def test_link_result(self):
r = DedupResult(
verdict="link",
matched_control_uuid="abc-123",
matched_control_id="AUTH-2001",
stage="embedding_match",
similarity_score=0.95,
)
assert r.verdict == "link"
assert r.matched_control_id == "AUTH-2001"
# ---------------------------------------------------------------------------
# ControlDedupChecker TESTS (mocked DB + embedding)
# ---------------------------------------------------------------------------
class TestControlDedupChecker:
"""Integration tests for the 4-stage dedup pipeline with mocks."""
def _make_checker(self, existing_controls=None, search_results=None):
"""Build a ControlDedupChecker with mocked dependencies."""
db = MagicMock()
# Mock DB query for existing controls
if existing_controls is not None:
mock_rows = []
for c in existing_controls:
row = (c["uuid"], c["control_id"], c["title"], c["objective"],
c.get("pattern_id", "CP-AUTH-001"), c.get("obligation_type"))
mock_rows.append(row)
db.execute.return_value.fetchall.return_value = mock_rows
# Mock embedding function
async def fake_embed(text):
return [0.1] * 1024
# Mock Qdrant search
async def fake_search(embedding, pattern_id, top_k=10):
return search_results or []
return ControlDedupChecker(db, embed_fn=fake_embed, search_fn=fake_search)
@pytest.mark.asyncio
async def test_no_pattern_id_returns_new(self):
checker = self._make_checker()
result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id=None)
assert result.verdict == "new"
assert result.stage == "no_pattern"
@pytest.mark.asyncio
async def test_no_existing_controls_returns_new(self):
checker = self._make_checker(existing_controls=[])
result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "pattern_gate"
@pytest.mark.asyncio
async def test_no_qdrant_matches_returns_new(self):
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[],
)
result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "no_qdrant_matches"
@pytest.mark.asyncio
async def test_action_mismatch_returns_new(self):
"""Stage 2: Different action verbs → always NEW, even if embedding is high."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.96,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "test",
"object_normalized": "multi_factor_auth",
"title": "MFA testen",
},
}],
)
result = await checker.check_duplicate("implementieren", "MFA", "MFA implementieren", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "action_mismatch"
assert result.details["candidate_action"] == "implement"
assert result.details["existing_action"] == "test"
@pytest.mark.asyncio
async def test_object_mismatch_high_score_links(self):
"""Stage 3: Different objects but similarity > 0.95 → LINK."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.96,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "remote_access",
"title": "Remote-Zugriff MFA",
},
}],
)
result = await checker.check_duplicate("implementieren", "Admin-Konten", "Admin MFA", pattern_id="CP-AUTH-001")
assert result.verdict == "link"
assert result.stage == "embedding_diff_object"
@pytest.mark.asyncio
async def test_object_mismatch_low_score_returns_new(self):
"""Stage 3: Different objects and similarity < 0.95 → NEW."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.88,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "remote_access",
"title": "Remote-Zugriff MFA",
},
}],
)
result = await checker.check_duplicate("implementieren", "Admin-Konten", "Admin MFA", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "object_mismatch_below_threshold"
@pytest.mark.asyncio
async def test_same_action_object_high_score_links(self):
"""Stage 4: Same action + object + similarity > 0.92 → LINK."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.94,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "multi_factor_auth",
"title": "MFA implementieren",
},
}],
)
result = await checker.check_duplicate("implementieren", "MFA", "MFA umsetzen", pattern_id="CP-AUTH-001")
assert result.verdict == "link"
assert result.stage == "embedding_match"
assert result.similarity_score == 0.94
@pytest.mark.asyncio
async def test_same_action_object_review_range(self):
"""Stage 4: Same action + object + 0.85 < similarity < 0.92 → REVIEW."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.88,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "multi_factor_auth",
"title": "MFA implementieren",
},
}],
)
result = await checker.check_duplicate("implementieren", "MFA", "MFA für Admins", pattern_id="CP-AUTH-001")
assert result.verdict == "review"
assert result.stage == "embedding_review"
@pytest.mark.asyncio
async def test_same_action_object_low_score_new(self):
"""Stage 4: Same action + object but similarity < 0.85 → NEW."""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.72,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "multi_factor_auth",
"title": "MFA implementieren",
},
}],
)
result = await checker.check_duplicate("implementieren", "MFA", "Ganz anderer MFA Kontext", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "embedding_below_threshold"
@pytest.mark.asyncio
async def test_embedding_failure_returns_new(self):
"""If embedding service is down, default to NEW."""
db = MagicMock()
db.execute.return_value.fetchall.return_value = [
("a1", "AUTH-2001", "t", "o", "CP-AUTH-001", None)
]
async def failing_embed(text):
return []
checker = ControlDedupChecker(db, embed_fn=failing_embed)
result = await checker.check_duplicate("implement", "MFA", "Test", pattern_id="CP-AUTH-001")
assert result.verdict == "new"
assert result.stage == "embedding_unavailable"
@pytest.mark.asyncio
async def test_spec_false_positive_example(self):
"""The spec example: Admin-MFA vs Remote-MFA should NOT dedup.
Even if embedding says >0.9, different objects (privileged_access vs remote_access)
and score < 0.95 means NEW.
"""
checker = self._make_checker(
existing_controls=[{"uuid": "a1", "control_id": "AUTH-2001", "title": "t", "objective": "o"}],
search_results=[{
"score": 0.91,
"payload": {
"control_uuid": "a1", "control_id": "AUTH-2001",
"action_normalized": "implement",
"object_normalized": "remote_access",
"title": "Remote-Zugriffe müssen MFA nutzen",
},
}],
)
result = await checker.check_duplicate(
"implementieren", "Admin-Konten",
"Admin-Zugriffe müssen MFA nutzen",
pattern_id="CP-AUTH-001",
)
assert result.verdict == "new"
assert result.stage == "object_mismatch_below_threshold"
# ---------------------------------------------------------------------------
# THRESHOLD CONFIGURATION TESTS
# ---------------------------------------------------------------------------
class TestThresholds:
"""Verify the configured threshold values match the spec."""
def test_link_threshold(self):
assert LINK_THRESHOLD == 0.92
def test_review_threshold(self):
assert REVIEW_THRESHOLD == 0.85
def test_diff_object_threshold(self):
assert LINK_THRESHOLD_DIFF_OBJECT == 0.95
def test_threshold_ordering(self):
assert LINK_THRESHOLD_DIFF_OBJECT > LINK_THRESHOLD > REVIEW_THRESHOLD
def test_cross_reg_threshold(self):
assert CROSS_REG_LINK_THRESHOLD == 0.95
def test_cross_reg_threshold_higher_than_intra(self):
assert CROSS_REG_LINK_THRESHOLD >= LINK_THRESHOLD
# ---------------------------------------------------------------------------
# CROSS-REGULATION DEDUP TESTS
# ---------------------------------------------------------------------------
class TestCrossRegulationDedup:
"""Tests for cross-regulation linking (second dedup pass)."""
def _make_checker(self):
"""Create a checker with mocked DB, embedding, and search."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA", "Enable MFA", "SEC-AUTH", "pflicht"),
]
embed_fn = AsyncMock(return_value=[0.1] * 1024)
search_fn = AsyncMock(return_value=[]) # no intra-pattern matches
return ControlDedupChecker(db=mock_db, embed_fn=embed_fn, search_fn=search_fn)
@pytest.mark.asyncio
async def test_cross_reg_triggered_when_intra_is_new(self):
"""Cross-reg runs when intra-pattern returns 'new'."""
checker = self._make_checker()
cross_results = [{
"score": 0.96,
"payload": {
"control_uuid": "cross-uuid-1",
"control_id": "NIS2-CTRL-001",
"title": "MFA (NIS2)",
},
}]
with patch(
"compliance.services.control_dedup.qdrant_search_cross_regulation",
new_callable=AsyncMock,
return_value=cross_results,
):
result = await checker.check_duplicate(
action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH"
)
assert result.verdict == "link"
assert result.stage == "cross_regulation"
assert result.link_type == "cross_regulation"
assert result.matched_control_id == "NIS2-CTRL-001"
assert result.similarity_score == 0.96
@pytest.mark.asyncio
async def test_cross_reg_not_triggered_when_intra_is_link(self):
"""Cross-reg should NOT run when intra-pattern already found a link."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA", "Enable MFA", "SEC-AUTH", "pflicht"),
]
embed_fn = AsyncMock(return_value=[0.1] * 1024)
# Intra-pattern search returns a high match
search_fn = AsyncMock(return_value=[{
"score": 0.95,
"payload": {
"control_uuid": "intra-uuid",
"control_id": "CTRL-001",
"title": "MFA",
"action_normalized": "implement",
"object_normalized": "multi_factor_auth",
},
}])
checker = ControlDedupChecker(db=mock_db, embed_fn=embed_fn, search_fn=search_fn)
with patch(
"compliance.services.control_dedup.qdrant_search_cross_regulation",
new_callable=AsyncMock,
) as mock_cross:
result = await checker.check_duplicate(
action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH"
)
assert result.verdict == "link"
assert result.stage == "embedding_match"
assert result.link_type == "dedup_merge" # not cross_regulation
mock_cross.assert_not_called()
@pytest.mark.asyncio
async def test_cross_reg_below_threshold_keeps_new(self):
"""Cross-reg score below 0.95 keeps the verdict as 'new'."""
checker = self._make_checker()
cross_results = [{
"score": 0.93, # below CROSS_REG_LINK_THRESHOLD
"payload": {
"control_uuid": "cross-uuid-2",
"control_id": "NIS2-CTRL-002",
"title": "Similar control",
},
}]
with patch(
"compliance.services.control_dedup.qdrant_search_cross_regulation",
new_callable=AsyncMock,
return_value=cross_results,
):
result = await checker.check_duplicate(
action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH"
)
assert result.verdict == "new"
@pytest.mark.asyncio
async def test_cross_reg_no_results_keeps_new(self):
"""No cross-reg results keeps the verdict as 'new'."""
checker = self._make_checker()
with patch(
"compliance.services.control_dedup.qdrant_search_cross_regulation",
new_callable=AsyncMock,
return_value=[],
):
result = await checker.check_duplicate(
action="implement", obj="MFA", title="MFA", pattern_id="SEC-AUTH"
)
assert result.verdict == "new"
class TestDedupResultLinkType:
"""Tests for the link_type field on DedupResult."""
def test_default_link_type(self):
r = DedupResult(verdict="new")
assert r.link_type == "dedup_merge"
def test_cross_regulation_link_type(self):
r = DedupResult(verdict="link", link_type="cross_regulation")
assert r.link_type == "cross_regulation"

View File

@@ -30,7 +30,7 @@ class TestLicenseMapping:
def test_rule1_eu_law(self):
info = _classify_regulation("eu_2016_679")
assert info["rule"] == 1
assert info["name"] == "DSGVO"
assert "DSGVO" in info["name"]
assert info["source_type"] == "law"
def test_rule1_nist(self):
@@ -42,7 +42,7 @@ class TestLicenseMapping:
def test_rule1_german_law(self):
info = _classify_regulation("bdsg")
assert info["rule"] == 1
assert info["name"] == "BDSG"
assert "BDSG" in info["name"]
assert info["source_type"] == "law"
def test_rule2_owasp(self):
@@ -199,7 +199,7 @@ class TestAnchorFinder:
async def test_rag_anchor_search_filters_restricted(self):
"""Only Rule 1+2 sources are returned as anchors."""
mock_rag = AsyncMock()
mock_rag.search.return_value = [
mock_rag.search_with_rerank.return_value = [
RAGSearchResult(
text="OWASP requirement",
regulation_code="owasp_asvs",
@@ -231,7 +231,7 @@ class TestAnchorFinder:
# Only OWASP should be returned (Rule 2), BSI should be filtered out (Rule 3)
assert len(anchors) == 1
assert anchors[0].framework == "OWASP ASVS"
assert "OWASP ASVS" in anchors[0].framework
@pytest.mark.asyncio
async def test_web_search_identifies_frameworks(self):
@@ -1668,3 +1668,36 @@ class TestApplicabilityFields:
control = pipeline._build_control_from_json(data, "SEC")
assert "applicable_industries" not in control.generation_metadata
assert "applicable_company_size" not in control.generation_metadata
# =============================================================================
# Tests: Ollama JSON-Mode
# =============================================================================
class TestOllamaJsonMode:
"""Verify that control_generator Ollama payloads include format=json."""
@pytest.mark.asyncio
async def test_ollama_payload_contains_format_json(self):
"""_llm_ollama must send format='json' in the request payload."""
from compliance.services.control_generator import _llm_ollama
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"message": {"content": '{"test": true}'}
}
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await _llm_ollama("test prompt", "system prompt")
mock_client.post.assert_called_once()
call_kwargs = mock_client.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["format"] == "json"

View File

@@ -25,7 +25,11 @@ from compliance.services.decomposition_pass import (
AtomicControlCandidate,
quality_gate,
passes_quality_gate,
classify_obligation_type,
_NORMATIVE_RE,
_PFLICHT_RE,
_EMPFEHLUNG_RE,
_KANN_RE,
_RATIONALE_RE,
_TEST_RE,
_REPORTING_RE,
@@ -176,7 +180,7 @@ class TestQualityGate:
def test_rationale_detected(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Schwache Passwörter können zu Risiken führen, weil sie leicht zu erraten sind",
obligation_text="Dies liegt daran, weil schwache Konfigurationen ein Risiko darstellen",
)
flags = quality_gate(oc)
assert flags["not_rationale"] is False
@@ -228,14 +232,28 @@ class TestQualityGate:
)
flags = quality_gate(oc)
assert flags["has_normative_signal"] is False
assert flags["obligation_type"] == "empfehlung"
def test_obligation_type_in_flags(self):
oc = ObligationCandidate(
parent_control_uuid="uuid-1",
obligation_text="Der Betreiber muss alle Daten verschlüsseln.",
)
flags = quality_gate(oc)
assert flags["obligation_type"] == "pflicht"
class TestPassesQualityGate:
"""Tests for passes_quality_gate function."""
"""Tests for passes_quality_gate function.
Note: has_normative_signal is NO LONGER critical — obligations without
normative signal are classified as 'empfehlung' instead of being rejected.
"""
def test_all_critical_pass(self):
flags = {
"has_normative_signal": True,
"obligation_type": "pflicht",
"single_action": True,
"not_rationale": True,
"not_evidence_only": True,
@@ -244,20 +262,23 @@ class TestPassesQualityGate:
}
assert passes_quality_gate(flags) is True
def test_no_normative_signal_fails(self):
def test_no_normative_signal_still_passes(self):
"""No normative signal no longer causes rejection — classified as empfehlung."""
flags = {
"has_normative_signal": False,
"obligation_type": "empfehlung",
"single_action": True,
"not_rationale": True,
"not_evidence_only": True,
"min_length": True,
"has_parent_link": True,
}
assert passes_quality_gate(flags) is False
assert passes_quality_gate(flags) is True
def test_evidence_only_fails(self):
flags = {
"has_normative_signal": True,
"obligation_type": "pflicht",
"single_action": True,
"not_rationale": True,
"not_evidence_only": False,
@@ -267,9 +288,10 @@ class TestPassesQualityGate:
assert passes_quality_gate(flags) is False
def test_non_critical_dont_block(self):
"""single_action and not_rationale are NOT critical — should still pass."""
"""single_action, not_rationale, has_normative_signal are NOT critical."""
flags = {
"has_normative_signal": True,
"has_normative_signal": False, # Not critical
"obligation_type": "empfehlung",
"single_action": False, # Not critical
"not_rationale": False, # Not critical
"not_evidence_only": True,
@@ -279,6 +301,42 @@ class TestPassesQualityGate:
assert passes_quality_gate(flags) is True
class TestClassifyObligationType:
"""Tests for the 3-tier obligation type classification."""
def test_pflicht_muss(self):
assert classify_obligation_type("Der Betreiber muss alle Daten verschlüsseln") == "pflicht"
def test_pflicht_ist_zu(self):
assert classify_obligation_type("Die Meldung ist innerhalb von 72 Stunden zu erstatten") == "pflicht"
def test_pflicht_shall(self):
assert classify_obligation_type("The controller shall implement appropriate measures") == "pflicht"
def test_empfehlung_soll(self):
assert classify_obligation_type("Der Betreiber soll regelmäßige Audits durchführen") == "empfehlung"
def test_empfehlung_should(self):
assert classify_obligation_type("Organizations should implement security controls") == "empfehlung"
def test_empfehlung_sicherstellen(self):
assert classify_obligation_type("Die Verfügbarkeit der Systeme sicherstellen") == "empfehlung"
def test_kann(self):
assert classify_obligation_type("Der Betreiber kann zusätzliche Maßnahmen ergreifen") == "kann"
def test_kann_may(self):
assert classify_obligation_type("The organization may implement optional safeguards") == "kann"
def test_no_signal_defaults_to_empfehlung(self):
assert classify_obligation_type("Regelmäßige Überprüfung der Zugriffsrechte") == "empfehlung"
def test_pflicht_overrides_empfehlung(self):
"""If both pflicht and empfehlung signals present, pflicht wins."""
txt = "Der Betreiber muss sicherstellen, dass alle Daten verschlüsselt werden"
assert classify_obligation_type(txt) == "pflicht"
# ---------------------------------------------------------------------------
# HELPER TESTS
# ---------------------------------------------------------------------------
@@ -520,6 +578,24 @@ class TestPromptBuilders:
assert "REGELN" in _PASS0A_SYSTEM_PROMPT
assert "atomares" in _PASS0B_SYSTEM_PROMPT
def test_pass0a_prompt_contains_cot_steps(self):
"""Pass 0a system prompt must include Chain-of-Thought analysis steps."""
assert "ANALYSE-SCHRITTE" in _PASS0A_SYSTEM_PROMPT
assert "Adressaten" in _PASS0A_SYSTEM_PROMPT
assert "Handlung" in _PASS0A_SYSTEM_PROMPT
assert "normative Staerke" in _PASS0A_SYSTEM_PROMPT
assert "Meldepflicht" in _PASS0A_SYSTEM_PROMPT
assert "NICHT im Output" in _PASS0A_SYSTEM_PROMPT
def test_pass0b_prompt_contains_cot_steps(self):
"""Pass 0b system prompt must include Chain-of-Thought analysis steps."""
assert "ANALYSE-SCHRITTE" in _PASS0B_SYSTEM_PROMPT
assert "Anforderung" in _PASS0B_SYSTEM_PROMPT
assert "Massnahme" in _PASS0B_SYSTEM_PROMPT
assert "Pruefverfahren" in _PASS0B_SYSTEM_PROMPT
assert "Nachweis" in _PASS0B_SYSTEM_PROMPT
assert "NICHT im Output" in _PASS0B_SYSTEM_PROMPT
# ---------------------------------------------------------------------------
# DECOMPOSITION PASS INTEGRATION TESTS

View File

@@ -937,3 +937,36 @@ class TestConstants:
def test_candidate_threshold_is_60(self):
assert EMBEDDING_CANDIDATE_THRESHOLD == 0.60
# =============================================================================
# Tests: Ollama JSON-Mode
# =============================================================================
class TestOllamaJsonMode:
"""Verify that Ollama payloads include format=json."""
@pytest.mark.asyncio
async def test_ollama_payload_contains_format_json(self):
"""_llm_ollama must send format='json' in the request payload."""
from compliance.services.obligation_extractor import _llm_ollama
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"message": {"content": '{"test": true}'}
}
with patch("compliance.services.obligation_extractor.httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await _llm_ollama("test prompt", "system prompt")
mock_client.post.assert_called_once()
call_kwargs = mock_client.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["format"] == "json"

View File

@@ -0,0 +1,191 @@
"""Tests for Cross-Encoder Re-Ranking module."""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from compliance.services.reranker import Reranker, get_reranker, RERANK_ENABLED
from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult
# =============================================================================
# Reranker Unit Tests
# =============================================================================
class TestReranker:
"""Tests for Reranker class."""
def test_rerank_empty_texts(self):
"""Empty texts list returns empty indices."""
reranker = Reranker()
assert reranker.rerank("query", [], top_k=5) == []
def test_rerank_returns_correct_indices(self):
"""Reranker returns indices sorted by score descending."""
reranker = Reranker()
# Mock the cross-encoder model
mock_model = MagicMock()
# Scores: text[0]=0.1, text[1]=0.9, text[2]=0.5
mock_model.predict.return_value = [0.1, 0.9, 0.5]
reranker._model = mock_model
indices = reranker.rerank("test query", ["low", "high", "mid"], top_k=3)
assert indices == [1, 2, 0] # sorted by score desc
def test_rerank_top_k_limits_results(self):
"""top_k limits the number of returned indices."""
reranker = Reranker()
mock_model = MagicMock()
mock_model.predict.return_value = [0.1, 0.9, 0.5, 0.7, 0.3]
reranker._model = mock_model
indices = reranker.rerank("query", ["a", "b", "c", "d", "e"], top_k=2)
assert len(indices) == 2
assert indices[0] == 1 # highest score
assert indices[1] == 3 # second highest
def test_rerank_sends_pairs_to_model(self):
"""Model receives [[query, text], ...] pairs."""
reranker = Reranker()
mock_model = MagicMock()
mock_model.predict.return_value = [0.5, 0.8]
reranker._model = mock_model
reranker.rerank("my query", ["text A", "text B"], top_k=2)
call_args = mock_model.predict.call_args[0][0]
assert call_args == [["my query", "text A"], ["my query", "text B"]]
def test_lazy_init_not_loaded_until_rerank(self):
"""Model should not be loaded at construction time."""
reranker = Reranker()
assert reranker._model is None
def test_ensure_model_skips_if_already_loaded(self):
"""_ensure_model should not reload when model is already set."""
reranker = Reranker()
mock_model = MagicMock()
reranker._model = mock_model
# Call _ensure_model — should short-circuit since _model is set
reranker._ensure_model()
reranker._ensure_model()
# Model should still be the same mock
assert reranker._model is mock_model
# =============================================================================
# get_reranker Tests
# =============================================================================
class TestGetReranker:
"""Tests for the get_reranker factory."""
def test_disabled_returns_none(self):
"""When RERANK_ENABLED=false, get_reranker returns None."""
with patch("compliance.services.reranker.RERANK_ENABLED", False):
# Reset singleton
import compliance.services.reranker as mod
mod._reranker = None
result = mod.get_reranker()
assert result is None
def test_enabled_returns_reranker(self):
"""When RERANK_ENABLED=true, get_reranker returns a Reranker instance."""
import compliance.services.reranker as mod
mod._reranker = None
with patch.object(mod, "RERANK_ENABLED", True):
result = mod.get_reranker()
assert isinstance(result, Reranker)
mod._reranker = None # cleanup
def test_singleton_returns_same_instance(self):
"""get_reranker returns the same instance on repeated calls."""
import compliance.services.reranker as mod
mod._reranker = None
with patch.object(mod, "RERANK_ENABLED", True):
r1 = mod.get_reranker()
r2 = mod.get_reranker()
assert r1 is r2
mod._reranker = None # cleanup
# =============================================================================
# search_with_rerank Integration Tests
# =============================================================================
class TestSearchWithRerank:
"""Tests for ComplianceRAGClient.search_with_rerank."""
def _make_result(self, text: str, score: float) -> RAGSearchResult:
return RAGSearchResult(
text=text, regulation_code="eu_2016_679",
regulation_name="DSGVO", regulation_short="DSGVO",
category="regulation", article="", paragraph="",
source_url="", score=score,
)
@pytest.mark.asyncio
async def test_rerank_disabled_falls_through(self):
"""When reranker is disabled, search_with_rerank calls regular search."""
client = ComplianceRAGClient(base_url="http://fake")
results = [self._make_result("text1", 0.9)]
with patch.object(client, "search", new_callable=AsyncMock, return_value=results):
with patch("compliance.services.reranker.get_reranker", return_value=None):
got = await client.search_with_rerank("query", top_k=5)
assert len(got) == 1
assert got[0].text == "text1"
@pytest.mark.asyncio
async def test_rerank_reorders_results(self):
"""When reranker is enabled, results are re-ordered."""
client = ComplianceRAGClient(base_url="http://fake")
candidates = [
self._make_result("low relevance", 0.9),
self._make_result("high relevance", 0.7),
self._make_result("medium relevance", 0.8),
]
mock_reranker = MagicMock()
# Reranker says index 1 is best, then 2, then 0
mock_reranker.rerank.return_value = [1, 2, 0]
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
got = await client.search_with_rerank("query", top_k=2)
# Should get reranked top 2 (but our mock returns [1,2,0] and top_k=2
# means reranker.rerank is called with top_k=2, which returns [1, 2])
mock_reranker.rerank.assert_called_once()
# The rerank mock returns [1,2,0], so we get candidates[1] and candidates[2]
assert got[0].text == "high relevance"
assert got[1].text == "medium relevance"
@pytest.mark.asyncio
async def test_rerank_failure_returns_unranked(self):
"""If reranker fails, fall back to unranked results."""
client = ComplianceRAGClient(base_url="http://fake")
candidates = [self._make_result("text", 0.9)] * 5
mock_reranker = MagicMock()
mock_reranker.rerank.side_effect = RuntimeError("model error")
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
got = await client.search_with_rerank("query", top_k=3)
assert len(got) == 3 # falls back to first top_k