"""Tests for Canonical Control Library routes (canonical_control_routes.py). Includes: - Model validation tests (FrameworkResponse, ControlResponse, etc.) - _control_row conversion tests - Server-side pagination, sorting, search, source filter tests - /controls-count and /controls-meta endpoint tests """ import pytest from unittest.mock import MagicMock, patch from datetime import datetime, timezone from fastapi import FastAPI from fastapi.testclient import TestClient from compliance.api.canonical_control_routes import ( FrameworkResponse, ControlResponse, SimilarityCheckRequest, SimilarityCheckResponse, _control_row, router, ) # --------------------------------------------------------------------------- # TestClient setup for endpoint tests # --------------------------------------------------------------------------- _app = FastAPI() _app.include_router(router, prefix="/api/compliance") _client = TestClient(_app) class TestFrameworkResponse: """Tests for FrameworkResponse model.""" def test_basic_creation(self): resp = FrameworkResponse( id="uuid-1", framework_id="bp_security_v1", name="BreakPilot Security Controls", version="1.0", release_state="draft", created_at="2026-03-12T00:00:00+00:00", updated_at="2026-03-12T00:00:00+00:00", ) assert resp.framework_id == "bp_security_v1" assert resp.version == "1.0" def test_optional_fields(self): resp = FrameworkResponse( id="uuid-1", framework_id="test", name="Test", version="1.0", release_state="draft", created_at="2026-03-12T00:00:00+00:00", updated_at="2026-03-12T00:00:00+00:00", ) assert resp.description is None assert resp.owner is None assert resp.policy_version is None class TestControlResponse: """Tests for ControlResponse model.""" def test_full_control(self): resp = ControlResponse( id="uuid-1", framework_id="uuid-fw", control_id="AUTH-001", title="Multi-Factor Authentication", objective="Require MFA for privileged access.", rationale="Passwords alone are insufficient.", scope={"platforms": ["web"]}, requirements=["MFA for admin accounts"], test_procedure=["Test admin login without MFA"], evidence=[{"type": "config", "description": "MFA config"}], severity="high", open_anchors=[{"framework": "OWASP ASVS", "ref": "V2.8", "url": "https://owasp.org"}], release_state="draft", tags=["mfa", "auth"], created_at="2026-03-12T00:00:00+00:00", updated_at="2026-03-12T00:00:00+00:00", ) assert resp.control_id == "AUTH-001" assert resp.severity == "high" assert len(resp.open_anchors) == 1 def test_optional_numeric_fields(self): resp = ControlResponse( id="uuid-1", framework_id="uuid-fw", control_id="NET-001", title="TLS", objective="Encrypt traffic.", rationale="Prevent eavesdropping.", scope={}, requirements=[], test_procedure=[], evidence=[], severity="high", open_anchors=[], release_state="draft", tags=[], created_at="2026-03-12T00:00:00+00:00", updated_at="2026-03-12T00:00:00+00:00", ) assert resp.risk_score is None assert resp.implementation_effort is None assert resp.evidence_confidence is None class TestSimilarityCheckRequest: """Tests for SimilarityCheckRequest model.""" def test_valid_request(self): req = SimilarityCheckRequest( source_text="Die Anwendung muss MFA implementieren.", candidate_text="Multi-factor authentication is required.", ) assert req.source_text == "Die Anwendung muss MFA implementieren." assert req.candidate_text == "Multi-factor authentication is required." def test_empty_strings(self): req = SimilarityCheckRequest(source_text="", candidate_text="") assert req.source_text == "" class TestSimilarityCheckResponse: """Tests for SimilarityCheckResponse model.""" def test_pass_status(self): resp = SimilarityCheckResponse( max_exact_run=2, token_overlap=0.05, ngram_jaccard=0.03, embedding_cosine=0.45, lcs_ratio=0.12, status="PASS", details={ "max_exact_run": "PASS", "token_overlap": "PASS", "ngram_jaccard": "PASS", "embedding_cosine": "PASS", "lcs_ratio": "PASS", }, ) assert resp.status == "PASS" def test_fail_status(self): resp = SimilarityCheckResponse( max_exact_run=15, token_overlap=0.35, ngram_jaccard=0.20, embedding_cosine=0.95, lcs_ratio=0.55, status="FAIL", details={ "max_exact_run": "FAIL", "token_overlap": "FAIL", "ngram_jaccard": "FAIL", "embedding_cosine": "FAIL", "lcs_ratio": "FAIL", }, ) assert resp.status == "FAIL" class TestControlRowConversion: """Tests for _control_row helper.""" def _make_row(self, **overrides): now = datetime.now(timezone.utc) defaults = { "id": "uuid-ctrl-1", "framework_id": "uuid-fw-1", "control_id": "AUTH-001", "title": "Multi-Factor Authentication", "objective": "Require MFA.", "rationale": "Passwords insufficient.", "scope": {"platforms": ["web", "mobile"]}, "requirements": ["Req 1", "Req 2"], "test_procedure": ["Test 1"], "evidence": [{"type": "config", "description": "MFA config"}], "severity": "high", "risk_score": 8.5, "implementation_effort": "m", "evidence_confidence": 0.85, "open_anchors": [ {"framework": "OWASP ASVS", "ref": "V2.8", "url": "https://owasp.org"}, ], "release_state": "draft", "tags": ["mfa"], "generation_strategy": "ungrouped", "parent_control_uuid": None, "parent_control_id": None, "parent_control_title": None, "decomposition_method": None, "pipeline_version": None, "created_at": now, "updated_at": now, } defaults.update(overrides) mock = MagicMock() for key, value in defaults.items(): setattr(mock, key, value) return mock def test_basic_conversion(self): row = self._make_row() result = _control_row(row) assert result["control_id"] == "AUTH-001" assert result["severity"] == "high" assert result["risk_score"] == 8.5 assert result["implementation_effort"] == "m" assert result["evidence_confidence"] == 0.85 assert len(result["open_anchors"]) == 1 def test_null_numeric_fields(self): row = self._make_row(risk_score=None, evidence_confidence=None, implementation_effort=None) result = _control_row(row) assert result["risk_score"] is None assert result["evidence_confidence"] is None assert result["implementation_effort"] is None def test_empty_tags(self): row = self._make_row(tags=None) result = _control_row(row) assert result["tags"] == [] def test_empty_tags_list(self): row = self._make_row(tags=[]) result = _control_row(row) assert result["tags"] == [] def test_timestamp_format(self): now = datetime(2026, 3, 12, 10, 30, 0, tzinfo=timezone.utc) row = self._make_row(created_at=now, updated_at=now) result = _control_row(row) assert "2026-03-12" in result["created_at"] assert "10:30" in result["created_at"] def test_none_timestamps(self): row = self._make_row(created_at=None, updated_at=None) result = _control_row(row) assert result["created_at"] is None assert result["updated_at"] is None def test_generation_strategy_default(self): row = self._make_row() result = _control_row(row) assert result["generation_strategy"] == "ungrouped" def test_generation_strategy_document_grouped(self): row = self._make_row(generation_strategy="document_grouped") result = _control_row(row) assert result["generation_strategy"] == "document_grouped" # ============================================================================= # ENDPOINT TESTS — Server-Side Pagination, Sort, Search, Source Filter # ============================================================================= def _make_mock_row(**overrides): """Build a mock Row with all canonical_controls columns.""" now = datetime.now(timezone.utc) defaults = { "id": "uuid-ctrl-1", "framework_id": "uuid-fw-1", "control_id": "AUTH-001", "title": "Test Control", "objective": "Test obj", "rationale": "Test rat", "scope": {}, "requirements": ["Req 1"], "test_procedure": ["Test 1"], "evidence": [], "severity": "high", "risk_score": 3.0, "implementation_effort": "m", "evidence_confidence": None, "open_anchors": [], "release_state": "draft", "tags": [], "license_rule": 1, "source_original_text": None, "source_citation": None, "customer_visible": True, "verification_method": "automated", "category": "authentication", "target_audience": "developer", "generation_metadata": {}, "generation_strategy": "ungrouped", "created_at": now, "updated_at": now, } defaults.update(overrides) mock = MagicMock() for k, v in defaults.items(): setattr(mock, k, v) return mock def _session_returning(rows=None, scalar=None): """Create a mock SessionLocal that returns rows or scalar.""" db = MagicMock() result = MagicMock() if rows is not None: result.fetchall.return_value = rows if scalar is not None: result.scalar.return_value = scalar db.execute.return_value = result db.__enter__ = MagicMock(return_value=db) db.__exit__ = MagicMock(return_value=False) return db class TestListControlsPagination: """GET /controls with limit/offset.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_limit_param_in_sql(self, mock_cls): mock_cls.return_value = _session_returning(rows=[_make_mock_row()]) resp = _client.get("/api/compliance/v1/canonical/controls?limit=10&offset=20") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "LIMIT" in sql assert "OFFSET" in sql @patch("compliance.api.canonical_control_routes.SessionLocal") def test_no_limit_by_default(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "LIMIT" not in sql class TestListControlsSorting: """GET /controls with sort/order.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_sort_created_at_desc(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?sort=created_at&order=desc") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "created_at DESC" in sql @patch("compliance.api.canonical_control_routes.SessionLocal") def test_default_sort_control_id_asc(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "control_id ASC" in sql @patch("compliance.api.canonical_control_routes.SessionLocal") def test_sql_injection_in_sort_blocked(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?sort=1;DROP+TABLE") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "DROP" not in sql assert "control_id" in sql @patch("compliance.api.canonical_control_routes.SessionLocal") def test_sort_by_source(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?sort=source&order=asc") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "source_citation" in sql assert "control_id ASC" in sql # secondary sort within source group class TestListControlsSearch: """GET /controls with search.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_search_uses_ilike(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?search=encryption") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "ILIKE" in sql params = mock_cls.return_value.__enter__().execute.call_args[0][1] assert params["q"] == "%encryption%" class TestListControlsSourceFilter: """GET /controls with source filter.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_specific_source(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?source=DSGVO") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "source_citation" in sql params = mock_cls.return_value.__enter__().execute.call_args[0][1] assert params["src"] == "DSGVO" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_no_source_filter(self, mock_cls): mock_cls.return_value = _session_returning(rows=[]) resp = _client.get("/api/compliance/v1/canonical/controls?source=__none__") assert resp.status_code == 200 sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "IS NULL" in sql class TestControlsCount: """GET /controls-count.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_returns_total(self, mock_cls): mock_cls.return_value = _session_returning(scalar=42) resp = _client.get("/api/compliance/v1/canonical/controls-count") assert resp.status_code == 200 assert resp.json() == {"total": 42} @patch("compliance.api.canonical_control_routes.SessionLocal") def test_with_filters(self, mock_cls): mock_cls.return_value = _session_returning(scalar=5) resp = _client.get("/api/compliance/v1/canonical/controls-count?severity=critical&search=mfa") assert resp.status_code == 200 assert resp.json() == {"total": 5} sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text) assert "severity" in sql assert "ILIKE" in sql class TestControlsMeta: """GET /controls-meta.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_returns_structure(self, mock_cls): db = MagicMock() db.__enter__ = MagicMock(return_value=db) db.__exit__ = MagicMock(return_value=False) # Faceted meta does many execute() calls — use a default mock scalar_r = MagicMock() scalar_r.scalar.return_value = 100 scalar_r.fetchall.return_value = [] db.execute.return_value = scalar_r mock_cls.return_value = db resp = _client.get("/api/compliance/v1/canonical/controls-meta") assert resp.status_code == 200 data = resp.json() assert data["total"] == 100 assert isinstance(data["domains"], list) assert isinstance(data["sources"], list) assert "type_counts" in data assert "severity_counts" in data assert "verification_method_counts" in data assert "category_counts" in data assert "evidence_type_counts" in data assert "release_state_counts" in data class TestObligationDedup: """Tests for obligation deduplication endpoints.""" @patch("compliance.api.canonical_control_routes.SessionLocal") def test_dedup_dry_run(self, mock_cls): db = MagicMock() db.__enter__ = MagicMock(return_value=db) db.__exit__ = MagicMock(return_value=False) mock_cls.return_value = db # Mock: 2 duplicate groups dup_row1 = MagicMock(candidate_id="OC-AUTH-001-01", cnt=3) dup_row2 = MagicMock(candidate_id="OC-AUTH-001-02", cnt=2) # Entries for group 1 import uuid uid1 = uuid.uuid4() uid2 = uuid.uuid4() uid3 = uuid.uuid4() entry1 = MagicMock(id=uid1, candidate_id="OC-AUTH-001-01", obligation_text="Text A", release_state="composed", created_at=datetime(2026, 1, 1, tzinfo=timezone.utc)) entry2 = MagicMock(id=uid2, candidate_id="OC-AUTH-001-01", obligation_text="Text B", release_state="composed", created_at=datetime(2026, 1, 2, tzinfo=timezone.utc)) entry3 = MagicMock(id=uid3, candidate_id="OC-AUTH-001-01", obligation_text="Text C", release_state="composed", created_at=datetime(2026, 1, 3, tzinfo=timezone.utc)) # Entries for group 2 uid4 = uuid.uuid4() uid5 = uuid.uuid4() entry4 = MagicMock(id=uid4, candidate_id="OC-AUTH-001-02", obligation_text="Text D", release_state="composed", created_at=datetime(2026, 1, 1, tzinfo=timezone.utc)) entry5 = MagicMock(id=uid5, candidate_id="OC-AUTH-001-02", obligation_text="Text E", release_state="composed", created_at=datetime(2026, 1, 2, tzinfo=timezone.utc)) # Side effects: 1) dup groups, 2) total count, 3) entries grp1, 4) entries grp2 mock_result_groups = MagicMock() mock_result_groups.fetchall.return_value = [dup_row1, dup_row2] mock_result_total = MagicMock() mock_result_total.scalar.return_value = 2 mock_result_entries1 = MagicMock() mock_result_entries1.fetchall.return_value = [entry1, entry2, entry3] mock_result_entries2 = MagicMock() mock_result_entries2.fetchall.return_value = [entry4, entry5] db.execute.side_effect = [mock_result_groups, mock_result_total, mock_result_entries1, mock_result_entries2] resp = _client.post("/api/compliance/v1/canonical/obligations/dedup?dry_run=true") assert resp.status_code == 200 data = resp.json() assert data["dry_run"] is True assert data["stats"]["total_duplicate_groups"] == 2 assert data["stats"]["kept"] == 2 assert data["stats"]["marked_duplicate"] == 3 # 2 from grp1 + 1 from grp2 # Dry run: no commit db.commit.assert_not_called() @patch("compliance.api.canonical_control_routes.SessionLocal") def test_dedup_stats(self, mock_cls): db = MagicMock() db.__enter__ = MagicMock(return_value=db) db.__exit__ = MagicMock(return_value=False) mock_cls.return_value = db # total, by_state, dup_groups, removable mock_total = MagicMock() mock_total.scalar.return_value = 76046 mock_states = MagicMock() mock_states.fetchall.return_value = [ MagicMock(release_state="composed", cnt=41217), MagicMock(release_state="duplicate", cnt=34829), ] mock_dup_groups = MagicMock() mock_dup_groups.scalar.return_value = 0 mock_removable = MagicMock() mock_removable.scalar.return_value = 0 db.execute.side_effect = [mock_total, mock_states, mock_dup_groups, mock_removable] resp = _client.get("/api/compliance/v1/canonical/obligations/dedup-stats") assert resp.status_code == 200 data = resp.json() assert data["total_obligations"] == 76046 assert data["by_state"]["composed"] == 41217 assert data["by_state"]["duplicate"] == 34829 assert data["pending_duplicate_groups"] == 0 assert data["pending_removable_duplicates"] == 0