"""Tests for Cross-Encoder Re-Ranking module.""" import pytest from unittest.mock import MagicMock, patch, AsyncMock from compliance.services.reranker import Reranker, get_reranker, RERANK_ENABLED from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult # ============================================================================= # Reranker Unit Tests # ============================================================================= class TestReranker: """Tests for Reranker class.""" def test_rerank_empty_texts(self): """Empty texts list returns empty indices.""" reranker = Reranker() assert reranker.rerank("query", [], top_k=5) == [] def test_rerank_returns_correct_indices(self): """Reranker returns indices sorted by score descending.""" reranker = Reranker() # Mock the cross-encoder model mock_model = MagicMock() # Scores: text[0]=0.1, text[1]=0.9, text[2]=0.5 mock_model.predict.return_value = [0.1, 0.9, 0.5] reranker._model = mock_model indices = reranker.rerank("test query", ["low", "high", "mid"], top_k=3) assert indices == [1, 2, 0] # sorted by score desc def test_rerank_top_k_limits_results(self): """top_k limits the number of returned indices.""" reranker = Reranker() mock_model = MagicMock() mock_model.predict.return_value = [0.1, 0.9, 0.5, 0.7, 0.3] reranker._model = mock_model indices = reranker.rerank("query", ["a", "b", "c", "d", "e"], top_k=2) assert len(indices) == 2 assert indices[0] == 1 # highest score assert indices[1] == 3 # second highest def test_rerank_sends_pairs_to_model(self): """Model receives [[query, text], ...] pairs.""" reranker = Reranker() mock_model = MagicMock() mock_model.predict.return_value = [0.5, 0.8] reranker._model = mock_model reranker.rerank("my query", ["text A", "text B"], top_k=2) call_args = mock_model.predict.call_args[0][0] assert call_args == [["my query", "text A"], ["my query", "text B"]] def test_lazy_init_not_loaded_until_rerank(self): """Model should not be loaded at construction time.""" reranker = Reranker() assert reranker._model is None def test_ensure_model_skips_if_already_loaded(self): """_ensure_model should not reload when model is already set.""" reranker = Reranker() mock_model = MagicMock() reranker._model = mock_model # Call _ensure_model — should short-circuit since _model is set reranker._ensure_model() reranker._ensure_model() # Model should still be the same mock assert reranker._model is mock_model # ============================================================================= # get_reranker Tests # ============================================================================= class TestGetReranker: """Tests for the get_reranker factory.""" def test_disabled_returns_none(self): """When RERANK_ENABLED=false, get_reranker returns None.""" with patch("compliance.services.reranker.RERANK_ENABLED", False): # Reset singleton import compliance.services.reranker as mod mod._reranker = None result = mod.get_reranker() assert result is None def test_enabled_returns_reranker(self): """When RERANK_ENABLED=true, get_reranker returns a Reranker instance.""" import compliance.services.reranker as mod mod._reranker = None with patch.object(mod, "RERANK_ENABLED", True): result = mod.get_reranker() assert isinstance(result, Reranker) mod._reranker = None # cleanup def test_singleton_returns_same_instance(self): """get_reranker returns the same instance on repeated calls.""" import compliance.services.reranker as mod mod._reranker = None with patch.object(mod, "RERANK_ENABLED", True): r1 = mod.get_reranker() r2 = mod.get_reranker() assert r1 is r2 mod._reranker = None # cleanup # ============================================================================= # search_with_rerank Integration Tests # ============================================================================= class TestSearchWithRerank: """Tests for ComplianceRAGClient.search_with_rerank.""" def _make_result(self, text: str, score: float) -> RAGSearchResult: return RAGSearchResult( text=text, regulation_code="eu_2016_679", regulation_name="DSGVO", regulation_short="DSGVO", category="regulation", article="", paragraph="", source_url="", score=score, ) @pytest.mark.asyncio async def test_rerank_disabled_falls_through(self): """When reranker is disabled, search_with_rerank calls regular search.""" client = ComplianceRAGClient(base_url="http://fake") results = [self._make_result("text1", 0.9)] with patch.object(client, "search", new_callable=AsyncMock, return_value=results): with patch("compliance.services.reranker.get_reranker", return_value=None): got = await client.search_with_rerank("query", top_k=5) assert len(got) == 1 assert got[0].text == "text1" @pytest.mark.asyncio async def test_rerank_reorders_results(self): """When reranker is enabled, results are re-ordered.""" client = ComplianceRAGClient(base_url="http://fake") candidates = [ self._make_result("low relevance", 0.9), self._make_result("high relevance", 0.7), self._make_result("medium relevance", 0.8), ] mock_reranker = MagicMock() # Reranker says index 1 is best, then 2, then 0 mock_reranker.rerank.return_value = [1, 2, 0] with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates): with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker): got = await client.search_with_rerank("query", top_k=2) # Should get reranked top 2 (but our mock returns [1,2,0] and top_k=2 # means reranker.rerank is called with top_k=2, which returns [1, 2]) mock_reranker.rerank.assert_called_once() # The rerank mock returns [1,2,0], so we get candidates[1] and candidates[2] assert got[0].text == "high relevance" assert got[1].text == "medium relevance" @pytest.mark.asyncio async def test_rerank_failure_returns_unranked(self): """If reranker fails, fall back to unranked results.""" client = ComplianceRAGClient(base_url="http://fake") candidates = [self._make_result("text", 0.9)] * 5 mock_reranker = MagicMock() mock_reranker.rerank.side_effect = RuntimeError("model error") with patch.object(client, "search", new_callable=AsyncMock, return_value=candidates): with patch("compliance.services.reranker.get_reranker", return_value=mock_reranker): got = await client.search_with_rerank("query", top_k=3) assert len(got) == 3 # falls back to first top_k