Files
breakpilot-compliance/backend-compliance/tests/test_reranker.py
Benjamin Admin c52dbdb8f1
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
feat(rag): optimize RAG pipeline — JSON-Mode, CoT, Hybrid Search, Re-Ranking, Cross-Reg Dedup, chunk 1024
Phase 1 (LLM Quality):
- Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill)
- Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts

Phase 2 (Retrieval Quality):
- Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go)
- Fallback to dense-only search if Query API unavailable
- Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default)
- CPU-only PyTorch dependency to keep Docker image small

Phase 3 (Data Layer):
- Cross-regulation dedup pass (threshold 0.95) links controls across regulations
- DedupResult.link_type field distinguishes dedup_merge vs cross_regulation
- Chunk size defaults updated 512/50 → 1024/128 for new ingestions only
- Existing collections and controls are NOT affected

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-21 11:49:43 +01:00

192 lines
7.2 KiB
Python

"""Tests for Cross-Encoder Re-Ranking module."""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from compliance.services.reranker import Reranker, get_reranker, RERANK_ENABLED
from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult
# =============================================================================
# Reranker Unit Tests
# =============================================================================
class TestReranker:
"""Tests for Reranker class."""
def test_rerank_empty_texts(self):
"""Empty texts list returns empty indices."""
reranker = Reranker()
assert reranker.rerank("query", [], top_k=5) == []
def test_rerank_returns_correct_indices(self):
"""Reranker returns indices sorted by score descending."""
reranker = Reranker()
# Mock the cross-encoder model
mock_model = MagicMock()
# Scores: text[0]=0.1, text[1]=0.9, text[2]=0.5
mock_model.predict.return_value = [0.1, 0.9, 0.5]
reranker._model = mock_model
indices = reranker.rerank("test query", ["low", "high", "mid"], top_k=3)
assert indices == [1, 2, 0] # sorted by score desc
def test_rerank_top_k_limits_results(self):
"""top_k limits the number of returned indices."""
reranker = Reranker()
mock_model = MagicMock()
mock_model.predict.return_value = [0.1, 0.9, 0.5, 0.7, 0.3]
reranker._model = mock_model
indices = reranker.rerank("query", ["a", "b", "c", "d", "e"], top_k=2)
assert len(indices) == 2
assert indices[0] == 1 # highest score
assert indices[1] == 3 # second highest
def test_rerank_sends_pairs_to_model(self):
"""Model receives [[query, text], ...] pairs."""
reranker = Reranker()
mock_model = MagicMock()
mock_model.predict.return_value = [0.5, 0.8]
reranker._model = mock_model
reranker.rerank("my query", ["text A", "text B"], top_k=2)
call_args = mock_model.predict.call_args[0][0]
assert call_args == [["my query", "text A"], ["my query", "text B"]]
def test_lazy_init_not_loaded_until_rerank(self):
"""Model should not be loaded at construction time."""
reranker = Reranker()
assert reranker._model is None
def test_ensure_model_skips_if_already_loaded(self):
"""_ensure_model should not reload when model is already set."""
reranker = Reranker()
mock_model = MagicMock()
reranker._model = mock_model
# Call _ensure_model — should short-circuit since _model is set
reranker._ensure_model()
reranker._ensure_model()
# Model should still be the same mock
assert reranker._model is mock_model
# =============================================================================
# get_reranker Tests
# =============================================================================
class TestGetReranker:
"""Tests for the get_reranker factory."""
def test_disabled_returns_none(self):
"""When RERANK_ENABLED=false, get_reranker returns None."""
with patch("compliance.services.reranker.RERANK_ENABLED", False):
# Reset singleton
import compliance.services.reranker as mod
mod._reranker = None
result = mod.get_reranker()
assert result is None
def test_enabled_returns_reranker(self):
"""When RERANK_ENABLED=true, get_reranker returns a Reranker instance."""
import compliance.services.reranker as mod
mod._reranker = None
with patch.object(mod, "RERANK_ENABLED", True):
result = mod.get_reranker()
assert isinstance(result, Reranker)
mod._reranker = None # cleanup
def test_singleton_returns_same_instance(self):
"""get_reranker returns the same instance on repeated calls."""
import compliance.services.reranker as mod
mod._reranker = None
with patch.object(mod, "RERANK_ENABLED", True):
r1 = mod.get_reranker()
r2 = mod.get_reranker()
assert r1 is r2
mod._reranker = None # cleanup
# =============================================================================
# search_with_rerank Integration Tests
# =============================================================================
class TestSearchWithRerank:
"""Tests for ComplianceRAGClient.search_with_rerank."""
def _make_result(self, text: str, score: float) -> RAGSearchResult:
return RAGSearchResult(
text=text, regulation_code="eu_2016_679",
regulation_name="DSGVO", regulation_short="DSGVO",
category="regulation", article="", paragraph="",
source_url="", score=score,
)
@pytest.mark.asyncio
async def test_rerank_disabled_falls_through(self):
"""When reranker is disabled, search_with_rerank calls regular search."""
client = ComplianceRAGClient(base_url="http://fake")
results = [self._make_result("text1", 0.9)]
with patch.object(client, "search", new_callable=AsyncMock, return_value=results):
with patch("compliance.services.reranker.get_reranker", return_value=None):
got = await client.search_with_rerank("query", top_k=5)
assert len(got) == 1
assert got[0].text == "text1"
@pytest.mark.asyncio
async def test_rerank_reorders_results(self):
"""When reranker is enabled, results are re-ordered."""
client = ComplianceRAGClient(base_url="http://fake")
candidates = [
self._make_result("low relevance", 0.9),
self._make_result("high relevance", 0.7),
self._make_result("medium relevance", 0.8),
]
mock_reranker = MagicMock()
# Reranker says index 1 is best, then 2, then 0
mock_reranker.rerank.return_value = [1, 2, 0]
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
got = await client.search_with_rerank("query", top_k=2)
# Should get reranked top 2 (but our mock returns [1,2,0] and top_k=2
# means reranker.rerank is called with top_k=2, which returns [1, 2])
mock_reranker.rerank.assert_called_once()
# The rerank mock returns [1,2,0], so we get candidates[1] and candidates[2]
assert got[0].text == "high relevance"
assert got[1].text == "medium relevance"
@pytest.mark.asyncio
async def test_rerank_failure_returns_unranked(self):
"""If reranker fails, fall back to unranked results."""
client = ComplianceRAGClient(base_url="http://fake")
candidates = [self._make_result("text", 0.9)] * 5
mock_reranker = MagicMock()
mock_reranker.rerank.side_effect = RuntimeError("model error")
with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates):
with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker):
got = await client.search_with_rerank("query", top_k=3)
assert len(got) == 3 # falls back to first top_k