"""Tests for the rationale backfill endpoint logic.""" import sys sys.path.insert(0, ".") import pytest from unittest.mock import AsyncMock, MagicMock, patch from compliance.api.canonical_control_routes import backfill_rationale class TestRationaleBackfillDryRun: """Dry-run mode should return statistics without touching DB.""" @pytest.mark.asyncio async def test_dry_run_returns_stats(self): mock_parents = [ MagicMock( parent_uuid="uuid-1", control_id="ACC-001", title="Access Control", category="access", source_name="OWASP ASVS", child_count=12, ), MagicMock( parent_uuid="uuid-2", control_id="SEC-042", title="Encryption", category="encryption", source_name="NIST SP 800-53", child_count=5, ), ] with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents result = await backfill_rationale(dry_run=True, batch_size=50, offset=0) assert result["dry_run"] is True assert result["total_parents"] == 2 assert result["total_children"] == 17 assert result["estimated_llm_calls"] == 2 assert len(result["sample_parents"]) == 2 assert result["sample_parents"][0]["control_id"] == "ACC-001" class TestRationaleBackfillExecution: """Execution mode should call LLM and update DB.""" @pytest.mark.asyncio async def test_processes_batch_and_updates(self): mock_parents = [ MagicMock( parent_uuid="uuid-1", control_id="ACC-001", title="Access Control", category="access", source_name="OWASP ASVS", child_count=5, ), ] mock_llm_response = MagicMock() mock_llm_response.content = ( "Die uebergeordneten Anforderungen an Zugriffskontrolle aus " "OWASP ASVS erfordern eine Zerlegung in atomare Massnahmen, " "um jede Einzelmassnahme unabhaengig testbar zu machen." ) mock_update_result = MagicMock() mock_update_result.rowcount = 5 with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents # Second call is the UPDATE db.execute.return_value.rowcount = 5 with patch("compliance.services.llm_provider.get_llm_provider") as mock_get: mock_provider = AsyncMock() mock_provider.complete.return_value = mock_llm_response mock_get.return_value = mock_provider result = await backfill_rationale( dry_run=False, batch_size=50, offset=0, ) assert result["dry_run"] is False assert result["processed_parents"] == 1 assert len(result["errors"]) == 0 assert len(result["sample_rationales"]) == 1 @pytest.mark.asyncio async def test_empty_batch_returns_done(self): with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = [] result = await backfill_rationale( dry_run=False, batch_size=50, offset=9999, ) assert result["processed"] == 0 assert "Kein weiterer Batch" in result["message"] @pytest.mark.asyncio async def test_llm_error_captured(self): mock_parents = [ MagicMock( parent_uuid="uuid-1", control_id="SEC-100", title="Network Security", category="network", source_name="ISO 27001", child_count=3, ), ] with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents with patch("compliance.services.llm_provider.get_llm_provider") as mock_get: mock_provider = AsyncMock() mock_provider.complete.side_effect = Exception("Ollama timeout") mock_get.return_value = mock_provider result = await backfill_rationale( dry_run=False, batch_size=50, offset=0, ) assert result["processed_parents"] == 0 assert len(result["errors"]) == 1 assert "Ollama timeout" in result["errors"][0]["error"] @pytest.mark.asyncio async def test_short_response_skipped(self): mock_parents = [ MagicMock( parent_uuid="uuid-1", control_id="GOV-001", title="Governance", category="governance", source_name="ISO 27001", child_count=2, ), ] mock_llm_response = MagicMock() mock_llm_response.content = "OK" # Too short with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents with patch("compliance.services.llm_provider.get_llm_provider") as mock_get: mock_provider = AsyncMock() mock_provider.complete.return_value = mock_llm_response mock_get.return_value = mock_provider result = await backfill_rationale( dry_run=False, batch_size=50, offset=0, ) assert result["processed_parents"] == 0 assert len(result["errors"]) == 1 assert "zu kurz" in result["errors"][0]["error"] class TestRationalePagination: """Pagination logic should work correctly.""" @pytest.mark.asyncio async def test_next_offset_set_when_more_remain(self): # 3 parents, batch_size=2 → next_offset=2 mock_parents = [ MagicMock( parent_uuid=f"uuid-{i}", control_id=f"SEC-{i:03d}", title=f"Control {i}", category="security", source_name="NIST", child_count=2, ) for i in range(3) ] mock_llm_response = MagicMock() mock_llm_response.content = ( "Sicherheitsanforderungen aus NIST erfordern atomare " "Massnahmen fuer unabhaengige Testbarkeit." ) with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents db.execute.return_value.rowcount = 2 with patch("compliance.services.llm_provider.get_llm_provider") as mock_get: mock_provider = AsyncMock() mock_provider.complete.return_value = mock_llm_response mock_get.return_value = mock_provider result = await backfill_rationale( dry_run=False, batch_size=2, offset=0, ) assert result["next_offset"] == 2 assert result["processed_parents"] == 2 @pytest.mark.asyncio async def test_next_offset_none_when_done(self): mock_parents = [ MagicMock( parent_uuid="uuid-1", control_id="SEC-001", title="Control 1", category="security", source_name="NIST", child_count=2, ), ] mock_llm_response = MagicMock() mock_llm_response.content = ( "Sicherheitsanforderungen erfordern atomare Massnahmen." ) with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: db = MagicMock() mock_session.return_value.__enter__ = MagicMock(return_value=db) mock_session.return_value.__exit__ = MagicMock(return_value=False) db.execute.return_value.fetchall.return_value = mock_parents db.execute.return_value.rowcount = 2 with patch("compliance.services.llm_provider.get_llm_provider") as mock_get: mock_provider = AsyncMock() mock_provider.complete.return_value = mock_llm_response mock_get.return_value = mock_provider result = await backfill_rationale( dry_run=False, batch_size=50, offset=0, ) assert result["next_offset"] is None