"""Tests for Risk management routes (risk_routes.py).""" from datetime import datetime, date from unittest.mock import MagicMock, patch from fastapi import FastAPI from fastapi.testclient import TestClient from compliance.api.risk_routes import router as risk_router from classroom_engine.database import get_db # --------------------------------------------------------------------------- # App setup with mocked DB dependency # --------------------------------------------------------------------------- app = FastAPI() app.include_router(risk_router) mock_db = MagicMock() def override_get_db(): yield mock_db app.dependency_overrides[get_db] = override_get_db client = TestClient(app) RISK_UUID = "aaaaaaaa-1111-2222-3333-bbbbbbbbbbbb" NOW = datetime(2024, 3, 1, 12, 0, 0) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def make_risk(overrides=None): r = MagicMock() r.id = RISK_UUID r.risk_id = "RISK-001" r.title = "Datenleck durch unsichere API" r.description = "API ohne Auth" r.category = "data_breach" r.likelihood = 3 r.impact = 4 # inherent_risk and residual_risk are Enum → need .value r.inherent_risk = MagicMock() r.inherent_risk.value = "high" r.residual_likelihood = 2 r.residual_impact = 3 r.residual_risk = MagicMock() r.residual_risk.value = "medium" r.status = "open" r.mitigating_controls = ["TOM-001"] r.owner = "CISO" r.treatment_plan = "API absichern" r.identified_date = date(2024, 1, 1) r.review_date = date(2024, 6, 1) r.last_assessed_at = NOW r.created_at = NOW r.updated_at = NOW if overrides: for k, v in overrides.items(): setattr(r, k, v) return r # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestListRisks: """Tests for GET /risks.""" def test_list_empty(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [] response = client.get("/risks") assert response.status_code == 200 data = response.json() assert data["risks"] == [] assert data["total"] == 0 def test_list_with_risk(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [make_risk()] response = client.get("/risks") assert response.status_code == 200 data = response.json() assert data["total"] == 1 r = data["risks"][0] assert r["risk_id"] == "RISK-001" assert r["title"] == "Datenleck durch unsichere API" assert r["inherent_risk"] == "high" assert r["status"] == "open" def test_list_filter_category(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [make_risk()] response = client.get("/risks", params={"category": "data_breach"}) assert response.status_code == 200 assert response.json()["total"] == 1 def test_list_filter_status(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [] response = client.get("/risks", params={"status": "mitigated"}) assert response.status_code == 200 def test_list_filter_risk_level(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [make_risk()] response = client.get("/risks", params={"risk_level": "high"}) assert response.status_code == 200 def test_list_multiple(self): r2 = make_risk() r2.id = "bbbbbbbb-2222-2222-2222-bbbbbbbbbbbb" r2.risk_id = "RISK-002" with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_all.return_value = [make_risk(), r2] response = client.get("/risks") assert response.status_code == 200 assert response.json()["total"] == 2 class TestCreateRisk: """Tests for POST /risks.""" def test_create_success(self): risk = make_risk() with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.create.return_value = risk response = client.post("/risks", json={ "risk_id": "RISK-001", "title": "Datenleck durch unsichere API", "category": "data_breach", "likelihood": 3, "impact": 4, }) assert response.status_code == 200 data = response.json() assert data["risk_id"] == "RISK-001" assert data["inherent_risk"] == "high" def test_create_missing_required_fields(self): """Missing risk_id → 422.""" response = client.post("/risks", json={ "title": "Ohne risk_id", }) assert response.status_code == 422 def test_create_likelihood_out_of_range(self): """likelihood > 5 → 422.""" response = client.post("/risks", json={ "risk_id": "R-999", "title": "Test", "category": "test", "likelihood": 6, "impact": 3, }) assert response.status_code == 422 class TestUpdateRisk: """Tests for PUT /risks/{risk_id}.""" def test_update_success(self): updated = make_risk() updated.title = "Aktualisiertes Risiko" with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: repo = MockRepo.return_value # Update route uses get_by_risk_id (the risk_id string, not UUID) repo.get_by_risk_id.return_value = make_risk() repo.update.return_value = updated response = client.put("/risks/RISK-001", json={"title": "Aktualisiertes Risiko"}) assert response.status_code == 200 assert response.json()["title"] == "Aktualisiertes Risiko" def test_update_not_found(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_by_risk_id.return_value = None response = client.put("/risks/RISK-999", json={"title": "Test"}) assert response.status_code == 404 def test_update_status_change(self): updated = make_risk() updated.status = "closed" with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: repo = MockRepo.return_value repo.get_by_risk_id.return_value = make_risk() repo.update.return_value = updated response = client.put("/risks/RISK-001", json={"status": "closed"}) assert response.status_code == 200 assert response.json()["status"] == "closed" class TestDeleteRisk: """Tests for DELETE /risks/{risk_id}.""" def test_delete_success(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: repo = MockRepo.return_value repo.get_by_risk_id.return_value = make_risk() # Delete uses db.delete(risk) directly response = client.delete("/risks/RISK-001") assert response.status_code == 200 data = response.json() assert data["success"] is True def test_delete_not_found(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: MockRepo.return_value.get_by_risk_id.return_value = None response = client.delete("/risks/RISK-999") assert response.status_code == 404 class TestRiskMatrix: """Tests for GET /risks/matrix.""" def test_matrix_returns_structure(self): # Schema: Dict[str, Dict[str, List[str]]] → {likelihood: {impact: [risk_ids]}} matrix_data = { "3": {"4": ["RISK-001"]}, "1": {"1": []}, } with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: repo = MockRepo.return_value repo.get_risk_matrix.return_value = matrix_data repo.get_all.return_value = [make_risk()] response = client.get("/risks/matrix") assert response.status_code == 200 data = response.json() assert "matrix" in data assert "risks" in data assert len(data["risks"]) == 1 def test_matrix_empty(self): with patch("compliance.api.risk_routes.RiskRepository") as MockRepo: repo = MockRepo.return_value repo.get_risk_matrix.return_value = {} repo.get_all.return_value = [] response = client.get("/risks/matrix") assert response.status_code == 200 data = response.json() assert data["risks"] == []