# mypy: disable-error-code="arg-type,assignment,union-attr" """ Source Policy service — allowed sources, operations matrix, PII rules, blocked content, audit, stats, compliance report. Phase 1 Step 4: extracted from ``compliance.api.source_policy_router``. """ from datetime import datetime, timezone from typing import Any, Optional from sqlalchemy.orm import Session from compliance.db.source_policy_models import ( AllowedSourceDB, BlockedContentDB, PIIRuleDB, SourceOperationDB, SourcePolicyAuditDB, ) from compliance.domain import ConflictError, NotFoundError from compliance.schemas.source_policy import ( OperationUpdate, PIIRuleCreate, PIIRuleUpdate, SourceCreate, SourceUpdate, ) # ============================================================================ # Module-level helpers (re-exported by compliance.api.source_policy_router for # legacy test imports). # ============================================================================ def _log_audit( db: Session, action: str, entity_type: str, entity_id: Any, old_values: Optional[dict[str, Any]] = None, new_values: Optional[dict[str, Any]] = None, ) -> None: db.add( SourcePolicyAuditDB( action=action, entity_type=entity_type, entity_id=entity_id, old_values=old_values, new_values=new_values, user_id="system", ) ) def _source_to_dict(s: AllowedSourceDB) -> dict[str, Any]: return { "id": str(s.id), "domain": s.domain, "name": s.name, "description": s.description, "license": s.license, "legal_basis": s.legal_basis, "trust_boost": s.trust_boost, "source_type": s.source_type, "active": s.active, } def _full_source_dict(s: AllowedSourceDB) -> dict[str, Any]: return { **_source_to_dict(s), "metadata": s.metadata_, "created_at": s.created_at.isoformat() if s.created_at else None, "updated_at": s.updated_at.isoformat() if s.updated_at else None, } def _parse_iso_optional(value: Optional[str]) -> Optional[datetime]: if not value: return None try: return datetime.fromisoformat(value) except ValueError: return None # ============================================================================ # Service # ============================================================================ class SourcePolicyService: """Business logic for the source policy admin surface.""" def __init__(self, db: Session) -> None: self.db = db # ------------------------------------------------------------------ # Sources CRUD # ------------------------------------------------------------------ def list_sources( self, active_only: bool, source_type: Optional[str], license: Optional[str], ) -> dict[str, Any]: q = self.db.query(AllowedSourceDB) if active_only: q = q.filter(AllowedSourceDB.active) if source_type: q = q.filter(AllowedSourceDB.source_type == source_type) if license: q = q.filter(AllowedSourceDB.license == license) sources = q.order_by(AllowedSourceDB.name).all() return { "sources": [_full_source_dict(s) for s in sources], "count": len(sources), } def create_source(self, data: SourceCreate) -> dict[str, Any]: existing = ( self.db.query(AllowedSourceDB) .filter(AllowedSourceDB.domain == data.domain) .first() ) if existing: raise ConflictError( f"Source with domain '{data.domain}' already exists" ) source = AllowedSourceDB( domain=data.domain, name=data.name, description=data.description, license=data.license, legal_basis=data.legal_basis, trust_boost=data.trust_boost, source_type=data.source_type, active=data.active, metadata_=data.metadata, ) self.db.add(source) _log_audit( self.db, "create", "source", source.id, new_values=_source_to_dict(source), ) self.db.commit() self.db.refresh(source) return { "id": str(source.id), "domain": source.domain, "name": source.name, "created_at": source.created_at.isoformat(), } def _source_or_raise(self, source_id: str) -> AllowedSourceDB: source = ( self.db.query(AllowedSourceDB) .filter(AllowedSourceDB.id == source_id) .first() ) if not source: raise NotFoundError("Source not found") return source def get_source(self, source_id: str) -> dict[str, Any]: return _full_source_dict(self._source_or_raise(source_id)) def update_source(self, source_id: str, data: SourceUpdate) -> dict[str, Any]: source = self._source_or_raise(source_id) old_values = _source_to_dict(source) update_data = data.model_dump(exclude_unset=True) if "metadata" in update_data: update_data["metadata_"] = update_data.pop("metadata") for key, value in update_data.items(): setattr(source, key, value) _log_audit( self.db, "update", "source", source.id, old_values=old_values, new_values=update_data, ) self.db.commit() self.db.refresh(source) return {"status": "updated", "id": str(source.id)} def delete_source(self, source_id: str) -> dict[str, Any]: source = self._source_or_raise(source_id) old_values = _source_to_dict(source) _log_audit(self.db, "delete", "source", source.id, old_values=old_values) self.db.query(SourceOperationDB).filter( SourceOperationDB.source_id == source_id ).delete() self.db.delete(source) self.db.commit() return {"status": "deleted", "id": source_id} # ------------------------------------------------------------------ # Operations matrix # ------------------------------------------------------------------ def get_operations_matrix(self) -> dict[str, Any]: operations = self.db.query(SourceOperationDB).all() return { "operations": [ { "id": str(op.id), "source_id": str(op.source_id), "operation": op.operation, "allowed": op.allowed, "conditions": op.conditions, } for op in operations ], "count": len(operations), } def update_operation( self, operation_id: str, data: OperationUpdate ) -> dict[str, Any]: op = ( self.db.query(SourceOperationDB) .filter(SourceOperationDB.id == operation_id) .first() ) if not op: raise NotFoundError("Operation not found") op.allowed = data.allowed if data.conditions is not None: op.conditions = data.conditions _log_audit( self.db, "update", "operation", op.id, new_values={"allowed": data.allowed}, ) self.db.commit() return {"status": "updated", "id": str(op.id)} # ------------------------------------------------------------------ # PII rules # ------------------------------------------------------------------ def list_pii_rules(self, category: Optional[str]) -> dict[str, Any]: q = self.db.query(PIIRuleDB) if category: q = q.filter(PIIRuleDB.category == category) rules = q.order_by(PIIRuleDB.category, PIIRuleDB.name).all() return { "rules": [ { "id": str(r.id), "name": r.name, "description": r.description, "pattern": r.pattern, "category": r.category, "action": r.action, "active": r.active, "created_at": r.created_at.isoformat() if r.created_at else None, } for r in rules ], "count": len(rules), } def create_pii_rule(self, data: PIIRuleCreate) -> dict[str, Any]: rule = PIIRuleDB( name=data.name, description=data.description, pattern=data.pattern, category=data.category, action=data.action, active=data.active, ) self.db.add(rule) _log_audit( self.db, "create", "pii_rule", rule.id, new_values={"name": data.name, "category": data.category}, ) self.db.commit() self.db.refresh(rule) return {"id": str(rule.id), "name": rule.name} def _rule_or_raise(self, rule_id: str) -> PIIRuleDB: rule = self.db.query(PIIRuleDB).filter(PIIRuleDB.id == rule_id).first() if not rule: raise NotFoundError("PII rule not found") return rule def update_pii_rule(self, rule_id: str, data: PIIRuleUpdate) -> dict[str, Any]: rule = self._rule_or_raise(rule_id) update_data = data.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(rule, key, value) _log_audit(self.db, "update", "pii_rule", rule.id, new_values=update_data) self.db.commit() return {"status": "updated", "id": str(rule.id)} def delete_pii_rule(self, rule_id: str) -> dict[str, Any]: rule = self._rule_or_raise(rule_id) _log_audit( self.db, "delete", "pii_rule", rule.id, old_values={"name": rule.name, "category": rule.category}, ) self.db.delete(rule) self.db.commit() return {"status": "deleted", "id": rule_id} # ------------------------------------------------------------------ # Blocked content + audit # ------------------------------------------------------------------ def list_blocked_content( self, limit: int, offset: int, domain: Optional[str], date_from: Optional[str], date_to: Optional[str], ) -> dict[str, Any]: q = self.db.query(BlockedContentDB) if domain: q = q.filter(BlockedContentDB.domain == domain) from_dt = _parse_iso_optional(date_from) if from_dt: q = q.filter(BlockedContentDB.created_at >= from_dt) to_dt = _parse_iso_optional(date_to) if to_dt: q = q.filter(BlockedContentDB.created_at <= to_dt) total = q.count() entries = ( q.order_by(BlockedContentDB.created_at.desc()) .offset(offset) .limit(limit) .all() ) return { "blocked": [ { "id": str(e.id), "url": e.url, "domain": e.domain, "block_reason": e.block_reason, "rule_id": str(e.rule_id) if e.rule_id else None, "details": e.details, "created_at": e.created_at.isoformat() if e.created_at else None, } for e in entries ], "total": total, } def get_audit( self, limit: int, offset: int, entity_type: Optional[str], date_from: Optional[str], date_to: Optional[str], ) -> dict[str, Any]: q = self.db.query(SourcePolicyAuditDB) if entity_type: q = q.filter(SourcePolicyAuditDB.entity_type == entity_type) from_dt = _parse_iso_optional(date_from) if from_dt: q = q.filter(SourcePolicyAuditDB.created_at >= from_dt) to_dt = _parse_iso_optional(date_to) if to_dt: q = q.filter(SourcePolicyAuditDB.created_at <= to_dt) total = q.count() entries = ( q.order_by(SourcePolicyAuditDB.created_at.desc()) .offset(offset) .limit(limit) .all() ) return { "entries": [ { "id": str(e.id), "action": e.action, "entity_type": e.entity_type, "entity_id": str(e.entity_id) if e.entity_id else None, "old_values": e.old_values, "new_values": e.new_values, "user_id": e.user_id, "created_at": e.created_at.isoformat() if e.created_at else None, } for e in entries ], "total": total, "limit": limit, "offset": offset, } # ------------------------------------------------------------------ # Stats + report # ------------------------------------------------------------------ def stats(self) -> dict[str, Any]: total_sources = self.db.query(AllowedSourceDB).count() active_sources = ( self.db.query(AllowedSourceDB).filter(AllowedSourceDB.active).count() ) pii_rules = self.db.query(PIIRuleDB).filter(PIIRuleDB.active).count() today_start = datetime.now(timezone.utc).replace( hour=0, minute=0, second=0, microsecond=0 ) blocked_today = ( self.db.query(BlockedContentDB) .filter(BlockedContentDB.created_at >= today_start) .count() ) blocked_total = self.db.query(BlockedContentDB).count() return { "active_policies": active_sources, "allowed_sources": total_sources, "pii_rules": pii_rules, "blocked_today": blocked_today, "blocked_total": blocked_total, } def compliance_report(self) -> dict[str, Any]: sources = ( self.db.query(AllowedSourceDB).filter(AllowedSourceDB.active).all() ) pii_rules = self.db.query(PIIRuleDB).filter(PIIRuleDB.active).all() return { "report_date": datetime.now(timezone.utc).isoformat(), "summary": { "active_sources": len(sources), "active_pii_rules": len(pii_rules), "source_types": list({s.source_type for s in sources}), "licenses": list({s.license for s in sources if s.license}), }, "sources": [ { "domain": s.domain, "name": s.name, "license": s.license, "legal_basis": s.legal_basis, "trust_boost": s.trust_boost, } for s in sources ], "pii_rules": [ {"name": r.name, "category": r.category, "action": r.action} for r in pii_rules ], }