"""Tests for provenance and atomic-stats endpoints. Covers: - GET /v1/canonical/controls/{control_id}/provenance - GET /v1/canonical/controls/atomic-stats """ import pytest from unittest.mock import MagicMock, patch from datetime import datetime from compliance.api.canonical_control_routes import ( get_control_provenance, atomic_stats, ) # ============================================================================= # HELPERS # ============================================================================= def _mock_row(**kwargs): """Create a mock DB row with attribute access.""" obj = MagicMock() for k, v in kwargs.items(): setattr(obj, k, v) return obj def _mock_db_execute(return_values): """Return a mock that cycles through return values for sequential .execute() calls.""" mock_db = MagicMock() results = iter(return_values) def execute_side_effect(*args, **kwargs): result = next(results) mock_result = MagicMock() if isinstance(result, list): mock_result.fetchall.return_value = result mock_result.fetchone.return_value = result[0] if result else None elif isinstance(result, int): mock_result.scalar.return_value = result elif result is None: mock_result.fetchone.return_value = None mock_result.fetchall.return_value = [] mock_result.scalar.return_value = 0 else: mock_result.fetchone.return_value = result mock_result.fetchall.return_value = [result] return mock_result mock_db.execute.side_effect = execute_side_effect return mock_db # ============================================================================= # PROVENANCE ENDPOINT # ============================================================================= class TestProvenanceEndpoint: """Tests for GET /controls/{control_id}/provenance.""" @pytest.mark.asyncio async def test_provenance_not_found(self): """404 when control doesn't exist.""" from fastapi import HTTPException mock_db = _mock_db_execute([None]) with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: mock_session.return_value.__enter__ = MagicMock(return_value=mock_db) mock_session.return_value.__exit__ = MagicMock(return_value=False) with pytest.raises(HTTPException) as exc_info: await get_control_provenance("NONEXISTENT-999") assert exc_info.value.status_code == 404 @pytest.mark.asyncio async def test_provenance_atomic_control(self): """Atomic control returns document_references, parent_links, merged_duplicates.""" import uuid ctrl_id = uuid.uuid4() ctrl_row = _mock_row( id=ctrl_id, control_id="SEC-042", title="Test Atomic Control", parent_control_uuid=None, decomposition_method="pass0b", source_citation=None, ) parent_link = _mock_row( parent_control_uuid=uuid.uuid4(), parent_control_id="DATA-005", parent_title="Parent Control", link_type="decomposition", confidence=0.95, source_regulation="DSGVO", source_article="Art. 32", parent_citation=None, obligation_text="Must encrypt", action="encrypt", object="personal data", normative_strength="must", obligation_candidate_id=None, ) child_row = _mock_row( control_id="SEC-042a", title="Child", category="encryption", severity="high", decomposition_method="pass0b", ) obligation_row = _mock_row( candidate_id="OBL-SEC-042-001", obligation_text="Test obligation", action="encrypt", object="data at rest", normative_strength="must", release_state="composed", ) doc_ref = _mock_row( regulation_code="DSGVO", article="Art. 32", paragraph="Abs. 1 lit. a", extraction_method="llm_extracted", confidence=0.92, ) merged = _mock_row( control_id="SEC-099", title="Encryption at rest (NIS2)", source_regulation="NIS2", ) mock_db = _mock_db_execute([ ctrl_row, # control lookup [parent_link], # parent_links [], # children [obligation_row], # obligations [doc_ref], # document_references [merged], # merged_duplicates ]) with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: mock_session.return_value.__enter__ = MagicMock(return_value=mock_db) mock_session.return_value.__exit__ = MagicMock(return_value=False) result = await get_control_provenance("SEC-042") assert result["control_id"] == "SEC-042" assert result["is_atomic"] is True assert len(result["parent_links"]) == 1 assert result["parent_links"][0]["parent_control_id"] == "DATA-005" assert result["obligation_count"] == 1 assert len(result["document_references"]) == 1 assert result["document_references"][0]["regulation_code"] == "DSGVO" assert len(result["merged_duplicates"]) == 1 assert result["merged_duplicates"][0]["control_id"] == "SEC-099" @pytest.mark.asyncio async def test_provenance_rich_control(self): """Rich control returns obligations list and children.""" import uuid ctrl_id = uuid.uuid4() ctrl_row = _mock_row( id=ctrl_id, control_id="DATA-005", title="Rich Control", parent_control_uuid=None, decomposition_method=None, source_citation={"source": "DSGVO"}, ) obligation_row = _mock_row( candidate_id="OBL-DATA-005-001", obligation_text="Encrypt personal data", action="encrypt", object="personal data", normative_strength="must", release_state="composed", ) child_row = _mock_row( control_id="SEC-042", title="Child Atomic", category="encryption", severity="high", decomposition_method="pass0b", ) mock_db = _mock_db_execute([ ctrl_row, # control lookup [], # parent_links [child_row], # children [obligation_row], # obligations [], # document_references [], # merged_duplicates ]) with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: mock_session.return_value.__enter__ = MagicMock(return_value=mock_db) mock_session.return_value.__exit__ = MagicMock(return_value=False) result = await get_control_provenance("DATA-005") assert result["control_id"] == "DATA-005" assert result["is_atomic"] is False assert result["obligation_count"] == 1 assert result["obligations"][0]["candidate_id"] == "OBL-DATA-005-001" assert len(result["children"]) == 1 assert result["children"][0]["control_id"] == "SEC-042" # ============================================================================= # ATOMIC STATS ENDPOINT # ============================================================================= class TestAtomicStatsEndpoint: """Tests for GET /controls/atomic-stats.""" @pytest.mark.asyncio async def test_atomic_stats_response_shape(self): """Stats endpoint returns expected aggregation fields.""" mock_db = _mock_db_execute([ 18234, # total_active 67000, # total_duplicate [ # by_domain _mock_row(**{"__getitem__": lambda s, i: ["SEC", 4200][i]}), ], [ # by_regulation _mock_row(**{"__getitem__": lambda s, i: ["DSGVO", 1200][i]}), ], 2.3, # avg_coverage ]) # Override __getitem__ for tuple-like access domain_row = MagicMock() domain_row.__getitem__ = lambda s, i: ["SEC", 4200][i] reg_row = MagicMock() reg_row.__getitem__ = lambda s, i: ["DSGVO", 1200][i] mock_db2 = MagicMock() call_count = [0] responses = [18234, 67000, [domain_row], [reg_row], 2.3] def execute_side(*args, **kwargs): idx = call_count[0] call_count[0] += 1 r = MagicMock() val = responses[idx] if isinstance(val, list): r.fetchall.return_value = val else: r.scalar.return_value = val return r mock_db2.execute.side_effect = execute_side with patch("compliance.api.canonical_control_routes.SessionLocal") as mock_session: mock_session.return_value.__enter__ = MagicMock(return_value=mock_db2) mock_session.return_value.__exit__ = MagicMock(return_value=False) result = await atomic_stats() assert result["total_active"] == 18234 assert result["total_duplicate"] == 67000 assert len(result["by_domain"]) == 1 assert result["by_domain"][0]["domain"] == "SEC" assert len(result["by_regulation"]) == 1 assert result["by_regulation"][0]["regulation"] == "DSGVO" assert result["avg_regulation_coverage"] == 2.3