"""Tests for compliance obligation routes and schemas (obligation_routes.py).""" import pytest from unittest.mock import MagicMock from datetime import datetime from fastapi import FastAPI from fastapi.testclient import TestClient from compliance.api.obligation_routes import ( router, ObligationCreate, ObligationUpdate, ObligationStatusUpdate, ) from compliance.api.db_utils import row_to_dict as _row_to_dict DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e" from classroom_engine.database import get_db # --------------------------------------------------------------------------- # Route-Test infrastructure # --------------------------------------------------------------------------- OBLIGATION_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" app = FastAPI() app.include_router(router) class _MockRow: """Simulates a SQLAlchemy row with _mapping attribute.""" def __init__(self, data: dict): self._mapping = data def __getitem__(self, idx): vals = list(self._mapping.values()) return vals[idx] def _make_obligation_row(overrides=None): now = datetime(2026, 3, 6, 12, 0, 0) data = { "id": OBLIGATION_ID, "tenant_id": DEFAULT_TENANT_ID, "title": "Art. 30 VVT führen", "description": "Pflicht nach DSGVO", "source": "DSGVO", "source_article": "Art. 30", "deadline": None, "status": "pending", "priority": "medium", "responsible": None, "linked_systems": [], "linked_vendor_ids": [], "assessment_id": None, "rule_code": None, "notes": None, "created_at": now, "updated_at": now, } if overrides: data.update(overrides) return _MockRow(data) def _count_row(val): """Simulates a COUNT(*) row — fetchone()[0] returns the value.""" row = MagicMock() row.__getitem__ = lambda self, idx: val return row @pytest.fixture def mock_db(): db = MagicMock() app.dependency_overrides[get_db] = lambda: db yield db app.dependency_overrides.pop(get_db, None) @pytest.fixture def client(mock_db): return TestClient(app) # ============================================================================= # Schema Tests — ObligationCreate # ============================================================================= class TestObligationCreate: def test_minimal_valid(self): req = ObligationCreate(title="Art. 30 VVT führen") assert req.title == "Art. 30 VVT führen" assert req.source == "DSGVO" assert req.status == "pending" assert req.priority == "medium" assert req.description is None assert req.source_article is None assert req.deadline is None assert req.responsible is None assert req.linked_systems is None assert req.assessment_id is None assert req.rule_code is None assert req.notes is None def test_full_values(self): deadline = datetime(2026, 6, 1, 0, 0, 0) req = ObligationCreate( title="DSFA durchführen", description="Pflicht nach Art. 35 DSGVO", source="DSGVO", source_article="Art. 35", deadline=deadline, status="in-progress", priority="critical", responsible="Datenschutzbeauftragter", linked_systems=["CRM", "ERP"], assessment_id="abc123", rule_code="RULE-DSFA-001", notes="Frist ist bindend", ) assert req.title == "DSFA durchführen" assert req.source_article == "Art. 35" assert req.priority == "critical" assert req.responsible == "Datenschutzbeauftragter" assert req.linked_systems == ["CRM", "ERP"] assert req.rule_code == "RULE-DSFA-001" def test_ai_act_source(self): req = ObligationCreate(title="Risikoklasse bestimmen", source="AI Act") assert req.source == "AI Act" assert req.status == "pending" def test_nis2_source(self): req = ObligationCreate(title="Meldepflicht einrichten", source="NIS2") assert req.source == "NIS2" def test_serialization_excludes_none(self): req = ObligationCreate(title="Test", priority="high") data = req.model_dump(exclude_none=True) assert data["title"] == "Test" assert data["priority"] == "high" assert "description" not in data assert "deadline" not in data def test_serialization_includes_set_fields(self): req = ObligationCreate(title="Test", status="overdue", responsible="admin") data = req.model_dump() assert data["status"] == "overdue" assert data["responsible"] == "admin" assert data["source"] == "DSGVO" # ============================================================================= # Schema Tests — ObligationUpdate # ============================================================================= class TestObligationUpdate: def test_empty_update(self): req = ObligationUpdate() data = req.model_dump(exclude_unset=True) assert data == {} def test_partial_update_title(self): req = ObligationUpdate(title="Neuer Titel") data = req.model_dump(exclude_unset=True) assert data == {"title": "Neuer Titel"} def test_partial_update_status_priority(self): req = ObligationUpdate(status="completed", priority="low") data = req.model_dump(exclude_unset=True) assert data["status"] == "completed" assert data["priority"] == "low" assert "title" not in data def test_update_linked_systems(self): req = ObligationUpdate(linked_systems=["CRM"]) data = req.model_dump(exclude_unset=True) assert data["linked_systems"] == ["CRM"] def test_update_clears_linked_systems(self): req = ObligationUpdate(linked_systems=[]) data = req.model_dump(exclude_unset=True) assert data["linked_systems"] == [] def test_update_deadline(self): dl = datetime(2026, 12, 31) req = ObligationUpdate(deadline=dl) data = req.model_dump(exclude_unset=True) assert data["deadline"] == dl def test_full_update(self): req = ObligationUpdate( title="Updated", description="Neue Beschreibung", source="NIS2", source_article="Art. 21", status="in-progress", priority="high", responsible="CISO", notes="Jetzt eilt es", ) data = req.model_dump(exclude_unset=True) assert len(data) == 8 assert data["responsible"] == "CISO" # ============================================================================= # Schema Tests — ObligationStatusUpdate # ============================================================================= class TestObligationStatusUpdate: def test_pending(self): req = ObligationStatusUpdate(status="pending") assert req.status == "pending" def test_in_progress(self): req = ObligationStatusUpdate(status="in-progress") assert req.status == "in-progress" def test_completed(self): req = ObligationStatusUpdate(status="completed") assert req.status == "completed" def test_overdue(self): req = ObligationStatusUpdate(status="overdue") assert req.status == "overdue" def test_serialization(self): req = ObligationStatusUpdate(status="completed") data = req.model_dump() assert data == {"status": "completed"} # ============================================================================= # Helper Tests — _row_to_dict # ============================================================================= class TestRowToDict: def test_basic_conversion(self): row = MagicMock() row._mapping = {"id": "abc-123", "title": "Test Pflicht", "priority": "medium"} result = _row_to_dict(row) assert result["id"] == "abc-123" assert result["title"] == "Test Pflicht" assert result["priority"] == "medium" def test_datetime_serialized(self): ts = datetime(2026, 6, 1, 12, 0, 0) row = MagicMock() row._mapping = {"id": "abc", "created_at": ts, "updated_at": ts} result = _row_to_dict(row) assert result["created_at"] == ts.isoformat() assert result["updated_at"] == ts.isoformat() def test_deadline_serialized(self): dl = datetime(2026, 12, 31, 23, 59, 59) row = MagicMock() row._mapping = {"id": "abc", "deadline": dl} result = _row_to_dict(row) assert result["deadline"] == dl.isoformat() def test_none_values_preserved(self): row = MagicMock() row._mapping = { "id": "abc", "description": None, "deadline": None, "responsible": None, "notes": None, } result = _row_to_dict(row) assert result["description"] is None assert result["deadline"] is None assert result["responsible"] is None assert result["notes"] is None def test_uuid_converted_to_string(self): import uuid uid = uuid.UUID("9282a473-5c95-4b3a-bf78-0ecc0ec71d3e") row = MagicMock() row._mapping = {"id": uid, "tenant_id": uid} result = _row_to_dict(row) assert result["id"] == str(uid) assert result["tenant_id"] == str(uid) def test_string_fields_unchanged(self): row = MagicMock() row._mapping = { "title": "DSFA durchführen", "status": "pending", "source": "DSGVO", "source_article": "Art. 35", "priority": "critical", } result = _row_to_dict(row) assert result["title"] == "DSFA durchführen" assert result["status"] == "pending" assert result["source"] == "DSGVO" assert result["priority"] == "critical" def test_int_and_bool_unchanged(self): row = MagicMock() row._mapping = {"count": 42, "active": True, "flag": False} result = _row_to_dict(row) assert result["count"] == 42 assert result["active"] is True assert result["flag"] is False # ============================================================================= # Business Logic Tests # ============================================================================= class TestObligationBusinessLogic: def test_default_tenant_id_is_valid_uuid(self): import uuid # Should not raise parsed = uuid.UUID(DEFAULT_TENANT_ID) assert str(parsed) == DEFAULT_TENANT_ID def test_valid_statuses(self): valid = {"pending", "in-progress", "completed", "overdue"} # Each status should be a valid string, matching what the route validates assert "pending" in valid assert "in-progress" in valid assert "completed" in valid assert "overdue" in valid def test_valid_priorities(self): valid = {"critical", "high", "medium", "low"} req = ObligationCreate(title="Test", priority="critical") assert req.priority in valid req2 = ObligationCreate(title="Test", priority="low") assert req2.priority in valid def test_priority_order_correctness(self): """Ensure priority values match the SQL CASE ordering in the route.""" priorities_ordered = ["critical", "high", "medium", "low"] for i, p in enumerate(priorities_ordered): req = ObligationCreate(title=f"Test {p}", priority=p) assert req.priority == p def test_linked_systems_defaults_to_none(self): req = ObligationCreate(title="Test") assert req.linked_systems is None def test_linked_systems_can_be_empty_list(self): req = ObligationCreate(title="Test", linked_systems=[]) assert req.linked_systems == [] def test_linked_systems_multiple_items(self): systems = ["CRM", "ERP", "HR-System", "Buchhaltung"] req = ObligationCreate(title="Test", linked_systems=systems) assert len(req.linked_systems) == 4 assert "ERP" in req.linked_systems def test_source_defaults(self): """Verify all common DSGVO/AI Act sources can be stored.""" for source in ["DSGVO", "AI Act", "NIS2", "BDSG", "ISO 27001"]: req = ObligationCreate(title="Test", source=source) assert req.source == source # ============================================================================= # Route Integration Tests — List # ============================================================================= class TestListObligationsRoute: def test_list_empty(self, client, mock_db): mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(0))), MagicMock(fetchall=MagicMock(return_value=[])), ] resp = client.get("/obligations") assert resp.status_code == 200 data = resp.json() assert data["obligations"] == [] assert data["total"] == 0 def test_list_with_items(self, client, mock_db): row = _make_obligation_row() mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(1))), MagicMock(fetchall=MagicMock(return_value=[row])), ] resp = client.get("/obligations") assert resp.status_code == 200 data = resp.json() assert len(data["obligations"]) == 1 assert data["obligations"][0]["id"] == OBLIGATION_ID assert data["total"] == 1 def test_list_filter_status(self, client, mock_db): row = _make_obligation_row({"status": "overdue"}) mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(1))), MagicMock(fetchall=MagicMock(return_value=[row])), ] resp = client.get("/obligations?status=overdue") assert resp.status_code == 200 assert resp.json()["obligations"][0]["status"] == "overdue" def test_list_filter_priority(self, client, mock_db): row = _make_obligation_row({"priority": "critical"}) mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(1))), MagicMock(fetchall=MagicMock(return_value=[row])), ] resp = client.get("/obligations?priority=critical") assert resp.status_code == 200 assert resp.json()["obligations"][0]["priority"] == "critical" # ============================================================================= # Route Integration Tests — Create # ============================================================================= class TestCreateObligationRoute: def test_create_basic(self, client, mock_db): row = _make_obligation_row() mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.post("/obligations", json={"title": "Art. 30 VVT führen"}) assert resp.status_code == 201 data = resp.json() assert data["id"] == OBLIGATION_ID assert data["title"] == "Art. 30 VVT führen" mock_db.commit.assert_called_once() def test_create_full_fields(self, client, mock_db): row = _make_obligation_row({ "title": "DSFA durchführen", "source": "DSGVO", "source_article": "Art. 35", "priority": "critical", "responsible": "DSB", }) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.post("/obligations", json={ "title": "DSFA durchführen", "source": "DSGVO", "source_article": "Art. 35", "priority": "critical", "responsible": "DSB", }) assert resp.status_code == 201 assert resp.json()["priority"] == "critical" assert resp.json()["responsible"] == "DSB" def test_create_missing_title_422(self, client, mock_db): resp = client.post("/obligations", json={"source": "DSGVO"}) assert resp.status_code == 422 # ============================================================================= # Route Integration Tests — Get # ============================================================================= class TestGetObligationRoute: def test_get_existing(self, client, mock_db): row = _make_obligation_row() mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.get(f"/obligations/{OBLIGATION_ID}") assert resp.status_code == 200 assert resp.json()["id"] == OBLIGATION_ID def test_get_not_found(self, client, mock_db): mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=None)) resp = client.get("/obligations/nonexistent-id") assert resp.status_code == 404 # ============================================================================= # Route Integration Tests — Update # ============================================================================= class TestUpdateObligationRoute: def test_update_partial(self, client, mock_db): updated = _make_obligation_row({"priority": "high"}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=updated)) resp = client.put(f"/obligations/{OBLIGATION_ID}", json={"priority": "high"}) assert resp.status_code == 200 assert resp.json()["priority"] == "high" mock_db.commit.assert_called_once() def test_update_not_found(self, client, mock_db): mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=None)) resp = client.put(f"/obligations/{OBLIGATION_ID}", json={"title": "Updated"}) assert resp.status_code == 404 def test_update_empty_body_400(self, client, mock_db): resp = client.put(f"/obligations/{OBLIGATION_ID}", json={}) assert resp.status_code == 400 assert "No fields to update" in resp.json()["detail"] # ============================================================================= # Route Integration Tests — Status Update # ============================================================================= class TestUpdateObligationStatusRoute: def test_valid_status(self, client, mock_db): row = _make_obligation_row({"status": "in-progress"}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.put(f"/obligations/{OBLIGATION_ID}/status", json={"status": "in-progress"}) assert resp.status_code == 200 assert resp.json()["status"] == "in-progress" mock_db.commit.assert_called_once() def test_invalid_status_400(self, client, mock_db): resp = client.put(f"/obligations/{OBLIGATION_ID}/status", json={"status": "invalid"}) assert resp.status_code == 400 assert "Invalid status" in resp.json()["detail"] def test_status_not_found(self, client, mock_db): mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=None)) resp = client.put(f"/obligations/{OBLIGATION_ID}/status", json={"status": "completed"}) assert resp.status_code == 404 def test_status_to_completed(self, client, mock_db): row = _make_obligation_row({"status": "completed"}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.put(f"/obligations/{OBLIGATION_ID}/status", json={"status": "completed"}) assert resp.status_code == 200 assert resp.json()["status"] == "completed" # ============================================================================= # Route Integration Tests — Delete # ============================================================================= class TestDeleteObligationRoute: def test_delete_existing(self, client, mock_db): mock_db.execute.return_value = MagicMock(rowcount=1) resp = client.delete(f"/obligations/{OBLIGATION_ID}") assert resp.status_code == 204 mock_db.commit.assert_called_once() def test_delete_not_found(self, client, mock_db): mock_db.execute.return_value = MagicMock(rowcount=0) resp = client.delete(f"/obligations/{OBLIGATION_ID}") assert resp.status_code == 404 # ============================================================================= # Route Integration Tests — Stats # ============================================================================= class TestGetObligationStatsRoute: def test_stats_empty(self, client, mock_db): mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=None)) resp = client.get("/obligations/stats") assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 assert data["pending"] == 0 def test_stats_with_data(self, client, mock_db): row = MagicMock() row._mapping = { "pending": 3, "in_progress": 2, "overdue": 1, "completed": 5, "critical": 2, "high": 3, "total": 11, } mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.get("/obligations/stats") assert resp.status_code == 200 data = resp.json() assert data["total"] == 11 assert data["pending"] == 3 assert data["critical"] == 2 def test_stats_structure(self, client, mock_db): row = MagicMock() row._mapping = { "pending": 0, "in_progress": 0, "overdue": 0, "completed": 0, "critical": 0, "high": 0, "total": 0, } mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.get("/obligations/stats") data = resp.json() expected_keys = {"pending", "in_progress", "overdue", "completed", "critical", "high", "total"} assert set(data.keys()) == expected_keys # ============================================================================= # Route Integration Tests — Search # ============================================================================= class TestObligationSearchRoute: def test_search_param(self, client, mock_db): row = _make_obligation_row({"title": "DSFA Pflicht"}) mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(1))), MagicMock(fetchall=MagicMock(return_value=[row])), ] resp = client.get("/obligations?search=DSFA") assert resp.status_code == 200 assert len(resp.json()["obligations"]) == 1 def test_source_filter(self, client, mock_db): row = _make_obligation_row({"source": "AI Act"}) mock_db.execute.side_effect = [ MagicMock(fetchone=MagicMock(return_value=_count_row(1))), MagicMock(fetchall=MagicMock(return_value=[row])), ] resp = client.get("/obligations?source=AI Act") assert resp.status_code == 200 assert resp.json()["obligations"][0]["source"] == "AI Act" # ============================================================================= # Linked Vendor IDs Tests (Art. 28 DSGVO) # ============================================================================= class TestLinkedVendorIds: def test_create_with_linked_vendor_ids(self, client, mock_db): row = _make_obligation_row({"linked_vendor_ids": ["vendor-1", "vendor-2"]}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.post("/obligations", json={ "title": "Vendor-Prüfung Art. 28", "linked_vendor_ids": ["vendor-1", "vendor-2"], }) assert resp.status_code == 201 assert resp.json()["linked_vendor_ids"] == ["vendor-1", "vendor-2"] def test_create_without_linked_vendor_ids_defaults_empty(self, client, mock_db): row = _make_obligation_row() mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=row)) resp = client.post("/obligations", json={"title": "Ohne Vendor"}) assert resp.status_code == 201 # Schema allows it — linked_vendor_ids defaults to None in the schema schema = ObligationCreate(title="Ohne Vendor") assert schema.linked_vendor_ids is None def test_update_linked_vendor_ids(self, client, mock_db): updated = _make_obligation_row({"linked_vendor_ids": ["v1"]}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=updated)) resp = client.put(f"/obligations/{OBLIGATION_ID}", json={ "linked_vendor_ids": ["v1"], }) assert resp.status_code == 200 assert resp.json()["linked_vendor_ids"] == ["v1"] def test_update_clears_linked_vendor_ids(self, client, mock_db): updated = _make_obligation_row({"linked_vendor_ids": []}) mock_db.execute.return_value = MagicMock(fetchone=MagicMock(return_value=updated)) resp = client.put(f"/obligations/{OBLIGATION_ID}", json={ "linked_vendor_ids": [], }) assert resp.status_code == 200 assert resp.json()["linked_vendor_ids"] == [] def test_schema_create_includes_linked_vendor_ids(self): schema = ObligationCreate( title="Test Vendor Link", linked_vendor_ids=["a", "b"], ) assert schema.linked_vendor_ids == ["a", "b"] data = schema.model_dump() assert data["linked_vendor_ids"] == ["a", "b"] def test_schema_update_includes_linked_vendor_ids(self): schema = ObligationUpdate(linked_vendor_ids=["a"]) data = schema.model_dump(exclude_unset=True) assert data["linked_vendor_ids"] == ["a"]