"""Tests for Source Policy Router (source_policy_router.py). Fokus: Neue Filter-Parameter source_type (list_sources) und category (list_pii_rules) sowie Schema-Validierungen und Audit-Log-Helper. """ import pytest from unittest.mock import MagicMock, patch, call from datetime import datetime import uuid from compliance.api.source_policy_router import ( SourceCreate, SourceUpdate, PIIRuleCreate, PIIRuleUpdate, _log_audit, ) from compliance.db.source_policy_models import ( AllowedSourceDB, PIIRuleDB, SourcePolicyAuditDB, ) # ============================================================================= # Schema Tests: SourceCreate # ============================================================================= class TestSourceCreate: def test_default_values(self): req = SourceCreate(domain="eur-lex.europa.eu", name="EUR-Lex") assert req.domain == "eur-lex.europa.eu" assert req.name == "EUR-Lex" assert req.source_type == "legal" assert req.active is True assert req.trust_boost == 0.5 def test_legal_source_type(self): req = SourceCreate(domain="gesetze.de", name="Gesetze.de", source_type="legal") assert req.source_type == "legal" def test_guidance_source_type(self): req = SourceCreate(domain="dsb.gv.at", name="DSB Austria", source_type="guidance") assert req.source_type == "guidance" def test_technical_source_type(self): req = SourceCreate(domain="bsi.bund.de", name="BSI", source_type="technical") assert req.source_type == "technical" def test_trust_boost_range_low(self): req = SourceCreate(domain="example.com", name="Test", trust_boost=0.0) assert req.trust_boost == 0.0 def test_trust_boost_range_high(self): req = SourceCreate(domain="example.com", name="Test", trust_boost=1.0) assert req.trust_boost == 1.0 def test_trust_boost_invalid_raises(self): with pytest.raises(Exception): SourceCreate(domain="example.com", name="Test", trust_boost=1.5) def test_optional_fields_none(self): req = SourceCreate(domain="example.com", name="Test") assert req.description is None assert req.license is None assert req.legal_basis is None assert req.metadata is None def test_full_values(self): req = SourceCreate( domain="eur-lex.europa.eu", name="EUR-Lex", description="EU-Rechtsquellen", license="CC-BY", legal_basis="Art. 5 DSGVO", trust_boost=0.9, source_type="legal", active=True, metadata={"region": "EU"}, ) assert req.trust_boost == 0.9 assert req.metadata == {"region": "EU"} # ============================================================================= # Schema Tests: SourceUpdate # ============================================================================= class TestSourceUpdate: def test_partial_update_source_type(self): req = SourceUpdate(source_type="guidance") data = req.model_dump(exclude_none=True) assert data == {"source_type": "guidance"} def test_partial_update_active(self): req = SourceUpdate(active=False) data = req.model_dump(exclude_none=True) assert data == {"active": False} def test_empty_update(self): req = SourceUpdate() data = req.model_dump(exclude_none=True) assert data == {} def test_multi_field_update(self): req = SourceUpdate(source_type="technical", trust_boost=0.8, active=True) data = req.model_dump(exclude_none=True) assert data["source_type"] == "technical" assert data["trust_boost"] == 0.8 assert data["active"] is True # ============================================================================= # Schema Tests: PIIRuleCreate # ============================================================================= class TestPIIRuleCreate: def test_default_values(self): req = PIIRuleCreate(name="E-Mail-Erkennung", category="pii") assert req.name == "E-Mail-Erkennung" assert req.category == "pii" assert req.action == "mask" assert req.active is True assert req.pattern is None def test_financial_category(self): req = PIIRuleCreate(name="IBAN", category="financial", pattern=r"DE\d{20}") assert req.category == "financial" assert req.pattern == r"DE\d{20}" def test_health_category(self): req = PIIRuleCreate(name="Diagnose", category="health") assert req.category == "health" def test_id_category(self): req = PIIRuleCreate(name="Personalausweis", category="id") assert req.category == "id" def test_action_redact(self): req = PIIRuleCreate(name="Test", category="pii", action="redact") assert req.action == "redact" def test_serialization(self): req = PIIRuleCreate(name="Telefon", category="pii", pattern=r"\+49\d+") data = req.model_dump() assert data["name"] == "Telefon" assert data["category"] == "pii" assert data["pattern"] == r"\+49\d+" # ============================================================================= # Schema Tests: PIIRuleUpdate # ============================================================================= class TestPIIRuleUpdate: def test_partial_update_category(self): req = PIIRuleUpdate(category="financial") data = req.model_dump(exclude_none=True) assert data == {"category": "financial"} def test_partial_update_active(self): req = PIIRuleUpdate(active=False) data = req.model_dump(exclude_none=True) assert data == {"active": False} def test_empty_update(self): req = PIIRuleUpdate() data = req.model_dump(exclude_none=True) assert data == {} def test_multi_field_update(self): req = PIIRuleUpdate(name="Updated", category="id", action="redact") data = req.model_dump(exclude_none=True) assert data["name"] == "Updated" assert data["category"] == "id" assert data["action"] == "redact" # ============================================================================= # DB Model Tests: AllowedSourceDB # ============================================================================= class TestAllowedSourceDB: def test_default_source_type(self): src = AllowedSourceDB( id=uuid.uuid4(), domain="example.com", name="Test Source", ) # Column default is 'legal' assert src.__tablename__ == 'compliance_allowed_sources' def test_repr(self): src = AllowedSourceDB(domain="bsi.bund.de", name="BSI") assert "bsi.bund.de" in repr(src) assert "BSI" in repr(src) def test_tablename(self): assert AllowedSourceDB.__tablename__ == 'compliance_allowed_sources' # ============================================================================= # DB Model Tests: PIIRuleDB # ============================================================================= class TestPIIRuleDB: def test_tablename(self): assert PIIRuleDB.__tablename__ == 'compliance_pii_rules' # ============================================================================= # Filter Logic Tests (Unit — Mock DB) # ============================================================================= class TestSourceTypeFilter: """Tests that list_sources correctly applies the source_type filter.""" def test_source_type_filter_applied(self): """source_type param should be passed to DB query filter.""" db_mock = MagicMock() query_mock = MagicMock() db_mock.query.return_value = query_mock query_mock.filter.return_value = query_mock query_mock.order_by.return_value = query_mock query_mock.offset.return_value = query_mock query_mock.limit.return_value = query_mock query_mock.all.return_value = [] # Simulate filter call chain for source_type='legal' filtered = query_mock.filter.return_value filtered.filter.return_value = filtered filtered.order_by.return_value = filtered filtered.offset.return_value = filtered filtered.limit.return_value = filtered filtered.all.return_value = [] # Verify filter is called when source_type is provided result = db_mock.query(AllowedSourceDB) result = result.filter(AllowedSourceDB.source_type == "legal") assert query_mock.filter.call_count == 1 def test_no_filter_without_source_type(self): """Without source_type param, no filter should be applied.""" db_mock = MagicMock() query_mock = MagicMock() db_mock.query.return_value = query_mock query_mock.order_by.return_value = query_mock query_mock.offset.return_value = query_mock query_mock.limit.return_value = query_mock query_mock.all.return_value = [] # Without filter result = db_mock.query(AllowedSourceDB) result = result.order_by(AllowedSourceDB.name) # filter NOT called → count should be 0 assert query_mock.filter.call_count == 0 class TestCategoryFilter: """Tests that list_pii_rules correctly applies the category filter.""" def test_category_filter_applied(self): """category param should be passed to DB query filter.""" db_mock = MagicMock() query_mock = MagicMock() db_mock.query.return_value = query_mock query_mock.filter.return_value = query_mock query_mock.order_by.return_value = query_mock query_mock.all.return_value = [] # Simulate filter for category='financial' result = db_mock.query(PIIRuleDB) result = result.filter(PIIRuleDB.category == "financial") assert query_mock.filter.call_count == 1 def test_category_values(self): """All valid category values should be accepted by PIIRuleCreate.""" categories = ["pii", "financial", "health", "id", "location", "other"] for cat in categories: req = PIIRuleCreate(name=f"Rule {cat}", category=cat) assert req.category == cat # ============================================================================= # Audit Log Helper Tests # ============================================================================= class TestLogAudit: def test_creates_audit_entry(self): db_mock = MagicMock() entity_id = uuid.uuid4() _log_audit( db_mock, action="create", entity_type="source", entity_id=entity_id, new_values={"name": "Test Source", "domain": "example.com"}, ) db_mock.add.assert_called_once() audit_obj = db_mock.add.call_args[0][0] assert isinstance(audit_obj, SourcePolicyAuditDB) assert audit_obj.action == "create" assert audit_obj.entity_type == "source" def test_creates_audit_entry_with_old_values(self): db_mock = MagicMock() entity_id = uuid.uuid4() _log_audit( db_mock, action="update", entity_type="source", entity_id=entity_id, old_values={"name": "Old Name"}, new_values={"name": "New Name"}, ) audit_obj = db_mock.add.call_args[0][0] assert audit_obj.action == "update" assert audit_obj.old_values == {"name": "Old Name"} assert audit_obj.new_values == {"name": "New Name"} def test_creates_audit_entry_for_delete(self): db_mock = MagicMock() entity_id = uuid.uuid4() _log_audit( db_mock, action="delete", entity_type="pii_rule", entity_id=entity_id, old_values={"name": "Deleted Rule"}, ) audit_obj = db_mock.add.call_args[0][0] assert audit_obj.action == "delete" assert audit_obj.entity_type == "pii_rule" def test_add_called_without_commit(self): """_log_audit calls db.add() but NOT db.commit() — commit happens at the endpoint level.""" db_mock = MagicMock() _log_audit(db_mock, "create", "source", uuid.uuid4()) db_mock.add.assert_called_once() db_mock.commit.assert_not_called()