""" Compliance repositories — extracted from compliance/db/repository.py. Phase 1 Step 5: the monolithic repository module is decomposed per aggregate. Every repository class is re-exported from ``compliance.db.repository`` for backwards compatibility. """ import uuid from datetime import datetime, date, timezone from typing import List, Optional, Dict, Any, Tuple from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from compliance.db.models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum, ServiceModuleDB, ModuleRegulationMappingDB, ) class AuditSessionRepository: """Repository for audit sessions (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, name: str, auditor_name: str, description: Optional[str] = None, auditor_email: Optional[str] = None, regulation_ids: Optional[List[str]] = None, ) -> AuditSessionDB: """Create a new audit session.""" session = AuditSessionDB( id=str(uuid.uuid4()), name=name, description=description, auditor_name=auditor_name, auditor_email=auditor_email, regulation_ids=regulation_ids, status=AuditSessionStatusEnum.DRAFT, ) self.db.add(session) self.db.commit() self.db.refresh(session) return session def get_by_id(self, session_id: str) -> Optional[AuditSessionDB]: """Get audit session by ID with eager-loaded signoffs.""" return ( self.db.query(AuditSessionDB) .options( selectinload(AuditSessionDB.signoffs) .selectinload(AuditSignOffDB.requirement) ) .filter(AuditSessionDB.id == session_id) .first() ) def get_all( self, status: Optional[AuditSessionStatusEnum] = None, limit: int = 50, ) -> List[AuditSessionDB]: """Get all audit sessions with optional status filter.""" query = self.db.query(AuditSessionDB) if status: query = query.filter(AuditSessionDB.status == status) return query.order_by(AuditSessionDB.created_at.desc()).limit(limit).all() def update_status( self, session_id: str, status: AuditSessionStatusEnum, ) -> Optional[AuditSessionDB]: """Update session status and set appropriate timestamps.""" session = self.get_by_id(session_id) if not session: return None session.status = status if status == AuditSessionStatusEnum.IN_PROGRESS and not session.started_at: session.started_at = datetime.now(timezone.utc) elif status == AuditSessionStatusEnum.COMPLETED: session.completed_at = datetime.now(timezone.utc) session.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(session) return session def update_progress( self, session_id: str, total_items: Optional[int] = None, completed_items: Optional[int] = None, ) -> Optional[AuditSessionDB]: """Update session progress counters.""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return None if total_items is not None: session.total_items = total_items if completed_items is not None: session.completed_items = completed_items session.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(session) return session def start_session(self, session_id: str) -> Optional[AuditSessionDB]: """ Start an audit session: - Set status to IN_PROGRESS - Initialize total_items based on requirements count """ session = self.get_by_id(session_id) if not session: return None # Count requirements for this session query = self.db.query(func.count(RequirementDB.id)) if session.regulation_ids: query = query.join(RegulationDB).filter( RegulationDB.id.in_(session.regulation_ids) ) total_requirements = query.scalar() or 0 session.status = AuditSessionStatusEnum.IN_PROGRESS session.started_at = datetime.now(timezone.utc) session.total_items = total_requirements session.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(session) return session def delete(self, session_id: str) -> bool: """Delete an audit session (cascades to signoffs).""" session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return False self.db.delete(session) self.db.commit() return True def get_statistics(self, session_id: str) -> Dict[str, Any]: """Get detailed statistics for an audit session.""" session = self.get_by_id(session_id) if not session: return {} signoffs = session.signoffs or [] stats = { "total": session.total_items or 0, "completed": len([s for s in signoffs if s.result != AuditResultEnum.PENDING]), "compliant": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT]), "compliant_with_notes": len([s for s in signoffs if s.result == AuditResultEnum.COMPLIANT_WITH_NOTES]), "non_compliant": len([s for s in signoffs if s.result == AuditResultEnum.NON_COMPLIANT]), "not_applicable": len([s for s in signoffs if s.result == AuditResultEnum.NOT_APPLICABLE]), "pending": len([s for s in signoffs if s.result == AuditResultEnum.PENDING]), "signed": len([s for s in signoffs if s.signature_hash]), } total = stats["total"] if stats["total"] > 0 else 1 stats["completion_percentage"] = round( (stats["completed"] / total) * 100, 1 ) return stats class AuditSignOffRepository: """Repository for audit sign-offs (Sprint 3: Auditor-Verbesserungen).""" def __init__(self, db: DBSession): self.db = db def create( self, session_id: str, requirement_id: str, result: AuditResultEnum = AuditResultEnum.PENDING, notes: Optional[str] = None, ) -> AuditSignOffDB: """Create a new sign-off for a requirement.""" signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) return signoff def get_by_id(self, signoff_id: str) -> Optional[AuditSignOffDB]: """Get sign-off by ID.""" return ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.id == signoff_id) .first() ) def get_by_session_and_requirement( self, session_id: str, requirement_id: str, ) -> Optional[AuditSignOffDB]: """Get sign-off by session and requirement ID.""" return ( self.db.query(AuditSignOffDB) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.requirement_id == requirement_id, ) ) .first() ) def get_by_session( self, session_id: str, result_filter: Optional[AuditResultEnum] = None, ) -> List[AuditSignOffDB]: """Get all sign-offs for a session.""" query = ( self.db.query(AuditSignOffDB) .options(joinedload(AuditSignOffDB.requirement)) .filter(AuditSignOffDB.session_id == session_id) ) if result_filter: query = query.filter(AuditSignOffDB.result == result_filter) return query.order_by(AuditSignOffDB.created_at).all() def update( self, signoff_id: str, result: Optional[AuditResultEnum] = None, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> Optional[AuditSignOffDB]: """Update a sign-off with optional digital signature.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return None if result is not None: signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(signoff.session_id) return signoff def sign_off( self, session_id: str, requirement_id: str, result: AuditResultEnum, notes: Optional[str] = None, sign: bool = False, signed_by: Optional[str] = None, ) -> AuditSignOffDB: """ Create or update a sign-off for a requirement. This is the main method for auditors to record their findings. """ # Check if sign-off already exists signoff = self.get_by_session_and_requirement(session_id, requirement_id) if signoff: # Update existing signoff.result = result if notes is not None: signoff.notes = notes if sign and signed_by: signoff.create_signature(signed_by) signoff.updated_at = datetime.now(timezone.utc) else: # Create new signoff = AuditSignOffDB( id=str(uuid.uuid4()), session_id=session_id, requirement_id=requirement_id, result=result, notes=notes, ) if sign and signed_by: signoff.create_signature(signed_by) self.db.add(signoff) self.db.commit() self.db.refresh(signoff) # Update session progress self._update_session_progress(session_id) return signoff def _update_session_progress(self, session_id: str) -> None: """Update the session's completed_items count.""" completed = ( self.db.query(func.count(AuditSignOffDB.id)) .filter( and_( AuditSignOffDB.session_id == session_id, AuditSignOffDB.result != AuditResultEnum.PENDING, ) ) .scalar() ) or 0 session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if session: session.completed_items = completed session.updated_at = datetime.now(timezone.utc) self.db.commit() def get_checklist( self, session_id: str, page: int = 1, page_size: int = 50, result_filter: Optional[AuditResultEnum] = None, regulation_code: Optional[str] = None, search: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """ Get audit checklist items for a session with pagination. Returns requirements with their sign-off status. """ session = self.db.query(AuditSessionDB).filter( AuditSessionDB.id == session_id ).first() if not session: return [], 0 # Base query for requirements query = ( self.db.query(RequirementDB) .options( joinedload(RequirementDB.regulation), selectinload(RequirementDB.control_mappings), ) ) # Filter by session's regulation_ids if set if session.regulation_ids: query = query.filter(RequirementDB.regulation_id.in_(session.regulation_ids)) # Filter by regulation code if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) # Search if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Get existing sign-offs for this session signoffs_map = {} signoffs = ( self.db.query(AuditSignOffDB) .filter(AuditSignOffDB.session_id == session_id) .all() ) for s in signoffs: signoffs_map[s.requirement_id] = s # Filter by result if specified if result_filter: if result_filter == AuditResultEnum.PENDING: # Requirements without sign-off or with pending status signed_req_ids = [ s.requirement_id for s in signoffs if s.result != AuditResultEnum.PENDING ] if signed_req_ids: query = query.filter(~RequirementDB.id.in_(signed_req_ids)) else: # Requirements with specific result matching_req_ids = [ s.requirement_id for s in signoffs if s.result == result_filter ] if matching_req_ids: query = query.filter(RequirementDB.id.in_(matching_req_ids)) else: return [], 0 # Count and paginate total = query.count() requirements = ( query .order_by(RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) # Build checklist items items = [] for req in requirements: signoff = signoffs_map.get(req.id) items.append({ "requirement_id": req.id, "regulation_code": req.regulation.code if req.regulation else None, "regulation_name": req.regulation.name if req.regulation else None, "article": req.article, "paragraph": req.paragraph, "title": req.title, "description": req.description, "current_result": signoff.result.value if signoff else AuditResultEnum.PENDING.value, "notes": signoff.notes if signoff else None, "is_signed": bool(signoff.signature_hash) if signoff else False, "signed_at": signoff.signed_at if signoff else None, "signed_by": signoff.signed_by if signoff else None, "evidence_count": len(req.control_mappings) if req.control_mappings else 0, "controls_mapped": len(req.control_mappings) if req.control_mappings else 0, }) return items, total def delete(self, signoff_id: str) -> bool: """Delete a sign-off.""" signoff = self.db.query(AuditSignOffDB).filter( AuditSignOffDB.id == signoff_id ).first() if not signoff: return False session_id = signoff.session_id self.db.delete(signoff) self.db.commit() # Update session progress self._update_session_progress(session_id) return True