"""Tests for Compliance Scope routes (compliance_scope_routes.py).""" import json import pytest from unittest.mock import MagicMock, patch, call # --------------------------------------------------------------------------- # Helpers / shared fixtures # --------------------------------------------------------------------------- def _make_db_row(tenant_id, scope, created_at="2026-01-01 10:00:00", updated_at="2026-01-01 12:00:00"): """Return a mock DB row tuple for sdk_states queries.""" row = MagicMock() row.__getitem__ = lambda self, i: [tenant_id, scope, created_at, updated_at][i] row[0] = tenant_id row[1] = scope row[2] = created_at row[3] = updated_at return row def _make_row_indexable(tenant_id, scope, created_at="2026-01-01 10:00:00", updated_at="2026-01-01 12:00:00"): """Simple list-based row.""" return [tenant_id, scope, created_at, updated_at] # --------------------------------------------------------------------------- # Unit tests: _get_tid helper # --------------------------------------------------------------------------- class TestGetTid: """Tests for the _get_tid helper function.""" def test_prefers_x_tenant_header(self): from compliance.api.compliance_scope_routes import _get_tid assert _get_tid("header-val", "query-val") == "header-val" def test_falls_back_to_query(self): from compliance.api.compliance_scope_routes import _get_tid assert _get_tid(None, "query-val") == "query-val" def test_falls_back_to_default(self): from compliance.api.compliance_scope_routes import _get_tid assert _get_tid(None, None) == "default" def test_empty_string_as_falsy(self): from compliance.api.compliance_scope_routes import _get_tid assert _get_tid(None, "") == "default" # --------------------------------------------------------------------------- # Unit tests: _row_to_response helper # --------------------------------------------------------------------------- class TestRowToResponse: """Tests for the _row_to_response mapping function.""" def test_maps_correctly(self): from compliance.api.compliance_scope_routes import _row_to_response scope = {"frameworks": ["DSGVO"], "industry": "healthcare"} row = ["tenant-abc", scope, "2026-01-01 10:00:00", "2026-01-02 10:00:00"] result = _row_to_response(row) assert result.tenant_id == "tenant-abc" assert result.scope == scope assert "2026-01-01" in result.created_at assert "2026-01-02" in result.updated_at def test_handles_non_dict_scope(self): from compliance.api.compliance_scope_routes import _row_to_response row = ["t1", None, "2026-01-01", "2026-01-01"] result = _row_to_response(row) assert result.scope == {} def test_handles_empty_scope(self): from compliance.api.compliance_scope_routes import _row_to_response row = ["t1", {}, "2026-01-01", "2026-01-01"] result = _row_to_response(row) assert result.scope == {} def test_scope_nested_objects(self): from compliance.api.compliance_scope_routes import _row_to_response scope = {"frameworks": ["DSGVO", "NIS2"], "nested": {"key": "value"}} row = ["t2", scope, "2026-01-01", "2026-01-01"] result = _row_to_response(row) assert result.scope["frameworks"] == ["DSGVO", "NIS2"] # --------------------------------------------------------------------------- # Integration-style tests: GET endpoint # --------------------------------------------------------------------------- class TestGetComplianceScope: """Tests for GET /v1/compliance-scope.""" @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_returns_scope_when_found(self, mock_session_cls): from compliance.api.compliance_scope_routes import get_compliance_scope import asyncio scope = {"frameworks": ["DSGVO"], "industry": "it_services"} mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "tenant-1", scope, "2026-01-01 10:00:00", "2026-01-01 12:00:00" ] mock_session_cls.return_value = mock_db result = asyncio.get_event_loop().run_until_complete( get_compliance_scope(tenant_id="tenant-1") ) assert result.tenant_id == "tenant-1" assert result.scope == scope @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_raises_404_when_not_found(self, mock_session_cls): from compliance.api.compliance_scope_routes import get_compliance_scope from fastapi import HTTPException import asyncio mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = None mock_session_cls.return_value = mock_db with pytest.raises(HTTPException) as exc_info: asyncio.get_event_loop().run_until_complete( get_compliance_scope(tenant_id="unknown-tenant") ) assert exc_info.value.status_code == 404 @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_raises_404_when_scope_is_none(self, mock_session_cls): from compliance.api.compliance_scope_routes import get_compliance_scope from fastapi import HTTPException import asyncio mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = ["tenant-1", None, "x", "x"] mock_session_cls.return_value = mock_db with pytest.raises(HTTPException) as exc_info: asyncio.get_event_loop().run_until_complete( get_compliance_scope(tenant_id="tenant-1") ) assert exc_info.value.status_code == 404 @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_x_tenant_header_takes_precedence(self, mock_session_cls): from compliance.api.compliance_scope_routes import get_compliance_scope import asyncio scope = {"frameworks": ["ISO27001"]} mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "header-tenant", scope, "2026-01-01", "2026-01-01" ] mock_session_cls.return_value = mock_db result = asyncio.get_event_loop().run_until_complete( get_compliance_scope( tenant_id="query-tenant", x_tenant_id="header-tenant", ) ) # The query should use the header value call_args = mock_db.execute.call_args assert "header-tenant" in str(call_args) @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_db_always_closed(self, mock_session_cls): from compliance.api.compliance_scope_routes import get_compliance_scope from fastapi import HTTPException import asyncio mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = None mock_session_cls.return_value = mock_db try: asyncio.get_event_loop().run_until_complete( get_compliance_scope(tenant_id="t") ) except HTTPException: pass mock_db.close.assert_called_once() # --------------------------------------------------------------------------- # Integration-style tests: POST endpoint (UPSERT) # --------------------------------------------------------------------------- class TestUpsertComplianceScope: """Tests for POST /v1/compliance-scope.""" @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_creates_new_scope(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest import asyncio scope = {"frameworks": ["DSGVO", "NIS2"], "industry": "finance"} mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "tenant-1", scope, "2026-01-01", "2026-01-01" ] mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope=scope, tenant_id="tenant-1") result = asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) mock_db.execute.assert_called() mock_db.commit.assert_called_once() assert result.scope == scope @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_updates_existing_scope(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest import asyncio new_scope = {"frameworks": ["AI Act"], "industry": "healthcare"} mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "tenant-2", new_scope, "2026-01-01", "2026-02-01" ] mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope=new_scope, tenant_id="tenant-2") result = asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) assert result.scope == new_scope mock_db.commit.assert_called_once() @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_empty_scope_is_accepted(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest import asyncio mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "t", {}, "2026-01-01", "2026-01-01" ] mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope={}) result = asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) assert result.scope == {} @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_raises_500_on_db_error(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest from fastapi import HTTPException import asyncio mock_db = MagicMock() mock_db.execute.side_effect = Exception("DB connection error") mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope={"frameworks": ["DSGVO"]}) with pytest.raises(HTTPException) as exc_info: asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) assert exc_info.value.status_code == 500 mock_db.rollback.assert_called_once() @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_rollback_called_on_error(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest from fastapi import HTTPException import asyncio mock_db = MagicMock() mock_db.execute.side_effect = RuntimeError("unexpected") mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope={}) try: asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) except HTTPException: pass mock_db.rollback.assert_called_once() mock_db.close.assert_called_once() @patch("compliance.api.compliance_scope_routes.SessionLocal") def test_db_always_closed_on_success(self, mock_session_cls): from compliance.api.compliance_scope_routes import upsert_compliance_scope, ComplianceScopeRequest import asyncio mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = [ "t", {"frameworks": []}, "x", "x" ] mock_session_cls.return_value = mock_db body = ComplianceScopeRequest(scope={"frameworks": []}) asyncio.get_event_loop().run_until_complete( upsert_compliance_scope(body=body) ) mock_db.close.assert_called_once() # --------------------------------------------------------------------------- # Schema / model validation tests # --------------------------------------------------------------------------- class TestComplianceScopeRequest: """Tests for the ComplianceScopeRequest Pydantic model.""" def test_valid_scope(self): from compliance.api.compliance_scope_routes import ComplianceScopeRequest r = ComplianceScopeRequest(scope={"frameworks": ["DSGVO"]}) assert r.scope == {"frameworks": ["DSGVO"]} def test_tenant_id_optional(self): from compliance.api.compliance_scope_routes import ComplianceScopeRequest r = ComplianceScopeRequest(scope={}) assert r.tenant_id is None def test_tenant_id_can_be_set(self): from compliance.api.compliance_scope_routes import ComplianceScopeRequest r = ComplianceScopeRequest(scope={}, tenant_id="abc-123") assert r.tenant_id == "abc-123" def test_complex_scope_accepted(self): from compliance.api.compliance_scope_routes import ComplianceScopeRequest scope = { "frameworks": ["DSGVO", "AI Act", "NIS2"], "industry": "healthcare", "company_size": "medium", "answers": {"q1": True, "q2": "B2B"}, } r = ComplianceScopeRequest(scope=scope) assert len(r.scope["frameworks"]) == 3 class TestComplianceScopeResponse: """Tests for the ComplianceScopeResponse Pydantic model.""" def test_valid_response(self): from compliance.api.compliance_scope_routes import ComplianceScopeResponse r = ComplianceScopeResponse( tenant_id="t1", scope={"frameworks": ["DSGVO"]}, updated_at="2026-01-01", created_at="2026-01-01", ) assert r.tenant_id == "t1" def test_empty_scope_response(self): from compliance.api.compliance_scope_routes import ComplianceScopeResponse r = ComplianceScopeResponse( tenant_id="t1", scope={}, updated_at="x", created_at="x", ) assert r.scope == {} # --------------------------------------------------------------------------- # Router config tests # --------------------------------------------------------------------------- class TestRouterConfig: """Tests for router prefix and tags.""" def test_router_prefix(self): from compliance.api.compliance_scope_routes import router assert router.prefix == "/v1/compliance-scope" def test_router_tags(self): from compliance.api.compliance_scope_routes import router assert "compliance-scope" in router.tags