""" Repository layer for Compliance module. Provides CRUD operations and business logic queries for all compliance entities. """ import uuid from datetime import datetime, date from typing import List, Optional, Dict, Any from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from typing import Tuple from .models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum ) class RegulationRepository: """Repository for regulations/standards.""" def __init__(self, db: DBSession): self.db = db def create( self, code: str, name: str, regulation_type: RegulationTypeEnum, full_name: Optional[str] = None, source_url: Optional[str] = None, local_pdf_path: Optional[str] = None, effective_date: Optional[date] = None, description: Optional[str] = None, ) -> RegulationDB: """Create a new regulation.""" regulation = RegulationDB( id=str(uuid.uuid4()), code=code, name=name, full_name=full_name, regulation_type=regulation_type, source_url=source_url, local_pdf_path=local_pdf_path, effective_date=effective_date, description=description, ) self.db.add(regulation) self.db.commit() self.db.refresh(regulation) return regulation def get_by_id(self, regulation_id: str) -> Optional[RegulationDB]: """Get regulation by ID.""" return self.db.query(RegulationDB).filter(RegulationDB.id == regulation_id).first() def get_by_code(self, code: str) -> Optional[RegulationDB]: """Get regulation by code (e.g., 'GDPR').""" return self.db.query(RegulationDB).filter(RegulationDB.code == code).first() def get_all( self, regulation_type: Optional[RegulationTypeEnum] = None, is_active: Optional[bool] = True ) -> List[RegulationDB]: """Get all regulations with optional filters.""" query = self.db.query(RegulationDB) if regulation_type: query = query.filter(RegulationDB.regulation_type == regulation_type) if is_active is not None: query = query.filter(RegulationDB.is_active == is_active) return query.order_by(RegulationDB.code).all() def update(self, regulation_id: str, **kwargs) -> Optional[RegulationDB]: """Update a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return None for key, value in kwargs.items(): if hasattr(regulation, key): setattr(regulation, key, value) regulation.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(regulation) return regulation def delete(self, regulation_id: str) -> bool: """Delete a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return False self.db.delete(regulation) self.db.commit() return True def get_active(self) -> List[RegulationDB]: """Get all active regulations.""" return self.get_all(is_active=True) def count(self) -> int: """Count all regulations.""" return self.db.query(func.count(RegulationDB.id)).scalar() or 0 class RequirementRepository: """Repository for requirements.""" def __init__(self, db: DBSession): self.db = db def create( self, regulation_id: str, article: str, title: str, paragraph: Optional[str] = None, description: Optional[str] = None, requirement_text: Optional[str] = None, breakpilot_interpretation: Optional[str] = None, is_applicable: bool = True, priority: int = 2, ) -> RequirementDB: """Create a new requirement.""" requirement = RequirementDB( id=str(uuid.uuid4()), regulation_id=regulation_id, article=article, paragraph=paragraph, title=title, description=description, requirement_text=requirement_text, breakpilot_interpretation=breakpilot_interpretation, is_applicable=is_applicable, priority=priority, ) self.db.add(requirement) self.db.commit() self.db.refresh(requirement) return requirement def get_by_id(self, requirement_id: str) -> Optional[RequirementDB]: """Get requirement by ID with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.id == requirement_id) .first() ) def get_by_regulation( self, regulation_id: str, is_applicable: Optional[bool] = None ) -> List[RequirementDB]: """Get all requirements for a regulation with eager-loaded controls.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.regulation_id == regulation_id) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_by_regulation_code(self, code: str) -> List[RequirementDB]: """Get requirements by regulation code with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .join(RegulationDB) .filter(RegulationDB.code == code) .order_by(RequirementDB.article, RequirementDB.paragraph) .all() ) def get_all(self, is_applicable: Optional[bool] = None) -> List[RequirementDB]: """Get all requirements with optional filter and eager-loading.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_paginated( self, page: int = 1, page_size: int = 50, regulation_code: Optional[str] = None, status: Optional[str] = None, is_applicable: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[RequirementDB], int]: """ Get paginated requirements with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) # Filters if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) if status: query = query.filter(RequirementDB.implementation_status == status) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.description.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Count before pagination total = query.count() # Apply pagination and ordering items = ( query .order_by(RequirementDB.priority.desc(), RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def count(self) -> int: """Count all requirements.""" return self.db.query(func.count(RequirementDB.id)).scalar() or 0 class ControlRepository: """Repository for controls.""" def __init__(self, db: DBSession): self.db = db def create( self, control_id: str, domain: ControlDomainEnum, control_type: str, title: str, pass_criteria: str, description: Optional[str] = None, implementation_guidance: Optional[str] = None, code_reference: Optional[str] = None, is_automated: bool = False, automation_tool: Optional[str] = None, owner: Optional[str] = None, review_frequency_days: int = 90, ) -> ControlDB: """Create a new control.""" control = ControlDB( id=str(uuid.uuid4()), control_id=control_id, domain=domain, control_type=control_type, title=title, description=description, pass_criteria=pass_criteria, implementation_guidance=implementation_guidance, code_reference=code_reference, is_automated=is_automated, automation_tool=automation_tool, owner=owner, review_frequency_days=review_frequency_days, ) self.db.add(control) self.db.commit() self.db.refresh(control) return control def get_by_id(self, control_uuid: str) -> Optional[ControlDB]: """Get control by UUID with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.id == control_uuid) .first() ) def get_by_control_id(self, control_id: str) -> Optional[ControlDB]: """Get control by control_id (e.g., 'PRIV-001') with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.control_id == control_id) .first() ) def get_all( self, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, ) -> List[ControlDB]: """Get all controls with optional filters and eager-loading.""" query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) return query.order_by(ControlDB.control_id).all() def get_paginated( self, page: int = 1, page_size: int = 50, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[ControlDB], int]: """ Get paginated controls with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) if search: search_term = f"%{search}%" query = query.filter( or_( ControlDB.title.ilike(search_term), ControlDB.description.ilike(search_term), ControlDB.control_id.ilike(search_term), ) ) total = query.count() items = ( query .order_by(ControlDB.control_id) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def get_by_domain(self, domain: ControlDomainEnum) -> List[ControlDB]: """Get all controls in a domain.""" return self.get_all(domain=domain) def get_by_status(self, status: ControlStatusEnum) -> List[ControlDB]: """Get all controls with a specific status.""" return self.get_all(status=status) def update_status( self, control_id: str, status: ControlStatusEnum, status_notes: Optional[str] = None ) -> Optional[ControlDB]: """Update control status.""" control = self.get_by_control_id(control_id) if not control: return None control.status = status if status_notes: control.status_notes = status_notes control.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(control) return control def mark_reviewed(self, control_id: str) -> Optional[ControlDB]: """Mark control as reviewed.""" control = self.get_by_control_id(control_id) if not control: return None control.last_reviewed_at = datetime.utcnow() from datetime import timedelta control.next_review_at = datetime.utcnow() + timedelta(days=control.review_frequency_days) control.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(control) return control def get_due_for_review(self) -> List[ControlDB]: """Get controls due for review.""" return ( self.db.query(ControlDB) .filter( or_( ControlDB.next_review_at == None, ControlDB.next_review_at <= datetime.utcnow() ) ) .order_by(ControlDB.next_review_at) .all() ) def get_statistics(self) -> Dict[str, Any]: """Get control statistics by status and domain.""" total = self.db.query(func.count(ControlDB.id)).scalar() by_status = dict( self.db.query(ControlDB.status, func.count(ControlDB.id)) .group_by(ControlDB.status) .all() ) by_domain = dict( self.db.query(ControlDB.domain, func.count(ControlDB.id)) .group_by(ControlDB.domain) .all() ) passed = by_status.get(ControlStatusEnum.PASS, 0) partial = by_status.get(ControlStatusEnum.PARTIAL, 0) score = 0.0 if total > 0: score = ((passed + (partial * 0.5)) / total) * 100 return { "total": total, "by_status": {str(k.value) if k else "none": v for k, v in by_status.items()}, "by_domain": {str(k.value) if k else "none": v for k, v in by_domain.items()}, "compliance_score": round(score, 1), } class ControlMappingRepository: """Repository for requirement-control mappings.""" def __init__(self, db: DBSession): self.db = db def create( self, requirement_id: str, control_id: str, coverage_level: str = "full", notes: Optional[str] = None, ) -> ControlMappingDB: """Create a mapping.""" # Get the control UUID from control_id control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: raise ValueError(f"Control {control_id} not found") mapping = ControlMappingDB( id=str(uuid.uuid4()), requirement_id=requirement_id, control_id=control.id, coverage_level=coverage_level, notes=notes, ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def get_by_requirement(self, requirement_id: str) -> List[ControlMappingDB]: """Get all mappings for a requirement.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.requirement_id == requirement_id) .all() ) def get_by_control(self, control_uuid: str) -> List[ControlMappingDB]: """Get all mappings for a control.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.control_id == control_uuid) .all() ) class EvidenceRepository: """Repository for evidence.""" def __init__(self, db: DBSession): self.db = db def create( self, control_id: str, evidence_type: str, title: str, description: Optional[str] = None, artifact_path: Optional[str] = None, artifact_url: Optional[str] = None, artifact_hash: Optional[str] = None, file_size_bytes: Optional[int] = None, mime_type: Optional[str] = None, valid_until: Optional[datetime] = None, source: str = "manual", ci_job_id: Optional[str] = None, uploaded_by: Optional[str] = None, ) -> EvidenceDB: """Create evidence record.""" # Get control UUID control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: raise ValueError(f"Control {control_id} not found") evidence = EvidenceDB( id=str(uuid.uuid4()), control_id=control.id, evidence_type=evidence_type, title=title, description=description, artifact_path=artifact_path, artifact_url=artifact_url, artifact_hash=artifact_hash, file_size_bytes=file_size_bytes, mime_type=mime_type, valid_until=valid_until, source=source, ci_job_id=ci_job_id, uploaded_by=uploaded_by, ) self.db.add(evidence) self.db.commit() self.db.refresh(evidence) return evidence def get_by_id(self, evidence_id: str) -> Optional[EvidenceDB]: """Get evidence by ID.""" return self.db.query(EvidenceDB).filter(EvidenceDB.id == evidence_id).first() def get_by_control( self, control_id: str, status: Optional[EvidenceStatusEnum] = None ) -> List[EvidenceDB]: """Get all evidence for a control.""" control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: return [] query = self.db.query(EvidenceDB).filter(EvidenceDB.control_id == control.id) if status: query = query.filter(EvidenceDB.status == status) return query.order_by(EvidenceDB.collected_at.desc()).all() def get_all( self, evidence_type: Optional[str] = None, status: Optional[EvidenceStatusEnum] = None, limit: int = 100, ) -> List[EvidenceDB]: """Get all evidence with filters.""" query = self.db.query(EvidenceDB) if evidence_type: query = query.filter(EvidenceDB.evidence_type == evidence_type) if status: query = query.filter(EvidenceDB.status == status) return query.order_by(EvidenceDB.collected_at.desc()).limit(limit).all() def update_status(self, evidence_id: str, status: EvidenceStatusEnum) -> Optional[EvidenceDB]: """Update evidence status.""" evidence = self.get_by_id(evidence_id) if not evidence: return None evidence.status = status evidence.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(evidence) return evidence def get_statistics(self) -> Dict[str, Any]: """Get evidence statistics.""" total = self.db.query(func.count(EvidenceDB.id)).scalar() by_type = dict( self.db.query(EvidenceDB.evidence_type, func.count(EvidenceDB.id)) .group_by(EvidenceDB.evidence_type) .all() ) by_status = dict( self.db.query(EvidenceDB.status, func.count(EvidenceDB.id)) .group_by(EvidenceDB.status) .all() ) valid = by_status.get(EvidenceStatusEnum.VALID, 0) coverage = (valid / total * 100) if total > 0 else 0 return { "total": total, "by_type": by_type, "by_status": {str(k.value) if k else "none": v for k, v in by_status.items()}, "coverage_percent": round(coverage, 1), } class RiskRepository: """Repository for risks.""" def __init__(self, db: DBSession): self.db = db def create( self, risk_id: str, title: str, category: str, likelihood: int, impact: int, description: Optional[str] = None, mitigating_controls: Optional[List[str]] = None, owner: Optional[str] = None, treatment_plan: Optional[str] = None, ) -> RiskDB: """Create a risk.""" inherent_risk = RiskDB.calculate_risk_level(likelihood, impact) risk = RiskDB( id=str(uuid.uuid4()), risk_id=risk_id, title=title, description=description, category=category, likelihood=likelihood, impact=impact, inherent_risk=inherent_risk, mitigating_controls=mitigating_controls or [], owner=owner, treatment_plan=treatment_plan, ) self.db.add(risk) self.db.commit() self.db.refresh(risk) return risk def get_by_id(self, risk_uuid: str) -> Optional[RiskDB]: """Get risk by UUID.""" return self.db.query(RiskDB).filter(RiskDB.id == risk_uuid).first() def get_by_risk_id(self, risk_id: str) -> Optional[RiskDB]: """Get risk by risk_id (e.g., 'RISK-001').""" return self.db.query(RiskDB).filter(RiskDB.risk_id == risk_id).first() def get_all( self, category: Optional[str] = None, status: Optional[str] = None, min_risk_level: Optional[RiskLevelEnum] = None, ) -> List[RiskDB]: """Get all risks with filters.""" query = self.db.query(RiskDB) if category: query = query.filter(RiskDB.category == category) if status: query = query.filter(RiskDB.status == status) if min_risk_level: risk_order = { RiskLevelEnum.LOW: 1, RiskLevelEnum.MEDIUM: 2, RiskLevelEnum.HIGH: 3, RiskLevelEnum.CRITICAL: 4, } min_order = risk_order.get(min_risk_level, 1) query = query.filter( RiskDB.inherent_risk.in_( [k for k, v in risk_order.items() if v >= min_order] ) ) return query.order_by(RiskDB.risk_id).all() def update(self, risk_id: str, **kwargs) -> Optional[RiskDB]: """Update a risk.""" risk = self.get_by_risk_id(risk_id) if not risk: return None for key, value in kwargs.items(): if hasattr(risk, key): setattr(risk, key, value) # Recalculate risk levels if likelihood/impact changed if 'likelihood' in kwargs or 'impact' in kwargs: risk.inherent_risk = RiskDB.calculate_risk_level(risk.likelihood, risk.impact) if 'residual_likelihood' in kwargs or 'residual_impact' in kwargs: if risk.residual_likelihood and risk.residual_impact: risk.residual_risk = RiskDB.calculate_risk_level( risk.residual_likelihood, risk.residual_impact ) risk.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(risk) return risk def get_matrix_data(self) -> Dict[str, Any]: """Get data for risk matrix visualization.""" risks = self.get_all() matrix = {} for risk in risks: key = f"{risk.likelihood}_{risk.impact}" if key not in matrix: matrix[key] = [] matrix[key].append({ "risk_id": risk.risk_id, "title": risk.title, "inherent_risk": risk.inherent_risk.value if risk.inherent_risk else None, }) return { "matrix": matrix, "total_risks": len(risks), "by_level": { "critical": len([r for r in risks if r.inherent_risk == RiskLevelEnum.CRITICAL]), "high": len([r for r in risks if r.inherent_risk == RiskLevelEnum.HIGH]), "medium": len([r for r in risks if r.inherent_risk == RiskLevelEnum.MEDIUM]), "low": len([r for r in risks if r.inherent_risk == RiskLevelEnum.LOW]), } } class AuditExportRepository: """Repository for audit exports.""" def __init__(self, db: DBSession): self.db = db def create( self, export_type: str, requested_by: str, export_name: Optional[str] = None, included_regulations: Optional[List[str]] = None, included_domains: Optional[List[str]] = None, date_range_start: Optional[date] = None, date_range_end: Optional[date] = None, ) -> AuditExportDB: """Create an export request.""" export = AuditExportDB( id=str(uuid.uuid4()), export_type=export_type, export_name=export_name or f"audit_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}", requested_by=requested_by, included_regulations=included_regulations, included_domains=included_domains, date_range_start=date_range_start, date_range_end=date_range_end, ) self.db.add(export) self.db.commit() self.db.refresh(export) return export def get_by_id(self, export_id: str) -> Optional[AuditExportDB]: """Get export by ID.""" return self.db.query(AuditExportDB).filter(AuditExportDB.id == export_id).first() def get_all(self, limit: int = 50) -> List[AuditExportDB]: """Get all exports.""" return ( self.db.query(AuditExportDB) .order_by(AuditExportDB.requested_at.desc()) .limit(limit) .all() ) def update_status( self, export_id: str, status: ExportStatusEnum, file_path: Optional[str] = None, file_hash: Optional[str] = None, file_size_bytes: Optional[int] = None, error_message: Optional[str] = None, total_controls: Optional[int] = None, total_evidence: Optional[int] = None, compliance_score: Optional[float] = None, ) -> Optional[AuditExportDB]: """Update export status.""" export = self.get_by_id(export_id) if not export: return None export.status = status if file_path: export.file_path = file_path if file_hash: export.file_hash = file_hash if file_size_bytes: export.file_size_bytes = file_size_bytes if error_message: export.error_message = error_message if total_controls is not None: export.total_controls = total_controls if total_evidence is not None: export.total_evidence = total_evidence if compliance_score is not None: export.compliance_score = compliance_score if status == ExportStatusEnum.COMPLETED: export.completed_at = datetime.utcnow() export.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(export) return export class ServiceModuleRepository: """Repository for service modules (Sprint 3).""" def __init__(self, db: DBSession): self.db = db def create( self, name: str, display_name: str, service_type: str, description: Optional[str] = None, port: Optional[int] = None, technology_stack: Optional[List[str]] = None, repository_path: Optional[str] = None, docker_image: Optional[str] = None, data_categories: Optional[List[str]] = None, processes_pii: bool = False, processes_health_data: bool = False, ai_components: bool = False, criticality: str = "medium", owner_team: Optional[str] = None, owner_contact: Optional[str] = None, ) -> "ServiceModuleDB": """Create a service module.""" from .models import ServiceModuleDB, ServiceTypeEnum module = ServiceModuleDB( id=str(uuid.uuid4()), name=name, display_name=display_name, description=description, service_type=ServiceTypeEnum(service_type), port=port, technology_stack=technology_stack or [], repository_path=repository_path, docker_image=docker_image, data_categories=data_categories or [], processes_pii=processes_pii, processes_health_data=processes_health_data, ai_components=ai_components, criticality=criticality, owner_team=owner_team, owner_contact=owner_contact, ) self.db.add(module) self.db.commit() self.db.refresh(module) return module def get_by_id(self, module_id: str) -> Optional["ServiceModuleDB"]: """Get module by ID.""" from .models import ServiceModuleDB return self.db.query(ServiceModuleDB).filter(ServiceModuleDB.id == module_id).first() def get_by_name(self, name: str) -> Optional["ServiceModuleDB"]: """Get module by name.""" from .models import ServiceModuleDB return self.db.query(ServiceModuleDB).filter(ServiceModuleDB.name == name).first() def get_all( self, service_type: Optional[str] = None, criticality: Optional[str] = None, processes_pii: Optional[bool] = None, ai_components: Optional[bool] = None, ) -> List["ServiceModuleDB"]: """Get all modules with filters.""" from .models import ServiceModuleDB, ServiceTypeEnum query = self.db.query(ServiceModuleDB).filter(ServiceModuleDB.is_active == True) if service_type: query = query.filter(ServiceModuleDB.service_type == ServiceTypeEnum(service_type)) if criticality: query = query.filter(ServiceModuleDB.criticality == criticality) if processes_pii is not None: query = query.filter(ServiceModuleDB.processes_pii == processes_pii) if ai_components is not None: query = query.filter(ServiceModuleDB.ai_components == ai_components) return query.order_by(ServiceModuleDB.name).all() def get_with_regulations(self, module_id: str) -> Optional["ServiceModuleDB"]: """Get module with regulation mappings loaded.""" from .models import ServiceModuleDB, ModuleRegulationMappingDB from sqlalchemy.orm import selectinload return ( self.db.query(ServiceModuleDB) .options( selectinload(ServiceModuleDB.regulation_mappings) .selectinload(ModuleRegulationMappingDB.regulation) ) .filter(ServiceModuleDB.id == module_id) .first() ) def add_regulation_mapping( self, module_id: str, regulation_id: str, relevance_level: str = "medium", notes: Optional[str] = None, applicable_articles: Optional[List[str]] = None, ) -> "ModuleRegulationMappingDB": """Add a regulation mapping to a module.""" from .models import ModuleRegulationMappingDB, RelevanceLevelEnum mapping = ModuleRegulationMappingDB( id=str(uuid.uuid4()), module_id=module_id, regulation_id=regulation_id, relevance_level=RelevanceLevelEnum(relevance_level), notes=notes, applicable_articles=applicable_articles, ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def get_overview(self) -> Dict[str, Any]: """Get overview statistics for all modules.""" from .models import ServiceModuleDB, ModuleRegulationMappingDB from sqlalchemy import func modules = self.get_all() total = len(modules) by_type = {} by_criticality = {} pii_count = 0 ai_count = 0 for m in modules: type_key = m.service_type.value if m.service_type else "unknown" by_type[type_key] = by_type.get(type_key, 0) + 1 by_criticality[m.criticality] = by_criticality.get(m.criticality, 0) + 1 if m.processes_pii: pii_count += 1 if m.ai_components: ai_count += 1 # Get regulation coverage regulation_coverage = {} mappings = self.db.query(ModuleRegulationMappingDB).all() for mapping in mappings: reg = mapping.regulation if reg: code = reg.code regulation_coverage[code] = regulation_coverage.get(code, 0) + 1 # Calculate average compliance score scores = [m.compliance_score for m in modules if m.compliance_score is not None] avg_score = sum(scores) / len(scores) if scores else None return { "total_modules": total, "modules_by_type": by_type, "modules_by_criticality": by_criticality, "modules_processing_pii": pii_count, "modules_with_ai": ai_count, "average_compliance_score": round(avg_score, 1) if avg_score else None, "regulations_coverage": regulation_coverage, } def seed_from_data(self, services_data: List[Dict[str, Any]], force: bool = False) -> Dict[str, int]: """Seed modules from service_modules.py data.""" from .models import ServiceModuleDB modules_created = 0 mappings_created = 0 for svc in services_data: # Check if module exists existing = self.get_by_name(svc["name"]) if existing and not force: continue if existing and force: # Delete existing module (cascades to mappings) self.db.delete(existing) self.db.commit() # Create module module = self.create( name=svc["name"], display_name=svc["display_name"], description=svc.get("description"), service_type=svc["service_type"], port=svc.get("port"), technology_stack=svc.get("technology_stack"), repository_path=svc.get("repository_path"), docker_image=svc.get("docker_image"), data_categories=svc.get("data_categories"), processes_pii=svc.get("processes_pii", False), processes_health_data=svc.get("processes_health_data", False), ai_components=svc.get("ai_components", False), criticality=svc.get("criticality", "medium"), owner_team=svc.get("owner_team"), ) modules_created += 1 # Create regulation mappings for reg_data in svc.get("regulations", []): # Find regulation by code reg = self.db.query(RegulationDB).filter( RegulationDB.code == reg_data["code"] ).first() if reg: self.add_regulation_mapping( module_id=module.id, regulation_id=reg.id, relevance_level=reg_data.get("relevance", "medium"), notes=reg_data.get("notes"), ) mappings_created += 1 return { "modules_created": modules_created, "mappings_created": mappings_created, } class AuditSessionRepository: """Repository for audit sessions (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, name: str, auditor_name: str, description: Optional[str] = None, auditor_email: Optional[str] = None, regulation_ids: Optional[List[str]] = None, ) -> AuditSessionDB: """Create a new audit session.""" session = AuditSessionDB( id=str(uuid.uuid4()), name=name, description=description, auditor_name=auditor_name, auditor_email=auditor_email, regulation_ids=regulation_ids, status=AuditSessionStatusEnum.DRAFT, ) self.db.add(session) self.db.commit() self.db.refresh(session) return session def get_by_id(self, session_id: str) -> Optional[AuditSessionDB]: """Get audit session by ID with eager-loaded signoffs.""" return ( self.db.query(AuditSessionDB) .options( selectinload(AuditSessionDB.signoffs) .selectinload(AuditSignOffDB.requirement) ) .filter(AuditSessionDB.id == session_id) .first() ) def get_all( self, status: Optional[AuditSessionStatusEnum] = None, limit: int = 50, ) -> List[AuditSessionDB]: """Get all audit sessions with optional status filter.""" query = self.db.query(AuditSessionDB) if status: query = query.filter(AuditSessionDB.status == status) return query.order_by(AuditSessionDB.created_at.desc()).limit(limit).all() def update_status( self, session_id: str, status: AuditSessionStatusEnum, ) -> Optional[AuditSessionDB]: """Update session status and set appropriate timestamps.""" session = self.get_by_id(session_id) if not session: return None session.status = status if status == AuditSessionStatusEnum.IN_PROGRESS and not session.started_at: session.started_at = datetime.utcnow() elif status == AuditSessionStatusEnum.COMPLETED: session.completed_at = datetime.utcnow() session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def update_progress( self, session_id: str, total_items: Optional[int] = None, completed_items: Optional[int] = None, ) -> Optional[AuditSessionDB]: """Update session progress counters.""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return None if total_items is not None: session.total_items = total_items if completed_items is not None: session.completed_items = completed_items session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def start_session(self, session_id: str) -> Optional[AuditSessionDB]: """ Start an audit session: - Set status to IN_PROGRESS - Initialize total_items based on requirements count """ session = self.get_by_id(session_id) if not session: return None # Count requirements for this session query = self.db.query(func.count(RequirementDB.id)) if session.regulation_ids: query = query.join(RegulationDB).filter( RegulationDB.id.in_(session.regulation_ids) ) total_requirements = query.scalar() or 0 session.status = AuditSessionStatusEnum.IN_PROGRESS session.started_at = datetime.utcnow() session.total_items = total_requirements session.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(session) return session def delete(self, session_id: str) -> bool: """Delete an audit session (cascades to signoffs).""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return False self.db.delete(session) self.db.commit() return True def get_statistics(self, session_id: str) -> Dict[str, Any]: """Get detailed statistics for an audit session.""" session = self.get_by_id(session_id) if not session: return {} signoffs = session.signoffs or [] stats = { "total": session.total_items or 0, "completed": len([s for s in signoffs if s.result != AuditResultEnum.PENDING]), "compliant": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT]), "compliant_with_notes": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT_WITH_NOTES]), "non_compliant": len([s for s in signoffs if s.result == AuditResultEnum.NON_COMPLIANT]), "not_applicable": len([s for s in signoffs if s.result == AuditResultEnum.NOT_APPLICABLE]), "pending": len([s for s in signoffs if s.result == AuditResultEnum.PENDING]), "signed": len([s for s in signoffs if s.signature_hash]), } total = stats["total"] if stats["total"] > 0 else 1 stats["completion_percentage"] = round( (stats["completed"] / total) * 100, 1 ) return stats class AuditSignOffRepository: """Repository for audit sign-offs (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, session_id: str, requirement_id: str, result: AuditResultEnum = AuditResultEnum.PENDING, notes: Optional[str] = None, ) -> AuditSignOffDB: """Create a new sign-off for a requirement.""" signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) return signoff def get_by_id(self, signoff_id: str) -> Optional[AuditSignOffDB]: """Get sign-off by ID.""" return ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.id == signoff_id) .first() ) def get_by_session_and_requirement( self, session_id: str, requirement_id: str, ) -> Optional[AuditSignOffDB]: """Get sign-off by session and requirement ID.""" return ( self.db.query(AuditSignOffDB) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.requirement_id == requirement_id, ) ) .first() ) def get_by_session( self, session_id: str, result_filter: Optional[AuditResultEnum] = None, ) -> List[AuditSignOffDB]: """Get all sign-offs for a session.""" query = ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.session_id == session_id) ) if result_filter: query = query.filter(AuditSignOffDB.result == result_filter) return query.order_by(AuditSignOffDB.created_at).all() def update( self, signoff_id: str, result: Optional[AuditResultEnum] = None, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> Optional[AuditSignOffDB]: """Update a sign-off with optional digital signature.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return None if result is not None: signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(signoff.session_id) return signoff def sign_off( self, session_id: str, requirement_id: str, result: AuditResultEnum, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> AuditSignOffDB: """ Create or update a sign-off for a requirement. This is the main method for auditors to record their findings. """ # Check if sign-off already exists signoff = self.get_by_session_and_requirement(session_id, requirement_id) if signoff: # Update existing signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.utcnow() else: # Create new signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) if sign and signed_by: signoff.create_signature(signed_by) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(session_id) return signoff def _update_session_progress(self, session_id: str) -> None: """Update the session's completed_items count.""" completed = ( self.db.query(func.count(AuditSignOffDB.id)) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.result != AuditResultEnum.PENDING, ) ) .scalar() ) or 0 session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if session: session.completed_items = completed session.updated_at = datetime.utcnow() self.db.commit() def get_checklist( self, session_id: str, page: int = 1, page_size: int = 50, result_filter: Optional[AuditResultEnum] = None, regulation_code: Optional[str] = None, search: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """ Get audit checklist items for a session with pagination. Returns requirements with their sign-off status. """ session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return [], 0 # Base query for requirements query = ( self.db.query(RequirementDB) .options( joinedload(RequirementDB.regulation), selectinload(RequirementDB.control_mappings), ) ) # Filter by session's regulation_ids if set if session.regulation_ids: query = query.filter(RequirementDB.regulation_id.in_(session.regulation_ids)) # Filter by regulation code if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) # Search if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Get existing sign-offs for this session signoffs_map = {} signoffs = ( self.db.query(AuditSignOffDB) .filter(AuditSignOffDB.session_id == session_id) .all() ) for s in signoffs: signoffs_map[s.requirement_id] = s # Filter by result if specified if result_filter: if result_filter == AuditResultEnum.PENDING: # Requirements without sign-off or with pending status signed_req_ids = [ s.requirement_id for s in signoffs if s.result != AuditResultEnum.PENDING ] if signed_req_ids: query = query.filter(~RequirementDB.id.in_(signed_req_ids)) else: # Requirements with specific result matching_req_ids = [ s.requirement_id for s in signoffs if s.result == result_filter ] if matching_req_ids: query = query.filter(RequirementDB.id.in_(matching_req_ids)) else: return [], 0 # Count and paginate total = query.count() requirements = ( query .order_by(RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) # Build checklist items items = [] for req in requirements: signoff = signoffs_map.get(req.id) items.append({ "requirement_id": req.id, "regulation_code": req.regulation.code if req.regulation else None, "regulation_name": req.regulation.name if req.regulation else None, "article": req.article, "paragraph": req.paragraph, "title": req.title, "description": req.description, "current_result": signoff.result.value if signoff else AuditResultEnum.PENDING.value, "notes": signoff.notes if signoff else None, "is_signed": bool(signoff.signature_hash) if signoff else False, "signed_at": signoff.signed_at if signoff else None, "signed_by": signoff.signed_by if signoff else None, "evidence_count": len(req.control_mappings) if req.control_mappings else 0, "controls_mapped": len(req.control_mappings) if req.control_mappings else 0, }) return items, total def delete(self, signoff_id: str) -> bool: """Delete a sign-off.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return False session_id = signoff.session_id self.db.delete(signoff) self.db.commit() # Update session progress self._update_session_progress(session_id) return True