""" Tests for Compliance Repository Layer. Tests cover: - RequirementRepository.get_paginated() - ControlRepository CRUD operations - EvidenceRepository.create() - RegulationRepository operations - Eager loading and relationships """ import pytest from datetime import datetime, timedelta from unittest.mock import MagicMock # Test with in-memory SQLite for isolation from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from classroom_engine.database import Base from compliance.db.models import ( RegulationDB, RequirementDB, ControlDB, EvidenceDB, ControlMappingDB, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, EvidenceStatusEnum, ControlTypeEnum ) from compliance.db.repository import ( RegulationRepository, RequirementRepository, ControlRepository, EvidenceRepository, ControlMappingRepository, ) @pytest.fixture def db_session(): """Create in-memory SQLite session for tests.""" # Use check_same_thread=False for SQLite in tests engine = create_engine( "sqlite:///:memory:", echo=False, connect_args={"check_same_thread": False} ) Base.metadata.create_all(engine) SessionLocal = sessionmaker(bind=engine) session = SessionLocal() yield session session.close() @pytest.fixture def sample_regulation(db_session): """Create a sample regulation.""" repo = RegulationRepository(db_session) return repo.create( code="GDPR", name="General Data Protection Regulation", regulation_type=RegulationTypeEnum.EU_REGULATION, description="EU data protection law", ) @pytest.fixture def sample_control(db_session): """Create a sample control.""" repo = ControlRepository(db_session) return repo.create( control_id="CRYPTO-001", title="TLS Encryption", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="All connections use TLS 1.3", description="Enforce TLS 1.3 for all external communication", ) # ============================================================================ # RegulationRepository Tests # ============================================================================ class TestRegulationRepository: """Tests for RegulationRepository.""" def test_create_regulation(self, db_session): """Test creating a regulation.""" repo = RegulationRepository(db_session) regulation = repo.create( code="GDPR", name="General Data Protection Regulation", regulation_type=RegulationTypeEnum.EU_REGULATION, full_name="Regulation (EU) 2016/679", description="EU data protection regulation", ) assert regulation.id is not None assert regulation.code == "GDPR" assert regulation.name == "General Data Protection Regulation" assert regulation.regulation_type == RegulationTypeEnum.EU_REGULATION assert regulation.is_active is True def test_get_regulation_by_id(self, db_session, sample_regulation): """Test getting regulation by ID.""" repo = RegulationRepository(db_session) found = repo.get_by_id(sample_regulation.id) assert found is not None assert found.id == sample_regulation.id assert found.code == "GDPR" def test_get_regulation_by_id_not_found(self, db_session): """Test getting non-existent regulation.""" repo = RegulationRepository(db_session) found = repo.get_by_id("nonexistent-id") assert found is None def test_get_regulation_by_code(self, db_session, sample_regulation): """Test getting regulation by code.""" repo = RegulationRepository(db_session) found = repo.get_by_code("GDPR") assert found is not None assert found.code == "GDPR" def test_get_all_regulations(self, db_session): """Test getting all regulations.""" repo = RegulationRepository(db_session) repo.create(code="GDPR", name="GDPR", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.create(code="AI-ACT", name="AI Act", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.create(code="BSI-TR", name="BSI", regulation_type=RegulationTypeEnum.BSI_STANDARD) all_regs = repo.get_all() assert len(all_regs) == 3 def test_get_regulations_filter_by_type(self, db_session): """Test filtering regulations by type.""" repo = RegulationRepository(db_session) repo.create(code="GDPR", name="GDPR", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.create(code="BSI-TR", name="BSI", regulation_type=RegulationTypeEnum.BSI_STANDARD) eu_regs = repo.get_all(regulation_type=RegulationTypeEnum.EU_REGULATION) assert len(eu_regs) == 1 assert eu_regs[0].code == "GDPR" def test_get_regulations_filter_by_active(self, db_session): """Test filtering regulations by active status.""" repo = RegulationRepository(db_session) active = repo.create(code="ACTIVE", name="Active", regulation_type=RegulationTypeEnum.EU_REGULATION) inactive = repo.create(code="INACTIVE", name="Inactive", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.update(inactive.id, is_active=False) active_regs = repo.get_all(is_active=True) assert len(active_regs) == 1 assert active_regs[0].code == "ACTIVE" def test_update_regulation(self, db_session, sample_regulation): """Test updating a regulation.""" repo = RegulationRepository(db_session) updated = repo.update( sample_regulation.id, name="Updated Name", is_active=False, ) assert updated is not None assert updated.name == "Updated Name" assert updated.is_active is False def test_delete_regulation(self, db_session, sample_regulation): """Test deleting a regulation.""" repo = RegulationRepository(db_session) result = repo.delete(sample_regulation.id) assert result is True found = repo.get_by_id(sample_regulation.id) assert found is None def test_delete_nonexistent_regulation(self, db_session): """Test deleting non-existent regulation.""" repo = RegulationRepository(db_session) result = repo.delete("nonexistent-id") assert result is False def test_get_active_regulations(self, db_session): """Test getting only active regulations.""" repo = RegulationRepository(db_session) repo.create(code="ACTIVE1", name="Active 1", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.create(code="ACTIVE2", name="Active 2", regulation_type=RegulationTypeEnum.EU_REGULATION) inactive = repo.create(code="INACTIVE", name="Inactive", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.update(inactive.id, is_active=False) active_regs = repo.get_active() assert len(active_regs) == 2 def test_count_regulations(self, db_session): """Test counting regulations.""" repo = RegulationRepository(db_session) repo.create(code="REG1", name="Reg 1", regulation_type=RegulationTypeEnum.EU_REGULATION) repo.create(code="REG2", name="Reg 2", regulation_type=RegulationTypeEnum.EU_REGULATION) count = repo.count() assert count == 2 # ============================================================================ # RequirementRepository Tests # ============================================================================ class TestRequirementRepository: """Tests for RequirementRepository.""" def test_create_requirement(self, db_session, sample_regulation): """Test creating a requirement.""" repo = RequirementRepository(db_session) requirement = repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Security of processing", description="Implement appropriate technical measures", requirement_text="The controller shall implement appropriate technical and organizational measures...", is_applicable=True, priority=1, ) assert requirement.id is not None assert requirement.article == "Art. 32" assert requirement.title == "Security of processing" assert requirement.is_applicable is True def test_get_requirement_by_id(self, db_session, sample_regulation): """Test getting requirement by ID.""" repo = RequirementRepository(db_session) created = repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Security", is_applicable=True, ) found = repo.get_by_id(created.id) assert found is not None assert found.id == created.id def test_get_requirements_by_regulation(self, db_session, sample_regulation): """Test getting requirements by regulation.""" repo = RequirementRepository(db_session) repo.create(regulation_id=sample_regulation.id, article="Art. 1", title="Req 1", is_applicable=True) repo.create(regulation_id=sample_regulation.id, article="Art. 2", title="Req 2", is_applicable=True) requirements = repo.get_by_regulation(sample_regulation.id) assert len(requirements) == 2 def test_get_requirements_filter_by_applicable(self, db_session, sample_regulation): """Test filtering requirements by applicability.""" repo = RequirementRepository(db_session) repo.create(regulation_id=sample_regulation.id, article="Art. 1", title="Applicable", is_applicable=True) repo.create(regulation_id=sample_regulation.id, article="Art. 2", title="Not Applicable", is_applicable=False) applicable = repo.get_by_regulation(sample_regulation.id, is_applicable=True) assert len(applicable) == 1 assert applicable[0].title == "Applicable" def test_get_requirements_paginated_basic(self, db_session, sample_regulation): """Test basic pagination of requirements.""" repo = RequirementRepository(db_session) # Create 10 requirements for i in range(10): repo.create( regulation_id=sample_regulation.id, article=f"Art. {i}", title=f"Requirement {i}", is_applicable=True, ) # Get first page items, total = repo.get_paginated(page=1, page_size=5) assert len(items) == 5 assert total == 10 # Get second page items, total = repo.get_paginated(page=2, page_size=5) assert len(items) == 5 assert total == 10 def test_get_requirements_paginated_filter_by_regulation(self, db_session): """Test pagination with regulation filter.""" repo_reg = RegulationRepository(db_session) repo_req = RequirementRepository(db_session) gdpr = repo_reg.create(code="GDPR", name="GDPR", regulation_type=RegulationTypeEnum.EU_REGULATION) bsi = repo_reg.create(code="BSI", name="BSI", regulation_type=RegulationTypeEnum.BSI_STANDARD) repo_req.create(regulation_id=gdpr.id, article="Art. 1", title="GDPR Req") repo_req.create(regulation_id=bsi.id, article="T.1", title="BSI Req") # Filter by GDPR items, total = repo_req.get_paginated(regulation_code="GDPR") assert total == 1 assert items[0].title == "GDPR Req" def test_get_requirements_paginated_filter_by_status(self, db_session, sample_regulation): """Test pagination with status filter.""" repo = RequirementRepository(db_session) # Create requirements with different statuses by updating the model directly req1 = repo.create(regulation_id=sample_regulation.id, article="Art. 1", title="Implemented") req2 = repo.create(regulation_id=sample_regulation.id, article="Art. 2", title="Planned") # Update statuses via the database model req1.implementation_status = "implemented" req2.implementation_status = "planned" db_session.commit() # Filter by implemented items, total = repo.get_paginated(status="implemented") assert total == 1 assert items[0].title == "Implemented" def test_get_requirements_paginated_search(self, db_session, sample_regulation): """Test pagination with search.""" repo = RequirementRepository(db_session) repo.create(regulation_id=sample_regulation.id, article="Art. 1", title="Security of processing") repo.create(regulation_id=sample_regulation.id, article="Art. 2", title="Data minimization") # Search for "security" items, total = repo.get_paginated(search="security") assert total == 1 assert "security" in items[0].title.lower() def test_update_requirement(self, db_session, sample_regulation): """Test updating a requirement.""" repo = RequirementRepository(db_session) requirement = repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Original", is_applicable=True, ) # Update via model directly (RequirementRepository doesn't have update method) requirement.title = "Updated Title" requirement.implementation_status = "implemented" db_session.commit() db_session.refresh(requirement) assert requirement.title == "Updated Title" assert requirement.implementation_status == "implemented" # ============================================================================ # ControlRepository Tests # ============================================================================ class TestControlRepository: """Tests for ControlRepository CRUD operations.""" def test_create_control(self, db_session): """Test creating a control.""" repo = ControlRepository(db_session) control = repo.create( control_id="CRYPTO-001", title="TLS 1.3 Encryption", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="All external communication uses TLS 1.3", description="Enforce TLS 1.3 for all connections", is_automated=True, automation_tool="NGINX", ) assert control.id is not None assert control.control_id == "CRYPTO-001" assert control.domain == ControlDomainEnum.CRYPTO assert control.is_automated is True def test_get_control_by_id(self, db_session, sample_control): """Test getting control by UUID.""" repo = ControlRepository(db_session) found = repo.get_by_id(sample_control.id) assert found is not None assert found.id == sample_control.id def test_get_control_by_control_id(self, db_session, sample_control): """Test getting control by control_id.""" repo = ControlRepository(db_session) found = repo.get_by_control_id("CRYPTO-001") assert found is not None assert found.control_id == "CRYPTO-001" def test_get_all_controls(self, db_session): """Test getting all controls.""" repo = ControlRepository(db_session) repo.create(control_id="CRYPTO-001", title="Crypto", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") repo.create(control_id="IAM-001", title="IAM", domain=ControlDomainEnum.IAM, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") all_controls = repo.get_all() assert len(all_controls) == 2 def test_get_controls_filter_by_domain(self, db_session): """Test filtering controls by domain.""" repo = ControlRepository(db_session) repo.create(control_id="CRYPTO-001", title="Crypto", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") repo.create(control_id="IAM-001", title="IAM", domain=ControlDomainEnum.IAM, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") crypto_controls = repo.get_all(domain=ControlDomainEnum.CRYPTO) assert len(crypto_controls) == 1 assert crypto_controls[0].control_id == "CRYPTO-001" def test_get_controls_filter_by_status(self, db_session): """Test filtering controls by status.""" repo = ControlRepository(db_session) pass_ctrl = repo.create(control_id="PASS-001", title="Pass", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") fail_ctrl = repo.create(control_id="FAIL-001", title="Fail", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") # Use update_status method with control_id (not UUID) repo.update_status("PASS-001", ControlStatusEnum.PASS) repo.update_status("FAIL-001", ControlStatusEnum.FAIL) passing_controls = repo.get_all(status=ControlStatusEnum.PASS) assert len(passing_controls) == 1 assert passing_controls[0].control_id == "PASS-001" def test_get_controls_filter_by_automated(self, db_session): """Test filtering controls by automation.""" repo = ControlRepository(db_session) repo.create(control_id="AUTO-001", title="Automated", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass", is_automated=True) repo.create(control_id="MANUAL-001", title="Manual", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass", is_automated=False) automated = repo.get_all(is_automated=True) assert len(automated) == 1 assert automated[0].control_id == "AUTO-001" def test_update_control(self, db_session, sample_control): """Test updating a control status.""" repo = ControlRepository(db_session) updated = repo.update_status( sample_control.control_id, ControlStatusEnum.PASS, status_notes="Implemented via NGINX config", ) assert updated is not None assert updated.status == ControlStatusEnum.PASS assert updated.status_notes == "Implemented via NGINX config" def test_delete_control(self, db_session, sample_control): """Test deleting a control (via model).""" repo = ControlRepository(db_session) # Delete via database directly (ControlRepository doesn't have delete method) db_session.delete(sample_control) db_session.commit() found = repo.get_by_id(sample_control.id) assert found is None def test_get_statistics(self, db_session): """Test getting control statistics.""" repo = ControlRepository(db_session) # Create controls with different statuses ctrl1 = repo.create(control_id="PASS-1", title="Pass 1", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") ctrl2 = repo.create(control_id="PASS-2", title="Pass 2", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") ctrl3 = repo.create(control_id="PARTIAL-1", title="Partial", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") ctrl4 = repo.create(control_id="FAIL-1", title="Fail", domain=ControlDomainEnum.CRYPTO, control_type=ControlTypeEnum.PREVENTIVE, pass_criteria="Pass") repo.update_status("PASS-1", ControlStatusEnum.PASS) repo.update_status("PASS-2", ControlStatusEnum.PASS) repo.update_status("PARTIAL-1", ControlStatusEnum.PARTIAL) repo.update_status("FAIL-1", ControlStatusEnum.FAIL) stats = repo.get_statistics() assert stats["total"] == 4 # Check if keys exist, they might be None or status values by_status = stats["by_status"] assert by_status.get("pass", 0) == 2 assert by_status.get("partial", 0) == 1 assert by_status.get("fail", 0) == 1 # Score = (2 pass + 0.5 * 1 partial) / 4 = 62.5% expected_score = ((2 + 0.5) / 4) * 100 assert stats["compliance_score"] == round(expected_score, 1) # ============================================================================ # EvidenceRepository Tests # ============================================================================ class TestEvidenceRepository: """Tests for EvidenceRepository.create().""" def test_create_evidence(self, db_session, sample_control): """Test creating evidence.""" repo = EvidenceRepository(db_session) evidence = repo.create( control_id=sample_control.control_id, evidence_type="report", title="SAST Report", description="Semgrep scan results", artifact_path="/path/to/report.json", artifact_hash="abc123", source="ci_pipeline", ci_job_id="job-123", ) assert evidence.id is not None assert evidence.title == "SAST Report" assert evidence.source == "ci_pipeline" assert evidence.ci_job_id == "job-123" def test_create_evidence_control_not_found(self, db_session): """Test creating evidence for non-existent control raises error.""" repo = EvidenceRepository(db_session) with pytest.raises(ValueError) as excinfo: repo.create( control_id="NONEXISTENT-001", evidence_type="report", title="Test", ) assert "not found" in str(excinfo.value).lower() def test_get_evidence_by_id(self, db_session, sample_control): """Test getting evidence by ID.""" repo = EvidenceRepository(db_session) created = repo.create( control_id=sample_control.control_id, evidence_type="report", title="Test Evidence", ) found = repo.get_by_id(created.id) assert found is not None assert found.id == created.id def test_get_evidence_by_control(self, db_session, sample_control): """Test getting evidence by control.""" repo = EvidenceRepository(db_session) repo.create(control_id=sample_control.control_id, evidence_type="report", title="Evidence 1") repo.create(control_id=sample_control.control_id, evidence_type="report", title="Evidence 2") evidence_list = repo.get_by_control(sample_control.control_id) assert len(evidence_list) == 2 def test_get_evidence_filter_by_status(self, db_session, sample_control): """Test filtering evidence by status.""" repo = EvidenceRepository(db_session) valid = repo.create(control_id=sample_control.control_id, evidence_type="report", title="Valid") expired = repo.create(control_id=sample_control.control_id, evidence_type="report", title="Expired") repo.update_status(valid.id, EvidenceStatusEnum.VALID) repo.update_status(expired.id, EvidenceStatusEnum.EXPIRED) valid_evidence = repo.get_by_control(sample_control.control_id, status=EvidenceStatusEnum.VALID) assert len(valid_evidence) == 1 assert valid_evidence[0].title == "Valid" def test_create_evidence_with_ci_metadata(self, db_session, sample_control): """Test creating evidence with CI/CD metadata.""" repo = EvidenceRepository(db_session) evidence = repo.create( control_id=sample_control.control_id, evidence_type="sast_report", title="Semgrep Scan", description="Static analysis results", source="ci_pipeline", ci_job_id="github-actions-123", artifact_hash="sha256:abc123", mime_type="application/json", ) assert evidence.source == "ci_pipeline" assert evidence.ci_job_id == "github-actions-123" assert evidence.mime_type == "application/json" # ============================================================================ # ControlMappingRepository Tests # ============================================================================ class TestControlMappingRepository: """Tests for requirement-control mappings.""" def test_create_mapping(self, db_session, sample_regulation, sample_control): """Test creating a requirement-control mapping.""" req_repo = RequirementRepository(db_session) mapping_repo = ControlMappingRepository(db_session) requirement = req_repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Security", is_applicable=True, ) mapping = mapping_repo.create( requirement_id=requirement.id, control_id=sample_control.control_id, coverage_level="full", notes="Fully covered by TLS encryption", ) assert mapping.id is not None assert mapping.requirement_id == requirement.id assert mapping.coverage_level == "full" def test_create_mapping_control_not_found(self, db_session, sample_regulation): """Test creating mapping with non-existent control raises error.""" req_repo = RequirementRepository(db_session) mapping_repo = ControlMappingRepository(db_session) requirement = req_repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Security", is_applicable=True, ) with pytest.raises(ValueError) as excinfo: mapping_repo.create( requirement_id=requirement.id, control_id="NONEXISTENT-001", ) assert "not found" in str(excinfo.value).lower() def test_get_mappings_by_requirement(self, db_session, sample_regulation, sample_control): """Test getting mappings by requirement.""" req_repo = RequirementRepository(db_session) mapping_repo = ControlMappingRepository(db_session) requirement = req_repo.create( regulation_id=sample_regulation.id, article="Art. 32", title="Security", is_applicable=True, ) mapping_repo.create(requirement_id=requirement.id, control_id=sample_control.control_id) mappings = mapping_repo.get_by_requirement(requirement.id) assert len(mappings) == 1 def test_get_mappings_by_control(self, db_session, sample_regulation, sample_control): """Test getting mappings by control.""" req_repo = RequirementRepository(db_session) mapping_repo = ControlMappingRepository(db_session) req1 = req_repo.create(regulation_id=sample_regulation.id, article="Art. 1", title="Req 1", is_applicable=True) req2 = req_repo.create(regulation_id=sample_regulation.id, article="Art. 2", title="Req 2", is_applicable=True) mapping_repo.create(requirement_id=req1.id, control_id=sample_control.control_id) mapping_repo.create(requirement_id=req2.id, control_id=sample_control.control_id) mappings = mapping_repo.get_by_control(sample_control.id) assert len(mappings) == 2 if __name__ == "__main__": pytest.main([__file__, "-v"])