Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
Phase 1 (LLM Quality): - Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill) - Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts Phase 2 (Retrieval Quality): - Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go) - Fallback to dense-only search if Query API unavailable - Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default) - CPU-only PyTorch dependency to keep Docker image small Phase 3 (Data Layer): - Cross-regulation dedup pass (threshold 0.95) links controls across regulations - DedupResult.link_type field distinguishes dedup_merge vs cross_regulation - Chunk size defaults updated 512/50 → 1024/128 for new ingestions only - Existing collections and controls are NOT affected Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
192 lines
7.2 KiB
Python
192 lines
7.2 KiB
Python
"""Tests for Cross-Encoder Re-Ranking module."""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
|
|
from compliance.services.reranker import Reranker, get_reranker, RERANK_ENABLED
|
|
from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult
|
|
|
|
|
|
# =============================================================================
|
|
# Reranker Unit Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestReranker:
|
|
"""Tests for Reranker class."""
|
|
|
|
def test_rerank_empty_texts(self):
|
|
"""Empty texts list returns empty indices."""
|
|
reranker = Reranker()
|
|
assert reranker.rerank("query", [], top_k=5) == []
|
|
|
|
def test_rerank_returns_correct_indices(self):
|
|
"""Reranker returns indices sorted by score descending."""
|
|
reranker = Reranker()
|
|
|
|
# Mock the cross-encoder model
|
|
mock_model = MagicMock()
|
|
# Scores: text[0]=0.1, text[1]=0.9, text[2]=0.5
|
|
mock_model.predict.return_value = [0.1, 0.9, 0.5]
|
|
reranker._model = mock_model
|
|
|
|
indices = reranker.rerank("test query", ["low", "high", "mid"], top_k=3)
|
|
|
|
assert indices == [1, 2, 0] # sorted by score desc
|
|
|
|
def test_rerank_top_k_limits_results(self):
|
|
"""top_k limits the number of returned indices."""
|
|
reranker = Reranker()
|
|
|
|
mock_model = MagicMock()
|
|
mock_model.predict.return_value = [0.1, 0.9, 0.5, 0.7, 0.3]
|
|
reranker._model = mock_model
|
|
|
|
indices = reranker.rerank("query", ["a", "b", "c", "d", "e"], top_k=2)
|
|
|
|
assert len(indices) == 2
|
|
assert indices[0] == 1 # highest score
|
|
assert indices[1] == 3 # second highest
|
|
|
|
def test_rerank_sends_pairs_to_model(self):
|
|
"""Model receives [[query, text], ...] pairs."""
|
|
reranker = Reranker()
|
|
|
|
mock_model = MagicMock()
|
|
mock_model.predict.return_value = [0.5, 0.8]
|
|
reranker._model = mock_model
|
|
|
|
reranker.rerank("my query", ["text A", "text B"], top_k=2)
|
|
|
|
call_args = mock_model.predict.call_args[0][0]
|
|
assert call_args == [["my query", "text A"], ["my query", "text B"]]
|
|
|
|
def test_lazy_init_not_loaded_until_rerank(self):
|
|
"""Model should not be loaded at construction time."""
|
|
reranker = Reranker()
|
|
assert reranker._model is None
|
|
|
|
def test_ensure_model_skips_if_already_loaded(self):
|
|
"""_ensure_model should not reload when model is already set."""
|
|
reranker = Reranker()
|
|
|
|
mock_model = MagicMock()
|
|
reranker._model = mock_model
|
|
|
|
# Call _ensure_model — should short-circuit since _model is set
|
|
reranker._ensure_model()
|
|
reranker._ensure_model()
|
|
|
|
# Model should still be the same mock
|
|
assert reranker._model is mock_model
|
|
|
|
|
|
# =============================================================================
|
|
# get_reranker Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestGetReranker:
|
|
"""Tests for the get_reranker factory."""
|
|
|
|
def test_disabled_returns_none(self):
|
|
"""When RERANK_ENABLED=false, get_reranker returns None."""
|
|
with patch("compliance.services.reranker.RERANK_ENABLED", False):
|
|
# Reset singleton
|
|
import compliance.services.reranker as mod
|
|
mod._reranker = None
|
|
result = mod.get_reranker()
|
|
assert result is None
|
|
|
|
def test_enabled_returns_reranker(self):
|
|
"""When RERANK_ENABLED=true, get_reranker returns a Reranker instance."""
|
|
import compliance.services.reranker as mod
|
|
mod._reranker = None
|
|
with patch.object(mod, "RERANK_ENABLED", True):
|
|
result = mod.get_reranker()
|
|
assert isinstance(result, Reranker)
|
|
mod._reranker = None # cleanup
|
|
|
|
def test_singleton_returns_same_instance(self):
|
|
"""get_reranker returns the same instance on repeated calls."""
|
|
import compliance.services.reranker as mod
|
|
mod._reranker = None
|
|
with patch.object(mod, "RERANK_ENABLED", True):
|
|
r1 = mod.get_reranker()
|
|
r2 = mod.get_reranker()
|
|
assert r1 is r2
|
|
mod._reranker = None # cleanup
|
|
|
|
|
|
# =============================================================================
|
|
# search_with_rerank Integration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestSearchWithRerank:
|
|
"""Tests for ComplianceRAGClient.search_with_rerank."""
|
|
|
|
def _make_result(self, text: str, score: float) -> RAGSearchResult:
|
|
return RAGSearchResult(
|
|
text=text, regulation_code="eu_2016_679",
|
|
regulation_name="DSGVO", regulation_short="DSGVO",
|
|
category="regulation", article="", paragraph="",
|
|
source_url="", score=score,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rerank_disabled_falls_through(self):
|
|
"""When reranker is disabled, search_with_rerank calls regular search."""
|
|
client = ComplianceRAGClient(base_url="http://fake")
|
|
|
|
results = [self._make_result("text1", 0.9)]
|
|
|
|
with patch.object(client, "search", new_callable=AsyncMock, return_value=results):
|
|
with patch("compliance.services.reranker.get_reranker", return_value=None):
|
|
got = await client.search_with_rerank("query", top_k=5)
|
|
|
|
assert len(got) == 1
|
|
assert got[0].text == "text1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rerank_reorders_results(self):
|
|
"""When reranker is enabled, results are re-ordered."""
|
|
client = ComplianceRAGClient(base_url="http://fake")
|
|
|
|
candidates = [
|
|
self._make_result("low relevance", 0.9),
|
|
self._make_result("high relevance", 0.7),
|
|
self._make_result("medium relevance", 0.8),
|
|
]
|
|
|
|
mock_reranker = MagicMock()
|
|
# Reranker says index 1 is best, then 2, then 0
|
|
mock_reranker.rerank.return_value = [1, 2, 0]
|
|
|
|
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
|
|
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
|
|
got = await client.search_with_rerank("query", top_k=2)
|
|
|
|
# Should get reranked top 2 (but our mock returns [1,2,0] and top_k=2
|
|
# means reranker.rerank is called with top_k=2, which returns [1, 2])
|
|
mock_reranker.rerank.assert_called_once()
|
|
# The rerank mock returns [1,2,0], so we get candidates[1] and candidates[2]
|
|
assert got[0].text == "high relevance"
|
|
assert got[1].text == "medium relevance"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rerank_failure_returns_unranked(self):
|
|
"""If reranker fails, fall back to unranked results."""
|
|
client = ComplianceRAGClient(base_url="http://fake")
|
|
|
|
candidates = [self._make_result("text", 0.9)] * 5
|
|
|
|
mock_reranker = MagicMock()
|
|
mock_reranker.rerank.side_effect = RuntimeError("model error")
|
|
|
|
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
|
|
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
|
|
got = await client.search_with_rerank("query", top_k=3)
|
|
|
|
assert len(got) == 3 # falls back to first top_k
|