""" Repository layer for Compliance module. Provides CRUD operations and business logic queries for all compliance entities. """ from __future__ import annotations import uuid from datetime import datetime, date from typing import List, Optional, Dict, Any from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from typing import Tuple from .models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum, ServiceModuleDB, ModuleRegulationMappingDB, ) class RegulationRepository: """Repository for regulations/standards.""" def __init__(self, db: DBSession): self.db = db def create( self, code: str, name: str, regulation_type: RegulationTypeEnum, full_name: Optional[str] = None, source_url: Optional[str] = None, local_pdf_path: Optional[str] = None, effective_date: Optional[date] = None, description: Optional[str] = None, ) -> RegulationDB: """Create a new regulation.""" regulation = RegulationDB( id=str(uuid.uuid4()), code=code, name=name, full_name=full_name, regulation_type=regulation_type, source_url=source_url, local_pdf_path=local_pdf_path, effective_date=effective_date, description=description, ) self.db.add(regulation) self.db.commit() self.db.refresh(regulation) return regulation def get_by_id(self, regulation_id: str) -> Optional[RegulationDB]: """Get regulation by ID.""" return self.db.query(RegulationDB).filter(RegulationDB.id == regulation_id).first() def get_by_code(self, code: str) -> Optional[RegulationDB]: """Get regulation by code (e.g., 'GDPR').""" return self.db.query(RegulationDB).filter(RegulationDB.code == code).first() def get_all( self, regulation_type: Optional[RegulationTypeEnum] = None, is_active: Optional[bool] = True ) -> List[RegulationDB]: """Get all regulations with optional filters.""" query = self.db.query(RegulationDB) if regulation_type: query = query.filter(RegulationDB.regulation_type == regulation_type) if is_active is not None: query = query.filter(RegulationDB.is_active == is_active) return query.order_by(RegulationDB.code).all() def update(self, regulation_id: str, **kwargs) -> Optional[RegulationDB]: """Update a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return None for key, value in kwargs.items(): if hasattr(regulation, key): setattr(regulation, key, value) regulation.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(regulation) return regulation def delete(self, regulation_id: str) -> bool: """Delete a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return False self.db.delete(regulation) self.db.commit() return True def get_active(self) -> List[RegulationDB]: """Get all active regulations.""" return self.get_all(is_active=True) def count(self) -> int: """Count all regulations.""" return self.db.query(func.count(RegulationDB.id)).scalar() or 0 class RequirementRepository: """Repository for requirements.""" def __init__(self, db: DBSession): self.db = db def create( self, regulation_id: str, article: str, title: str, paragraph: Optional[str] = None, description: Optional[str] = None, requirement_text: Optional[str] = None, breakpilot_interpretation: Optional[str] = None, is_applicable: bool = True, priority: int = 2, ) -> RequirementDB: """Create a new requirement.""" requirement = RequirementDB( id=str(uuid.uuid4()), regulation_id=regulation_id, article=article, paragraph=paragraph, title=title, description=description, requirement_text=requirement_text, breakpilot_interpretation=breakpilot_interpretation, is_applicable=is_applicable, priority=priority, ) self.db.add(requirement) self.db.commit() self.db.refresh(requirement) return requirement def get_by_id(self, requirement_id: str) -> Optional[RequirementDB]: """Get requirement by ID with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.id == requirement_id) .first() ) def get_by_regulation( self, regulation_id: str, is_applicable: Optional[bool] = None ) -> List[RequirementDB]: """Get all requirements for a regulation with eager-loaded controls.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.regulation_id == regulation_id) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_by_regulation_code(self, code: str) -> List[RequirementDB]: """Get requirements by regulation code with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .join(RegulationDB) .filter(RegulationDB.code == code) .order_by(RequirementDB.article, RequirementDB.paragraph) .all() ) def get_all(self, is_applicable: Optional[bool] = None) -> List[RequirementDB]: """Get all requirements with optional filter and eager-loading.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_paginated( self, page: int = 1, page_size: int = 50, regulation_code: Optional[str] = None, status: Optional[str] = None, is_applicable: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[RequirementDB], int]: """ Get paginated requirements with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) # Filters if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) if status: query = query.filter(RequirementDB.implementation_status == status) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.description.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Count before pagination total = query.count() # Apply pagination and ordering items = ( query .order_by(RequirementDB.priority.desc(), RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def delete(self, requirement_id: str) -> bool: """Delete a requirement.""" requirement = self.db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first() if not requirement: return False self.db.delete(requirement) self.db.commit() return True def count(self) -> int: """Count all requirements.""" return self.db.query(func.count(RequirementDB.id)).scalar() or 0 class ControlRepository: """Repository for controls.""" def __init__(self, db: DBSession): self.db = db def create( self, control_id: str, domain: ControlDomainEnum, control_type: str, title: str, pass_criteria: str, description: Optional[str] = None, implementation_guidance: Optional[str] = None, code_reference: Optional[str] = None, is_automated: bool = False, automation_tool: Optional[str] = None, owner: Optional[str] = None, review_frequency_days: int = 90, ) -> ControlDB: """Create a new control.""" control = ControlDB( id=str(uuid.uuid4()), control_id=control_id, domain=domain, control_type=control_type, title=title, description=description, pass_criteria=pass_criteria, implementation_guidance=implementation_guidance, code_reference=code_reference, is_automated=is_automated, automation_tool=automation_tool, owner=owner, review_frequency_days=review_frequency_days, ) self.db.add(control) self.db.commit() self.db.refresh(control) return control def get_by_id(self, control_uuid: str) -> Optional[ControlDB]: """Get control by UUID with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.id == control_uuid) .first() ) def get_by_control_id(self, control_id: str) -> Optional[ControlDB]: """Get control by control_id (e.g., 'PRIV-001') with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.control_id == control_id) .first() ) def get_all( self, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, ) -> List[ControlDB]: """Get all controls with optional filters and eager-loading.""" query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) return query.order_by(ControlDB.control_id).all() def get_paginated( self, page: int = 1, page_size: int = 50, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[ControlDB], int]: """ Get paginated controls with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) if search: search_term = f"%{search}%" query = query.filter( or_( ControlDB.title.ilike(search_term), ControlDB.description.ilike(search_term), ControlDB.control_id.ilike(search_term), ) ) total = query.count() items = ( query .order_by(ControlDB.control_id) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def get_by_domain(self, domain: ControlDomainEnum) -> List[ControlDB]: """Get all controls in a domain.""" return self.get_all(domain=domain) def get_by_status(self, status: ControlStatusEnum) -> List[ControlDB]: """Get all controls with a specific status.""" return self.get_all(status=status) def update_status( self, control_id: str, status: ControlStatusEnum, status_notes: Optional[str] = None ) -> Optional[ControlDB]: """Update control status.""" control = self.get_by_control_id(control_id) if not control: return None control.status = status if status_notes: control.status_notes = status_notes control.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(control) return control def mark_reviewed(self, control_id: str) -> Optional[ControlDB]: """Mark control as reviewed.""" control = self.get_by_control_id(control_id) if not control: return None control.last_reviewed_at = datetime.utcnow() from datetime import timedelta control.next_review_at = datetime.utcnow() + timedelta(days=control.review_frequency_days) control.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(control) return control def get_due_for_review(self) -> List[ControlDB]: """Get controls due for review.""" return ( self.db.query(ControlDB) .filter( or_( ControlDB.next_review_at is None, ControlDB.next_review_at <= datetime.utcnow() ) ) .order_by(ControlDB.next_review_at) .all() ) def get_statistics(self) -> Dict[str, Any]: """Get control statistics by status and domain.""" total = self.db.query(func.count(ControlDB.id)).scalar() by_status = dict( self.db.query(ControlDB.status, func.count(ControlDB.id)) .group_by(ControlDB.status) .all() ) by_domain = dict( self.db.query(ControlDB.domain, func.count(ControlDB.id)) .group_by(ControlDB.domain) .all() ) passed = by_status.get(ControlStatusEnum.PASS, 0) partial = by_status.get(ControlStatusEnum.PARTIAL, 0) score = 0.0 if total > 0: score = ((passed + (partial * 0.5)) / total) * 100 return { "total": total, "by_status": {str(k.value) if k else "none": v for k, v in by_status.items()}, "by_domain": {str(k.value) if k else "none": v for k, v in by_domain.items()}, "compliance_score": round(score, 1), } class ControlMappingRepository: """Repository for requirement-control mappings.""" def __init__(self, db: DBSession): self.db = db def create( self, requirement_id: str, control_id: str, coverage_level: str = "full", notes: Optional[str] = None, ) -> ControlMappingDB: """Create a mapping.""" # Get the control UUID from control_id control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: raise ValueError(f"Control {control_id} not found") mapping = ControlMappingDB( id=str(uuid.uuid4()), requirement_id=requirement_id, control_id=control.id, coverage_level=coverage_level, notes=notes, ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def get_by_requirement(self, requirement_id: str) -> List[ControlMappingDB]: """Get all mappings for a requirement.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.requirement_id == requirement_id) .all() ) def get_by_control(self, control_uuid: str) -> List[ControlMappingDB]: """Get all mappings for a control.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.control_id == control_uuid) .all() ) class EvidenceRepository: """Repository for evidence.""" def __init__(self, db: DBSession): self.db = db def create( self, control_id: str, evidence_type: str, title: str, description: Optional[str] = None, artifact_path: Optional[str] = None, artifact_url: Optional[str] = None, artifact_hash: Optional[str] = None, file_size_bytes: Optional[int] = None, mime_type: Optional[str] = None, valid_until: Optional[datetime] = None, source: str = "manual", ci_job_id: Optional[str] = None, uploaded_by: Optional[str] = None, ) -> EvidenceDB: """Create evidence record.""" # Get control UUID control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: raise ValueError(f"Control {control_id} not found") evidence = EvidenceDB( id=str(uuid.uuid4()), control_id=control.id, evidence_type=evidence_type, title=title, description=description, artifact_path=artifact_path, artifact_url=artifact_url, artifact_hash=artifact_hash, file_size_bytes=file_size_bytes, mime_type=mime_type, valid_until=valid_until, source=source, ci_job_id=ci_job_id, uploaded_by=uploaded_by, ) self.db.add(evidence) self.db.commit() self.db.refresh(evidence) return evidence def get_by_id(self, evidence_id: str) -> Optional[EvidenceDB]: """Get evidence by ID.""" return self.db.query(EvidenceDB).filter(EvidenceDB.id == evidence_id).first() def get_by_control( self, control_id: str, status: Optional[EvidenceStatusEnum] = None ) -> List[EvidenceDB]: """Get all evidence for a control.""" control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: return [] query = self.db.query(EvidenceDB).filter(EvidenceDB.control_id == control.id) if status: query = query.filter(EvidenceDB.status == status) return query.order_by(EvidenceDB.collected_at.desc()).all() def get_all( self, evidence_type: Optional[str] = None, status: Optional[EvidenceStatusEnum] = None, limit: int = 100, ) -> List[EvidenceDB]: """Get all evidence with filters.""" query = self.db.query(EvidenceDB) if evidence_type: query = query.filter(EvidenceDB.evidence_type == evidence_type) if status: query = query.filter(EvidenceDB.status == status) return query.order_by(EvidenceDB.collected_at.desc()).limit(limit).all() def update_status(self, evidence_id: str, status: EvidenceStatusEnum) -> Optional[EvidenceDB]: """Update evidence status.""" evidence = self.get_by_id(evidence_id) if not evidence: return None evidence.status = status evidence.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(evidence) return evidence def get_statistics(self) -> Dict[str, Any]: """Get evidence statistics.""" total = self.db.query(func.count(EvidenceDB.id)).scalar() by_type = dict( self.db.query(EvidenceDB.evidence_type, func.count(EvidenceDB.id)) .group_by(EvidenceDB.evidence_type) .all() ) by_status = dict( self.db.query(EvidenceDB.status, func.count(EvidenceDB.id)) .group_by(EvidenceDB.status) .all() ) valid = by_status.get(EvidenceStatusEnum.VALID, 0) coverage = (valid / total * 100) if total > 0 else 0 return { "total": total, "by_type": by_type, "by_status": {str(k.value) if k else "none": v for k, v in by_status.items()}, "coverage_percent": round(coverage, 1), } class RiskRepository: """Repository for risks.""" def __init__(self, db: DBSession): self.db = db def create( self, risk_id: str, title: str, category: str, likelihood: int, impact: int, description: Optional[str] = None, mitigating_controls: Optional[List[str]] = None, owner: Optional[str] = None, treatment_plan: Optional[str] = None, ) -> RiskDB: """Create a risk.""" inherent_risk = RiskDB.calculate_risk_level(likelihood, impact) risk = RiskDB( id=str(uuid.uuid4()), risk_id=risk_id, title=title, description=description, category=category, likelihood=likelihood, impact=impact, inherent_risk=inherent_risk, mitigating_controls=mitigating_controls or [], owner=owner, treatment_plan=treatment_plan, ) self.db.add(risk) self.db.commit() self.db.refresh(risk) return risk def get_by_id(self, risk_uuid: str) -> Optional[RiskDB]: """Get risk by UUID.""" return self.db.query(RiskDB).filter(RiskDB.id == risk_uuid).first() def get_by_risk_id(self, risk_id: str) -> Optional[RiskDB]: """Get risk by risk_id (e.g., 'RISK-001').""" return self.db.query(RiskDB).filter(RiskDB.risk_id == risk_id).first() def get_all( self, category: Optional[str] = None, status: Optional[str] = None, min_risk_level: Optional[RiskLevelEnum] = None, ) -> List[RiskDB]: """Get all risks with filters.""" query = self.db.query(RiskDB) if category: query = query.filter(RiskDB.category == category) if status: query = query.filter(RiskDB.status == status) if min_risk_level: risk_order = { RiskLevelEnum.LOW: 1, RiskLevelEnum.MEDIUM: 2, RiskLevelEnum.HIGH: 3, RiskLevelEnum.CRITICAL: 4, } min_order = risk_order.get(min_risk_level, 1) query = query.filter( RiskDB.inherent_risk.in_( [k for k, v in risk_order.items() if v >= min_order] ) ) return query.order_by(RiskDB.risk_id).all() def update(self, risk_id: str, **kwargs) -> Optional[RiskDB]: """Update a risk.""" risk = self.get_by_risk_id(risk_id) if not risk: return None for key, value in kwargs.items(): if hasattr(risk, key): setattr(risk, key, value) # Recalculate risk levels if likelihood/impact changed if 'likelihood' in kwargs or 'impact' in kwargs: risk.inherent_risk = RiskDB.calculate_risk_level(risk.likelihood, risk.impact) if 'residual_likelihood' in kwargs or 'residual_impact' in kwargs: if risk.residual_likelihood and risk.residual_impact: risk.residual_risk = RiskDB.calculate_risk_level( risk.residual_likelihood, risk.residual_impact ) risk.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(risk) return risk def get_matrix_data(self) -> Dict[str, Any]: """Get data for risk matrix visualization.""" risks = self.get_all() matrix = {} for risk in risks: key = f"{risk.likelihood}_{risk.impact}" if key not in matrix: matrix[key] = [] matrix[key].append({ "risk_id": risk.risk_id, "title": risk.title, "inherent_risk": risk.inherent_risk.value if risk.inherent_risk else None, }) return { "matrix": matrix, "total_risks": len(risks), "by_level": { "critical": len([r for r in risks if r.inherent_risk == RiskLevelEnum.CRITICAL]), "high": len([r for r in risks if r.inherent_risk == RiskLevelEnum.HIGH]), "medium": len([r for r in risks if r.inherent_risk == RiskLevelEnum.MEDIUM]), "low": len([r for r in risks if r.inherent_risk == RiskLevelEnum.LOW]), } } class AuditExportRepository: """Repository for audit exports.""" def __init__(self, db: DBSession): self.db = db def create( self, export_type: str, requested_by: str, export_name: Optional[str] = None, included_regulations: Optional[List[str]] = None, included_domains: Optional[List[str]] = None, date_range_start: Optional[date] = None, date_range_end: Optional[date] = None, ) -> AuditExportDB: """Create an export request.""" export = AuditExportDB( id=str(uuid.uuid4()), export_type=export_type, export_name=export_name or f"audit_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}", requested_by=requested_by, included_regulations=included_regulations, included_domains=included_domains, date_range_start=date_range_start, date_range_end=date_range_end, ) self.db.add(export) self.db.commit() self.db.refresh(export) return export def get_by_id(self, export_id: str) -> Optional[AuditExportDB]: """Get export by ID.""" return self.db.query(AuditExportDB).filter(AuditExportDB.id == export_id).first() def get_all(self, limit: int = 50) -> List[AuditExportDB]: """Get all exports.""" return ( self.db.query(AuditExportDB) .order_by(AuditExportDB.requested_at.desc()) .limit(limit) .all() ) def update_status( self, export_id: str, status: ExportStatusEnum, file_path: Optional[str] = None, file_hash: Optional[str] = None, file_size_bytes: Optional[int] = None, error_message: Optional[str] = None, total_controls: Optional[int] = None, total_evidence: Optional[int] = None, compliance_score: Optional[float] = None, ) -> Optional[AuditExportDB]: """Update export status.""" export = self.get_by_id(export_id) if not export: return None export.status = status if file_path: export.file_path = file_path if file_hash: export.file_hash = file_hash if file_size_bytes: export.file_size_bytes = file_size_bytes if error_message: export.error_message = error_message if total_controls is not None: export.total_controls = total_controls if total_evidence is not None: export.total_evidence = total_evidence if compliance_score is not None: export.compliance_score = compliance_score if status == ExportStatusEnum.COMPLETED: export.completed_at = datetime.utcnow() export.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(export) return export class ServiceModuleRepository: """Repository for service modules (Sprint 3).""" def __init__(self, db: DBSession): self.db = db def create( self, name: str, display_name: str, service_type: str, description: Optional[str] = None, port: Optional[int] = None, technology_stack: Optional[List[str]] = None, repository_path: Optional[str] = None, docker_image: Optional[str] = None, data_categories: Optional[List[str]] = None, processes_pii: bool = False, processes_health_data: bool = False, ai_components: bool = False, criticality: str = "medium", owner_team: Optional[str] = None, owner_contact: Optional[str] = None, ) -> "ServiceModuleDB": """Create a service module.""" from .models import ServiceModuleDB, ServiceTypeEnum module = ServiceModuleDB( id=str(uuid.uuid4()), name=name, display_name=display_name, description=description, service_type=ServiceTypeEnum(service_type), port=port, technology_stack=technology_stack or [], repository_path=repository_path, docker_image=docker_image, data_categories=data_categories or [], processes_pii=processes_pii, processes_health_data=processes_health_data, ai_components=ai_components, criticality=criticality, owner_team=owner_team, owner_contact=owner_contact, ) self.db.add(module) self.db.commit() self.db.refresh(module) return module def get_by_id(self, module_id: str) -> Optional["ServiceModuleDB"]: """Get module by ID.""" from .models import ServiceModuleDB return self.db.query(ServiceModuleDB).filter(ServiceModuleDB.id == module_id).first() def get_by_name(self, name: str) -> Optional["ServiceModuleDB"]: """Get module by name.""" from .models import ServiceModuleDB return self.db.query(ServiceModuleDB).filter(ServiceModuleDB.name == name).first() def get_all( self, service_type: Optional[str] = None, criticality: Optional[str] = None, processes_pii: Optional[bool] = None, ai_components: Optional[bool] = None, ) -> List["ServiceModuleDB"]: """Get all modules with filters.""" from .models import ServiceModuleDB, ServiceTypeEnum query = self.db.query(ServiceModuleDB).filter(ServiceModuleDB.is_active) if service_type: query = query.filter(ServiceModuleDB.service_type == ServiceTypeEnum(service_type)) if criticality: query = query.filter(ServiceModuleDB.criticality == criticality) if processes_pii is not None: query = query.filter(ServiceModuleDB.processes_pii == processes_pii) if ai_components is not None: query = query.filter(ServiceModuleDB.ai_components == ai_components) return query.order_by(ServiceModuleDB.name).all() def get_with_regulations(self, module_id: str) -> Optional["ServiceModuleDB"]: """Get module with regulation mappings loaded.""" from .models import ServiceModuleDB, ModuleRegulationMappingDB from sqlalchemy.orm import selectinload return ( self.db.query(ServiceModuleDB) .options( selectinload(ServiceModuleDB.regulation_mappings) .selectinload(ModuleRegulationMappingDB.regulation) ) .filter(ServiceModuleDB.id == module_id) .first() ) def add_regulation_mapping( self, module_id: str, regulation_id: str, relevance_level: str = "medium", notes: Optional[str] = None, applicable_articles: Optional[List[str]] = None, ) -> "ModuleRegulationMappingDB": """Add a regulation mapping to a module.""" from .models import ModuleRegulationMappingDB, RelevanceLevelEnum mapping = ModuleRegulationMappingDB( id=str(uuid.uuid4()), module_id=module_id, regulation_id=regulation_id, relevance_level=RelevanceLevelEnum(relevance_level), notes=notes, applicable_articles=applicable_articles, ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def get_overview(self) -> Dict[str, Any]: """Get overview statistics for all modules.""" from .models import ModuleRegulationMappingDB modules = self.get_all() total = len(modules) by_type = {} by_criticality = {} pii_count = 0 ai_count = 0 for m in modules: type_key = m.service_type.value if m.service_type else "unknown" by_type[type_key] = by_type.get(type_key, 0) + 1 by_criticality[m.criticality] = by_criticality.get(m.criticality, 0) + 1 if m.processes_pii: pii_count += 1 if m.ai_components: ai_count += 1 # Get regulation coverage regulation_coverage = {} mappings = self.db.query(ModuleRegulationMappingDB).all() for mapping in mappings: reg = mapping.regulation if reg: code = reg.code regulation_coverage[code] = regulation_coverage.get(code, 0) + 1 # Calculate average compliance score scores = [m.compliance_score for m in modules if m.compliance_score is not None] avg_score = sum(scores) / len(scores) if scores else None return { "total_modules": total, "modules_by_type": by_type, "modules_by_criticality": by_criticality, "modules_processing_pii": pii_count, "modules_with_ai": ai_count, "average_compliance_score": round(avg_score, 1) if avg_score else None, "regulations_coverage": regulation_coverage, } def seed_from_data(self, services_data: List[Dict[str, Any]], force: bool = False) -> Dict[str, int]: """Seed modules from service_modules.py data.""" modules_created = 0 mappings_created = 0 for svc in services_data: # Check if module exists existing = self.get_by_name(svc["name"]) if existing and not force: continue if existing and force: # Delete existing module (cascades to mappings) self.db.delete(existing) self.db.commit() # Create module module = self.create( name=svc["name"], display_name=svc["display_name"], description=svc.get("description"), service_type=svc["service_type"], port=svc.get("port"), technology_stack=svc.get("technology_stack"), repository_path=svc.get("repository_path"), docker_image=svc.get("docker_image"), data_categories=svc.get("data_categories"), processes_pii=svc.get("processes_pii", False), processes_health_data=svc.get("processes_health_data", False), ai_components=svc.get("ai_components", False), criticality=svc.get("criticality", "medium"), owner_team=svc.get("owner_team"), ) modules_created += 1 # Create regulation mappings for reg_data in svc.get("regulations", []): # Find regulation by code reg = self.db.query(RegulationDB).filter( RegulationDB.code == reg_data["code"] ).first() if reg: self.add_regulation_mapping( module_id=module.id, regulation_id=reg.id, relevance_level=reg_data.get("relevance", "medium"), notes=reg_data.get("notes"), ) mappings_created += 1 return { "modules_created": modules_created, "mappings_created": mappings_created, } class AuditSessionRepository: """Repository for audit sessions (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, name: str, auditor_name: str, description: Optional[str] = None, auditor_email: Optional[str] = None, regulation_ids: Optional[List[str]] = None, ) -> AuditSessionDB: """Create a new audit session.""" session = AuditSessionDB( id=str(uuid.uuid4()), name=name, description=description, auditor_name=auditor_name, auditor_email=auditor_email, regulation_ids=regulation_ids, status=AuditSessionStatusEnum.DRAFT, ) self.db.add(session) self.db.commit() self.db.refresh(session) return session def get_by_id(self, session_id: str) -> Optional[AuditSessionDB]: """Get audit session by ID with eager-loaded signoffs.""" return ( self.db.query(AuditSessionDB) .options( selectinload(AuditSessionDB.signoffs) .selectinload(AuditSignOffDB.requirement) ) .filter(AuditSessionDB.id == session_id) .first() ) def get_all( self, status: Optional[AuditSessionStatusEnum] = None, limit: int = 50, ) -> List[AuditSessionDB]: """Get all audit sessions with optional status filter.""" query = self.db.query(AuditSessionDB) if status: query = query.filter(AuditSessionDB.status == status) return query.order_by(AuditSessionDB.created_at.desc()).limit(limit).all() def update_status( self, session_id: str, status: AuditSessionStatusEnum, ) -> Optional[AuditSessionDB]: """Update session status and set appropriate timestamps.""" session = self.get_by_id(session_id) if not session: return None session.status = status if status == AuditSessionStatusEnum.IN_PROGRESS and not session.started_at: session.started_at = datetime.utcnow() elif status == AuditSessionStatusEnum.COMPLETED: session.completed_at = datetime.utcnow() session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def update_progress( self, session_id: str, total_items: Optional[int] = None, completed_items: Optional[int] = None, ) -> Optional[AuditSessionDB]: """Update session progress counters.""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return None if total_items is not None: session.total_items = total_items if completed_items is not None: session.completed_items = completed_items session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def start_session(self, session_id: str) -> Optional[AuditSessionDB]: """ Start an audit session: - Set status to IN_PROGRESS - Initialize total_items based on requirements count """ session = self.get_by_id(session_id) if not session: return None # Count requirements for this session query = self.db.query(func.count(RequirementDB.id)) if session.regulation_ids: query = query.join(RegulationDB).filter( RegulationDB.id.in_(session.regulation_ids) ) total_requirements = query.scalar() or 0 session.status = AuditSessionStatusEnum.IN_PROGRESS session.started_at = datetime.utcnow() session.total_items = total_requirements session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def delete(self, session_id: str) -> bool: """Delete an audit session (cascades to signoffs).""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return False self.db.delete(session) self.db.commit() return True def get_statistics(self, session_id: str) -> Dict[str, Any]: """Get detailed statistics for an audit session.""" session = self.get_by_id(session_id) if not session: return {} signoffs = session.signoffs or [] stats = { "total": session.total_items or 0, "completed": len([s for s in signoffs if s.result != AuditResultEnum.PENDING]), "compliant": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT]), "compliant_with_notes": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT_WITH_NOTES]), "non_compliant": len([s for s in signoffs if s.result == AuditResultEnum.NON_COMPLIANT]), "not_applicable": len([s for s in signoffs if s.result == AuditResultEnum.NOT_APPLICABLE]), "pending": len([s for s in signoffs if s.result == AuditResultEnum.PENDING]), "signed": len([s for s in signoffs if s.signature_hash]), } total = stats["total"] if stats["total"] > 0 else 1 stats["completion_percentage"] = round( (stats["completed"] / total) * 100, 1 ) return stats class AuditSignOffRepository: """Repository for audit sign-offs (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, session_id: str, requirement_id: str, result: AuditResultEnum = AuditResultEnum.PENDING, notes: Optional[str] = None, ) -> AuditSignOffDB: """Create a new sign-off for a requirement.""" signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) return signoff def get_by_id(self, signoff_id: str) -> Optional[AuditSignOffDB]: """Get sign-off by ID.""" return ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.id == signoff_id) .first() ) def get_by_session_and_requirement( self, session_id: str, requirement_id: str, ) -> Optional[AuditSignOffDB]: """Get sign-off by session and requirement ID.""" return ( self.db.query(AuditSignOffDB) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.requirement_id == requirement_id, ) ) .first() ) def get_by_session( self, session_id: str, result_filter: Optional[AuditResultEnum] = None, ) -> List[AuditSignOffDB]: """Get all sign-offs for a session.""" query = ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.session_id == session_id) ) if result_filter: query = query.filter(AuditSignOffDB.result == result_filter) return query.order_by(AuditSignOffDB.created_at).all() def update( self, signoff_id: str, result: Optional[AuditResultEnum] = None, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> Optional[AuditSignOffDB]: """Update a sign-off with optional digital signature.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return None if result is not None: signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(signoff.session_id) return signoff def sign_off( self, session_id: str, requirement_id: str, result: AuditResultEnum, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> AuditSignOffDB: """ Create or update a sign-off for a requirement. This is the main method for auditors to record their findings. """ # Check if sign-off already exists signoff = self.get_by_session_and_requirement(session_id, requirement_id) if signoff: # Update existing signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.utcnow() else: # Create new signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) if sign and signed_by: signoff.create_signature(signed_by) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(session_id) return signoff def _update_session_progress(self, session_id: str) -> None: """Update the session's completed_items count.""" completed = ( self.db.query(func.count(AuditSignOffDB.id)) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.result != AuditResultEnum.PENDING, ) ) .scalar() ) or 0 session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if session: session.completed_items = completed session.updated_at = datetime.utcnow() self.db.commit() def get_checklist( self, session_id: str, page: int = 1, page_size: int = 50, result_filter: Optional[AuditResultEnum] = None, regulation_code: Optional[str] = None, search: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """ Get audit checklist items for a session with pagination. Returns requirements with their sign-off status. """ session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return [], 0 # Base query for requirements query = ( self.db.query(RequirementDB) .options( joinedload(RequirementDB.regulation), selectinload(RequirementDB.control_mappings), ) ) # Filter by session's regulation_ids if set if session.regulation_ids: query = query.filter(RequirementDB.regulation_id.in_(session.regulation_ids)) # Filter by regulation code if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) # Search if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Get existing sign-offs for this session signoffs_map = {} signoffs = ( self.db.query(AuditSignOffDB) .filter(AuditSignOffDB.session_id == session_id) .all() ) for s in signoffs: signoffs_map[s.requirement_id] = s # Filter by result if specified if result_filter: if result_filter == AuditResultEnum.PENDING: # Requirements without sign-off or with pending status signed_req_ids = [ s.requirement_id for s in signoffs if s.result != AuditResultEnum.PENDING ] if signed_req_ids: query = query.filter(~RequirementDB.id.in_(signed_req_ids)) else: # Requirements with specific result matching_req_ids = [ s.requirement_id for s in signoffs if s.result == result_filter ] if matching_req_ids: query = query.filter(RequirementDB.id.in_(matching_req_ids)) else: return [], 0 # Count and paginate total = query.count() requirements = ( query .order_by(RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) # Build checklist items items = [] for req in requirements: signoff = signoffs_map.get(req.id) items.append({ "requirement_id": req.id, "regulation_code": req.regulation.code if req.regulation else None, "regulation_name": req.regulation.name if req.regulation else None, "article": req.article, "paragraph": req.paragraph, "title": req.title, "description": req.description, "current_result": signoff.result.value if signoff else AuditResultEnum.PENDING.value, "notes": signoff.notes if signoff else None, "is_signed": bool(signoff.signature_hash) if signoff else False, "signed_at": signoff.signed_at if signoff else None, "signed_by": signoff.signed_by if signoff else None, "evidence_count": len(req.control_mappings) if req.control_mappings else 0, "controls_mapped": len(req.control_mappings) if req.control_mappings else 0, }) return items, total def delete(self, signoff_id: str) -> bool: """Delete a sign-off.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return False session_id = signoff.session_id self.db.delete(signoff) self.db.commit() # Update session progress self._update_session_progress(session_id) return True