""" Unit tests for Compliance Routes — Requirements CRUD + AI Systems CRUD. Tests the new POST/DELETE requirements endpoints and AI system endpoints added during the Analyse-Module → 100% sprint. Run with: pytest compliance/tests/test_compliance_routes.py -v """ import pytest from datetime import datetime, timezone from unittest.mock import MagicMock from uuid import uuid4 from compliance.db.models import ( RequirementDB, RegulationDB, AISystemDB, AIClassificationEnum, AISystemStatusEnum, RiskDB, RiskLevelEnum, ) from compliance.db.repository import RequirementRepository # ============================================================================ # Test Fixtures # ============================================================================ @pytest.fixture def mock_db(): """Create a mock database session.""" db = MagicMock() db.query.return_value.filter.return_value.first.return_value = None return db @pytest.fixture def sample_regulation(): """Create a sample regulation.""" return RegulationDB( id=str(uuid4()), code="DSGVO", name="Datenschutz-Grundverordnung", full_name="Verordnung (EU) 2016/679", is_active=True, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_requirement(sample_regulation): """Create a sample requirement.""" return RequirementDB( id=str(uuid4()), regulation_id=sample_regulation.id, article="Art. 6", title="Rechtmaessigkeit der Verarbeitung", description="Personenbezogene Daten duerfen nur verarbeitet werden, wenn eine Rechtsgrundlage vorliegt.", priority=4, is_applicable=True, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_ai_system(): """Create a sample AI system.""" return AISystemDB( id=str(uuid4()), name="Dokumenten-Scanner", description="OCR-basiertes System zur Texterkennung", purpose="Texterkennung in eingescannten Dokumenten", sector="Verwaltung", classification=AIClassificationEnum.UNCLASSIFIED, status=AISystemStatusEnum.DRAFT, obligations=[], created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) # ============================================================================ # Test: Requirement Creation (POST /requirements) # ============================================================================ class TestCreateRequirement: """Tests for POST /compliance/requirements endpoint.""" def test_create_requirement_valid_data(self, sample_regulation): """Creating a requirement with valid data should set all fields.""" req = RequirementDB( id=str(uuid4()), regulation_id=sample_regulation.id, article="Art. 32", title="Sicherheit der Verarbeitung", description="Geeignete technische Massnahmen", priority=3, is_applicable=True, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) assert req.regulation_id == sample_regulation.id assert req.article == "Art. 32" assert req.title == "Sicherheit der Verarbeitung" assert req.priority == 3 assert req.is_applicable is True def test_create_requirement_default_priority(self, sample_regulation): """Creating a requirement without priority should default to 2.""" req = RequirementDB( id=str(uuid4()), regulation_id=sample_regulation.id, article="Art. 5", title="Grundsaetze", ) # Priority defaults handled in repository.create() assert req.article == "Art. 5" def test_create_requirement_requires_regulation_id(self): """A requirement must reference a regulation_id.""" req = RequirementDB( id=str(uuid4()), regulation_id="non-existent-id", article="Art. 1", title="Test", ) assert req.regulation_id == "non-existent-id" def test_create_requirement_with_optional_fields(self, sample_regulation): """Creating a requirement with all optional fields should work.""" req = RequirementDB( id=str(uuid4()), regulation_id=sample_regulation.id, article="Art. 25", paragraph="Abs. 1", title="Datenschutz durch Technikgestaltung", description="Privacy by Design", requirement_text="Der Verantwortliche trifft...", breakpilot_interpretation="Implementierung von PbD", is_applicable=True, applicability_reason="Relevant fuer alle Verarbeitungen", priority=4, ) assert req.paragraph == "Abs. 1" assert req.requirement_text is not None assert req.breakpilot_interpretation is not None assert req.applicability_reason is not None # ============================================================================ # Test: Requirement Deletion (DELETE /requirements/{id}) # ============================================================================ class TestDeleteRequirement: """Tests for DELETE /compliance/requirements/{requirement_id} endpoint.""" def test_delete_existing_requirement_returns_true(self, mock_db, sample_requirement): """Deleting an existing requirement should return True.""" mock_db.query.return_value.filter.return_value.first.return_value = sample_requirement repo = RequirementRepository(mock_db) result = repo.delete(sample_requirement.id) assert result is True mock_db.delete.assert_called_once_with(sample_requirement) mock_db.commit.assert_called_once() def test_delete_nonexistent_requirement_returns_false(self, mock_db): """Deleting a non-existent requirement should return False.""" mock_db.query.return_value.filter.return_value.first.return_value = None repo = RequirementRepository(mock_db) result = repo.delete("nonexistent-id") assert result is False mock_db.delete.assert_not_called() # ============================================================================ # Test: Requirement Update with Rollback logic # ============================================================================ class TestUpdateRequirement: """Tests for PUT /compliance/requirements/{requirement_id} endpoint.""" def test_update_implementation_status(self, sample_requirement): """Updating implementation_status should change the field.""" sample_requirement.implementation_status = "implemented" assert sample_requirement.implementation_status == "implemented" def test_update_audit_status_sets_audit_date(self, sample_requirement): """Updating audit_status should set last_audit_date.""" sample_requirement.audit_status = "compliant" sample_requirement.last_audit_date = datetime.now(timezone.utc) assert sample_requirement.audit_status == "compliant" assert sample_requirement.last_audit_date is not None def test_update_allowed_fields_only(self, sample_requirement): """Only allowed fields should be updated.""" allowed_fields = [ 'implementation_status', 'implementation_details', 'code_references', 'documentation_links', 'evidence_description', 'evidence_artifacts', 'auditor_notes', 'audit_status', 'is_applicable', 'applicability_reason', 'breakpilot_interpretation' ] # Verify all allowed fields exist on the model for field in allowed_fields: assert hasattr(sample_requirement, field) # ============================================================================ # Test: AI System Model # ============================================================================ class TestAISystemModel: """Tests for AISystemDB model and enums.""" def test_ai_classification_enum_values(self): """AI classification enum should have all 5 risk levels.""" assert AIClassificationEnum.PROHIBITED.value == "prohibited" assert AIClassificationEnum.HIGH_RISK.value == "high-risk" assert AIClassificationEnum.LIMITED_RISK.value == "limited-risk" assert AIClassificationEnum.MINIMAL_RISK.value == "minimal-risk" assert AIClassificationEnum.UNCLASSIFIED.value == "unclassified" def test_ai_system_status_enum_values(self): """AI system status enum should have all 4 statuses.""" assert AISystemStatusEnum.DRAFT.value == "draft" assert AISystemStatusEnum.CLASSIFIED.value == "classified" assert AISystemStatusEnum.COMPLIANT.value == "compliant" assert AISystemStatusEnum.NON_COMPLIANT.value == "non-compliant" def test_invalid_classification_raises_error(self): """Invalid classification should raise ValueError.""" with pytest.raises(ValueError): AIClassificationEnum("super-risk") def test_invalid_status_raises_error(self): """Invalid status should raise ValueError.""" with pytest.raises(ValueError): AISystemStatusEnum("pending") def test_ai_system_creation(self, sample_ai_system): """Creating an AI system should set all fields correctly.""" assert sample_ai_system.name == "Dokumenten-Scanner" assert sample_ai_system.classification == AIClassificationEnum.UNCLASSIFIED assert sample_ai_system.status == AISystemStatusEnum.DRAFT assert sample_ai_system.sector == "Verwaltung" def test_ai_system_repr(self, sample_ai_system): """AI system repr should show name and classification.""" repr_str = repr(sample_ai_system) assert "Dokumenten-Scanner" in repr_str assert "unclassified" in repr_str # ============================================================================ # Test: AI System CRUD Operations # ============================================================================ class TestAISystemCRUD: """Tests for AI system CRUD operations.""" def test_create_ai_system_with_defaults(self): """Creating an AI system with minimal data should use defaults.""" system = AISystemDB( id=str(uuid4()), name="Test System", ) assert system.name == "Test System" assert system.description is None assert system.purpose is None def test_update_ai_system_classification(self, sample_ai_system): """Updating classification should change the enum value.""" sample_ai_system.classification = AIClassificationEnum.HIGH_RISK assert sample_ai_system.classification == AIClassificationEnum.HIGH_RISK def test_update_ai_system_with_assessment(self, sample_ai_system): """After assessment, system should have assessment_date and result.""" sample_ai_system.assessment_date = datetime.now(timezone.utc) sample_ai_system.assessment_result = { "overall_risk": "high", "risk_factors": [{"factor": "education sector", "severity": "high"}], "recommendations": ["Implement Art. 9 measures"], } sample_ai_system.classification = AIClassificationEnum.HIGH_RISK sample_ai_system.status = AISystemStatusEnum.CLASSIFIED sample_ai_system.obligations = [ "Risikomanagementsystem (Art. 9)", "Transparenz (Art. 13)", ] assert sample_ai_system.assessment_date is not None assert sample_ai_system.assessment_result["overall_risk"] == "high" assert len(sample_ai_system.obligations) == 2 assert sample_ai_system.status == AISystemStatusEnum.CLASSIFIED def test_delete_ai_system(self, mock_db, sample_ai_system): """Deleting an AI system should call db.delete and commit.""" mock_db.query.return_value.filter.return_value.first.return_value = sample_ai_system system = mock_db.query(AISystemDB).filter(AISystemDB.id == sample_ai_system.id).first() assert system is not None mock_db.delete(system) mock_db.commit() mock_db.delete.assert_called_once_with(sample_ai_system) # ============================================================================ # Test: AI Act Rule-Based Assessment # ============================================================================ class TestAIActRuleBasedAssessment: """Tests for the rule-based AI Act classification fallback.""" def test_prohibited_keywords_detected(self): """Systems with prohibited keywords should score >= 10.""" from compliance.api.ai_routes import _rule_based_assessment system = MagicMock() system.description = "Social scoring system for citizens" system.purpose = "Social scoring" system.sector = "Government" result = _rule_based_assessment(system) assert result["risk_score"] >= 10 assert result["overall_risk"] == "critical" def test_high_risk_sector_detected(self): """Systems in high-risk sectors should score >= 5.""" from compliance.api.ai_routes import _rule_based_assessment system = MagicMock() system.description = "Automated grading system" system.purpose = "Education assessment" system.sector = "education" result = _rule_based_assessment(system) assert result["risk_score"] >= 5 assert len(result["risk_factors"]) > 0 def test_minimal_risk_system(self): """Systems without risk indicators should have low risk.""" from compliance.api.ai_routes import _rule_based_assessment system = MagicMock() system.description = "Simple spell checker" system.purpose = "Text correction" system.sector = "Office" result = _rule_based_assessment(system) assert result["risk_score"] < 3 assert result["overall_risk"] == "low" def test_derive_classification(self): """Classification should be derived from risk score.""" from compliance.api.ai_routes import _derive_classification assert _derive_classification({"risk_score": 15, "overall_risk": "critical"}) == "prohibited" assert _derive_classification({"risk_score": 7, "overall_risk": "high"}) == "high-risk" assert _derive_classification({"risk_score": 3, "overall_risk": "medium"}) == "limited-risk" assert _derive_classification({"risk_score": 0, "overall_risk": "low"}) == "minimal-risk" def test_derive_obligations_high_risk(self): """High-risk classification should have 8 obligations.""" from compliance.api.ai_routes import _derive_obligations obligations = _derive_obligations("high-risk") assert len(obligations) == 8 assert any("Art. 9" in o for o in obligations) assert any("Art. 13" in o for o in obligations) def test_derive_obligations_minimal_risk(self): """Minimal-risk classification should have 1 obligation.""" from compliance.api.ai_routes import _derive_obligations obligations = _derive_obligations("minimal-risk") assert len(obligations) == 1 assert "Art. 69" in obligations[0] def test_derive_obligations_prohibited(self): """Prohibited classification should have 1 obligation about ban.""" from compliance.api.ai_routes import _derive_obligations obligations = _derive_obligations("prohibited") assert len(obligations) == 1 assert "verboten" in obligations[0].lower() # ============================================================================ # Test: Evidence Pagination # ============================================================================ class TestEvidencePagination: """Tests for evidence pagination logic.""" def test_pagination_calculates_offset(self): """Pagination should calculate correct offset.""" page = 3 limit = 20 offset = (page - 1) * limit assert offset == 40 def test_pagination_first_page_no_offset(self): """First page should have offset 0.""" page = 1 limit = 20 offset = (page - 1) * limit assert offset == 0 def test_total_reflects_full_count(self): """Total should reflect all items, not just the page.""" all_items = list(range(100)) page = 2 limit = 20 total = len(all_items) page_items = all_items[(page - 1) * limit : page * limit] assert total == 100 assert len(page_items) == 20 assert page_items[0] == 20 # ============================================================================ # Test: Risk Status Workflow # ============================================================================ class TestRiskStatusWorkflow: """Tests for risk status transitions.""" def test_risk_status_values(self): """Risk status should support all workflow states.""" valid_statuses = ["identified", "assessed", "mitigated", "accepted", "closed"] for status in valid_statuses: risk = RiskDB( id=str(uuid4()), risk_id="RISK-001", title="Test Risk", category="data_breach", likelihood=3, impact=4, inherent_risk=RiskLevelEnum.HIGH, status=status, ) assert risk.status == status def test_residual_risk_after_mitigation(self): """After mitigation, residual risk should be lower.""" risk = RiskDB( id=str(uuid4()), risk_id="RISK-002", title="Data Breach Risk", category="data_breach", likelihood=4, impact=5, inherent_risk=RiskLevelEnum.CRITICAL, residual_likelihood=2, residual_impact=3, residual_risk=RiskLevelEnum.MEDIUM, ) inherent_score = risk.likelihood * risk.impact residual_score = risk.residual_likelihood * risk.residual_impact assert residual_score < inherent_score assert risk.inherent_risk == RiskLevelEnum.CRITICAL assert risk.residual_risk == RiskLevelEnum.MEDIUM # ============================================================================ # Test: Risk Level Calculation # ============================================================================ class TestRiskLevelCalculation: """Tests for risk level calculation from likelihood x impact.""" def test_critical_risk_level(self): """Score >= 20 should be CRITICAL.""" assert RiskDB.calculate_risk_level(5, 4) == RiskLevelEnum.CRITICAL assert RiskDB.calculate_risk_level(4, 5) == RiskLevelEnum.CRITICAL assert RiskDB.calculate_risk_level(5, 5) == RiskLevelEnum.CRITICAL def test_high_risk_level(self): """Score >= 12 and < 20 should be HIGH.""" assert RiskDB.calculate_risk_level(3, 4) == RiskLevelEnum.HIGH assert RiskDB.calculate_risk_level(4, 4) == RiskLevelEnum.HIGH def test_medium_risk_level(self): """Score >= 6 and < 12 should be MEDIUM.""" assert RiskDB.calculate_risk_level(2, 3) == RiskLevelEnum.MEDIUM assert RiskDB.calculate_risk_level(3, 3) == RiskLevelEnum.MEDIUM def test_low_risk_level(self): """Score < 6 should be LOW.""" assert RiskDB.calculate_risk_level(1, 1) == RiskLevelEnum.LOW assert RiskDB.calculate_risk_level(1, 5) == RiskLevelEnum.LOW assert RiskDB.calculate_risk_level(2, 2) == RiskLevelEnum.LOW