"""Tests for TOM routes (tom_routes.py, tom_models.py).""" import pytest import uuid from unittest.mock import MagicMock, patch, PropertyMock from datetime import datetime, timezone from fastapi.testclient import TestClient from fastapi import FastAPI from compliance.api.tom_routes import ( router, TOMStateBody, TOMMeasureCreate, TOMMeasureUpdate, TOMMeasureBulkBody, TOMMeasureBulkItem, _parse_dt, _measure_to_dict, DEFAULT_TENANT_ID, ) from compliance.db.tom_models import TOMStateDB, TOMMeasureDB from compliance.api.schemas import TOMStatsResponse, TOMMeasureResponse # ============================================================================= # Test App Setup # ============================================================================= app = FastAPI() app.include_router(router) DEFAULT_TENANT = DEFAULT_TENANT_ID MEASURE_ID = "ffffffff-0001-0001-0001-000000000001" UNKNOWN_ID = "aaaaaaaa-9999-9999-9999-999999999999" # ============================================================================= # Helper: create mock DB session # ============================================================================= def _make_mock_db(): db = MagicMock() db.query.return_value = db db.filter.return_value = db db.first.return_value = None db.count.return_value = 0 db.all.return_value = [] db.offset.return_value = db db.limit.return_value = db db.order_by.return_value = db db.group_by.return_value = db return db def _make_state_row(tenant_id=DEFAULT_TENANT, version=1, state=None): row = TOMStateDB() row.id = uuid.uuid4() row.tenant_id = tenant_id row.state = state or {"steps": [], "derivedTOMs": []} row.version = version row.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) row.updated_at = datetime(2024, 1, 2, tzinfo=timezone.utc) return row def _make_measure_row(control_id="TOM.GOV.01", **kwargs): m = TOMMeasureDB() m.id = uuid.UUID(kwargs.get("id", MEASURE_ID)) m.tenant_id = kwargs.get("tenant_id", DEFAULT_TENANT) m.control_id = control_id m.name = kwargs.get("name", "Datenschutzrichtlinie") m.description = kwargs.get("description", "Beschreibung") m.category = kwargs.get("category", "GOVERNANCE") m.type = kwargs.get("type", "ORGANIZATIONAL") m.applicability = kwargs.get("applicability", "REQUIRED") m.applicability_reason = kwargs.get("applicability_reason", None) m.implementation_status = kwargs.get("implementation_status", "NOT_IMPLEMENTED") m.responsible_person = kwargs.get("responsible_person", None) m.responsible_department = kwargs.get("responsible_department", None) m.implementation_date = kwargs.get("implementation_date", None) m.review_date = kwargs.get("review_date", None) m.review_frequency = kwargs.get("review_frequency", "ANNUAL") m.priority = kwargs.get("priority", "HIGH") m.complexity = kwargs.get("complexity", "MEDIUM") m.linked_evidence = kwargs.get("linked_evidence", []) m.evidence_gaps = kwargs.get("evidence_gaps", []) m.related_controls = kwargs.get("related_controls", {}) m.verified_at = kwargs.get("verified_at", None) m.verified_by = kwargs.get("verified_by", None) m.effectiveness_rating = kwargs.get("effectiveness_rating", None) m.created_by = kwargs.get("created_by", "system") m.created_at = kwargs.get("created_at", datetime(2024, 1, 1, tzinfo=timezone.utc)) m.updated_at = kwargs.get("updated_at", datetime(2024, 1, 2, tzinfo=timezone.utc)) return m # ============================================================================= # Schema Tests # ============================================================================= class TestTOMStateBody: def test_get_tenant_id_from_tenant_id(self): body = TOMStateBody(tenant_id="abc", state={}) assert body.get_tenant_id() == "abc" def test_get_tenant_id_from_camelcase(self): body = TOMStateBody(tenantId="def", state={}) assert body.get_tenant_id() == "def" def test_get_tenant_id_default(self): body = TOMStateBody(state={}) assert body.get_tenant_id() == DEFAULT_TENANT def test_version_optional(self): body = TOMStateBody(tenant_id="x", state={"foo": "bar"}) assert body.version is None class TestTOMMeasureCreate: def test_defaults(self): mc = TOMMeasureCreate( control_id="TOM.GOV.01", name="Test", category="GOVERNANCE", type="ORGANIZATIONAL", ) assert mc.applicability == "REQUIRED" assert mc.implementation_status == "NOT_IMPLEMENTED" assert mc.priority is None assert mc.linked_evidence is None def test_full_values(self): mc = TOMMeasureCreate( control_id="TOM.ACC.02", name="Zugriffskontrolle", description="RBAC implementieren", category="ACCESS_CONTROL", type="TECHNICAL", applicability="REQUIRED", implementation_status="IMPLEMENTED", priority="CRITICAL", complexity="HIGH", ) assert mc.control_id == "TOM.ACC.02" assert mc.priority == "CRITICAL" class TestTOMMeasureUpdate: def test_partial(self): mu = TOMMeasureUpdate(implementation_status="IMPLEMENTED") data = mu.model_dump(exclude_unset=True) assert data == {"implementation_status": "IMPLEMENTED"} def test_empty(self): mu = TOMMeasureUpdate() data = mu.model_dump(exclude_unset=True) assert data == {} class TestTOMStatsResponse: def test_defaults(self): stats = TOMStatsResponse() assert stats.total == 0 assert stats.by_status == {} assert stats.overdue_review_count == 0 def test_full(self): stats = TOMStatsResponse( total=10, by_status={"IMPLEMENTED": 5, "NOT_IMPLEMENTED": 3, "PARTIAL": 2}, by_category={"GOVERNANCE": 4, "ACCESS_CONTROL": 6}, overdue_review_count=2, implemented=5, partial=2, not_implemented=3, ) assert stats.total == 10 assert stats.implemented == 5 class TestTOMMeasureResponse: def test_from_dict(self): resp = TOMMeasureResponse( id="abc", tenant_id=DEFAULT_TENANT, control_id="TOM.GOV.01", name="Test", category="GOVERNANCE", type="ORGANIZATIONAL", ) assert resp.id == "abc" assert resp.linked_evidence == [] # ============================================================================= # DB Model Tests # ============================================================================= class TestTOMModels: def test_state_repr(self): s = TOMStateDB() s.tenant_id = "test" s.version = 3 assert "test" in repr(s) assert "v3" in repr(s) def test_measure_repr(self): m = TOMMeasureDB() m.control_id = "TOM.ACC.01" m.name = "Zugriffskontrolle" assert "TOM.ACC.01" in repr(m) # ============================================================================= # Helper Function Tests # ============================================================================= class TestParseDt: def test_none(self): assert _parse_dt(None) is None def test_empty_string(self): assert _parse_dt("") is None def test_iso_format(self): dt = _parse_dt("2024-01-15T10:30:00+00:00") assert dt is not None assert dt.year == 2024 assert dt.month == 1 def test_iso_with_z(self): dt = _parse_dt("2024-06-15T12:00:00Z") assert dt is not None assert dt.year == 2024 def test_invalid_string(self): assert _parse_dt("not-a-date") is None class TestMeasureToDict: def test_full_conversion(self): m = _make_measure_row() d = _measure_to_dict(m) assert d["id"] == MEASURE_ID assert d["control_id"] == "TOM.GOV.01" assert d["name"] == "Datenschutzrichtlinie" assert d["category"] == "GOVERNANCE" assert d["type"] == "ORGANIZATIONAL" assert d["linked_evidence"] == [] assert d["related_controls"] == {} assert d["created_at"] is not None def test_with_dates(self): m = _make_measure_row( implementation_date=datetime(2024, 3, 1, tzinfo=timezone.utc), review_date=datetime(2025, 3, 1, tzinfo=timezone.utc), ) d = _measure_to_dict(m) assert "2024-03-01" in d["implementation_date"] assert "2025-03-01" in d["review_date"] def test_null_dates(self): m = _make_measure_row() d = _measure_to_dict(m) assert d["implementation_date"] is None assert d["review_date"] is None assert d["verified_at"] is None # ============================================================================= # Route Tests (with mocked DB) # ============================================================================= from classroom_engine.database import get_db def override_get_db(mock_db): def _override(): return mock_db return _override class TestStateRoutes: def test_get_state_new_tenant(self): db = _make_mock_db() db.filter.return_value.first.return_value = None app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/state?tenant_id=new-tenant") assert resp.status_code == 200 data = resp.json() assert data["success"] is True assert data["data"]["isNew"] is True assert data["data"]["version"] == 0 app.dependency_overrides.clear() def test_get_state_existing(self): db = _make_mock_db() row = _make_state_row(state={"steps": [1, 2, 3]}) db.filter.return_value.first.return_value = row app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get(f"/tom/state?tenant_id={DEFAULT_TENANT}") assert resp.status_code == 200 data = resp.json() assert data["success"] is True assert data["data"]["version"] == 1 assert data["data"]["state"]["steps"] == [1, 2, 3] app.dependency_overrides.clear() def test_post_state_new(self): db = _make_mock_db() db.filter.return_value.first.return_value = None def mock_refresh(obj): obj.version = 1 obj.updated_at = datetime(2024, 1, 1, tzinfo=timezone.utc) obj.state = {"test": True} db.refresh = mock_refresh app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/state", json={ "tenant_id": DEFAULT_TENANT, "state": {"test": True}, }) assert resp.status_code == 200 data = resp.json() assert data["success"] is True db.add.assert_called_once() app.dependency_overrides.clear() def test_post_state_version_conflict(self): db = _make_mock_db() row = _make_state_row(version=5) db.filter.return_value.first.return_value = row app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/state", json={ "tenant_id": DEFAULT_TENANT, "state": {"test": True}, "version": 3, # Expected 3, actual 5 }) assert resp.status_code == 409 app.dependency_overrides.clear() def test_post_state_update_existing(self): db = _make_mock_db() row = _make_state_row(version=2) db.filter.return_value.first.return_value = row def mock_refresh(obj): pass # row already has attributes db.refresh = mock_refresh app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/state", json={ "tenant_id": DEFAULT_TENANT, "state": {"new": "data"}, "version": 2, }) assert resp.status_code == 200 assert row.version == 3 assert row.state == {"new": "data"} app.dependency_overrides.clear() def test_delete_state(self): db = _make_mock_db() row = _make_state_row() db.filter.return_value.first.return_value = row app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.delete(f"/tom/state?tenant_id={DEFAULT_TENANT}") assert resp.status_code == 200 data = resp.json() assert data["deleted"] is True db.delete.assert_called_once() app.dependency_overrides.clear() def test_delete_state_not_found(self): db = _make_mock_db() db.filter.return_value.first.return_value = None app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.delete(f"/tom/state?tenant_id={DEFAULT_TENANT}") assert resp.status_code == 200 data = resp.json() assert data["deleted"] is False app.dependency_overrides.clear() def test_delete_state_missing_tenant(self): db = _make_mock_db() app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.delete("/tom/state") assert resp.status_code == 400 app.dependency_overrides.clear() class TestMeasureRoutes: def test_list_measures_empty(self): db = _make_mock_db() db.filter.return_value.order_by.return_value.offset.return_value.limit.return_value.all.return_value = [] db.filter.return_value.count.return_value = 0 app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/measures") assert resp.status_code == 200 data = resp.json() assert data["measures"] == [] assert data["total"] == 0 app.dependency_overrides.clear() def test_list_measures_with_data(self): db = _make_mock_db() measures = [_make_measure_row("TOM.GOV.01"), _make_measure_row("TOM.ACC.01", name="Zugriff")] db.filter.return_value.order_by.return_value.offset.return_value.limit.return_value.all.return_value = measures db.filter.return_value.count.return_value = 2 app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/measures") assert resp.status_code == 200 data = resp.json() assert len(data["measures"]) == 2 assert data["total"] == 2 app.dependency_overrides.clear() def test_create_measure(self): db = _make_mock_db() # No existing measure with same control_id db.filter.return_value.first.return_value = None def mock_refresh(obj): obj.id = uuid.UUID(MEASURE_ID) obj.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) obj.updated_at = datetime(2024, 1, 1, tzinfo=timezone.utc) db.refresh = mock_refresh app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/measures", json={ "control_id": "TOM.GOV.01", "name": "Datenschutzrichtlinie", "category": "GOVERNANCE", "type": "ORGANIZATIONAL", }) assert resp.status_code == 201 data = resp.json() assert data["control_id"] == "TOM.GOV.01" db.add.assert_called_once() app.dependency_overrides.clear() def test_create_measure_duplicate(self): db = _make_mock_db() db.filter.return_value.first.return_value = _make_measure_row() app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/measures", json={ "control_id": "TOM.GOV.01", "name": "Duplicate", "category": "GOVERNANCE", "type": "ORGANIZATIONAL", }) assert resp.status_code == 409 app.dependency_overrides.clear() def test_update_measure(self): db = _make_mock_db() row = _make_measure_row() db.filter.return_value.first.return_value = row def mock_refresh(obj): pass db.refresh = mock_refresh app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.put(f"/tom/measures/{MEASURE_ID}", json={ "implementation_status": "IMPLEMENTED", "responsible_person": "Max Mustermann", }) assert resp.status_code == 200 assert row.implementation_status == "IMPLEMENTED" assert row.responsible_person == "Max Mustermann" app.dependency_overrides.clear() def test_update_measure_not_found(self): db = _make_mock_db() db.filter.return_value.first.return_value = None app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.put(f"/tom/measures/{UNKNOWN_ID}", json={ "implementation_status": "IMPLEMENTED", }) assert resp.status_code == 404 app.dependency_overrides.clear() def test_bulk_upsert_create(self): db = _make_mock_db() # No existing measures db.filter.return_value.first.return_value = None app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/measures/bulk", json={ "tenant_id": DEFAULT_TENANT, "measures": [ { "control_id": "TOM.GOV.01", "name": "Datenschutzrichtlinie", "category": "GOVERNANCE", "type": "ORGANIZATIONAL", }, { "control_id": "TOM.ACC.01", "name": "Zugriffskontrolle", "category": "ACCESS_CONTROL", "type": "TECHNICAL", }, ], }) assert resp.status_code == 200 data = resp.json() assert data["success"] is True assert data["created"] == 2 assert data["updated"] == 0 assert data["total"] == 2 app.dependency_overrides.clear() def test_bulk_upsert_update(self): db = _make_mock_db() existing = _make_measure_row("TOM.GOV.01") db.filter.return_value.first.return_value = existing app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/measures/bulk", json={ "measures": [ { "control_id": "TOM.GOV.01", "name": "Updated Name", "category": "GOVERNANCE", "type": "ORGANIZATIONAL", }, ], }) assert resp.status_code == 200 data = resp.json() assert data["updated"] == 1 assert data["created"] == 0 assert existing.name == "Updated Name" app.dependency_overrides.clear() class TestStatsRoute: def test_stats_empty(self): db = _make_mock_db() db.filter.return_value.count.return_value = 0 db.filter.return_value.group_by.return_value.all.return_value = [] app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/stats") assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 assert data["by_status"] == {} assert data["by_category"] == {} app.dependency_overrides.clear() def test_stats_with_data(self): db = _make_mock_db() # Total count base_q = MagicMock() base_q.count.return_value = 10 base_q.filter.return_value.count.return_value = 2 # overdue # Status group_by status_q = MagicMock() status_q.all.return_value = [("IMPLEMENTED", 5), ("NOT_IMPLEMENTED", 3), ("PARTIAL", 2)] # Category group_by cat_q = MagicMock() cat_q.all.return_value = [("GOVERNANCE", 4), ("ACCESS_CONTROL", 6)] call_count = [0] original_filter = db.query.return_value.filter def mock_filter(*args): call_count[0] += 1 if call_count[0] == 1: return base_q elif call_count[0] == 2: mock_gby = MagicMock() mock_gby.all.return_value = [("IMPLEMENTED", 5), ("NOT_IMPLEMENTED", 3), ("PARTIAL", 2)] result = MagicMock() result.group_by.return_value = mock_gby return result elif call_count[0] == 3: mock_gby = MagicMock() mock_gby.all.return_value = [("GOVERNANCE", 4), ("ACCESS_CONTROL", 6)] result = MagicMock() result.group_by.return_value = mock_gby return result return MagicMock() db.query.return_value.filter = mock_filter app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/stats") assert resp.status_code == 200 data = resp.json() assert data["total"] == 10 app.dependency_overrides.clear() class TestExportRoute: def test_export_json(self): db = _make_mock_db() measures = [_make_measure_row("TOM.GOV.01")] db.filter.return_value.order_by.return_value.all.return_value = measures app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/export?format=json") assert resp.status_code == 200 assert "application/json" in resp.headers.get("content-type", "") data = resp.json() assert len(data) == 1 assert data[0]["control_id"] == "TOM.GOV.01" app.dependency_overrides.clear() def test_export_csv(self): db = _make_mock_db() measures = [_make_measure_row("TOM.GOV.01")] db.filter.return_value.order_by.return_value.all.return_value = measures app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/export?format=csv") assert resp.status_code == 200 assert "text/csv" in resp.headers.get("content-type", "") content = resp.text assert "control_id" in content # Header assert "TOM.GOV.01" in content app.dependency_overrides.clear() def test_export_csv_empty(self): db = _make_mock_db() db.filter.return_value.order_by.return_value.all.return_value = [] app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get("/tom/export?format=csv") assert resp.status_code == 200 content = resp.text assert "control_id" in content # Header still present app.dependency_overrides.clear() # ============================================================================= # camelCase tenantId alias tests # ============================================================================= class TestTenantIdAlias: def test_get_state_camelcase(self): db = _make_mock_db() db.filter.return_value.first.return_value = None app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.get(f"/tom/state?tenantId={DEFAULT_TENANT}") assert resp.status_code == 200 data = resp.json() assert data["data"]["tenantId"] == DEFAULT_TENANT app.dependency_overrides.clear() def test_post_state_camelcase(self): db = _make_mock_db() db.filter.return_value.first.return_value = None def mock_refresh(obj): obj.version = 1 obj.updated_at = datetime(2024, 1, 1, tzinfo=timezone.utc) obj.state = {} db.refresh = mock_refresh app.dependency_overrides[get_db] = override_get_db(db) client = TestClient(app) resp = client.post("/tom/state", json={ "tenantId": DEFAULT_TENANT, "state": {}, }) assert resp.status_code == 200 app.dependency_overrides.clear()