""" Compliance repositories — extracted from compliance/db/repository.py. Phase 1 Step 5: the monolithic repository module is decomposed per aggregate. Every repository class is re-exported from ``compliance.db.repository`` for backwards compatibility. """ import uuid from datetime import datetime, date, timezone from typing import List, Optional, Dict, Any, Tuple from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from compliance.db.models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum, ServiceModuleDB, ModuleRegulationMappingDB, ) class ControlRepository: """Repository for controls.""" def __init__(self, db: DBSession): self.db = db def create( self, control_id: str, domain: ControlDomainEnum, control_type: str, title: str, pass_criteria: str, description: Optional[str] = None, implementation_guidance: Optional[str] = None, code_reference: Optional[str] = None, is_automated: bool = False, automation_tool: Optional[str] = None, owner: Optional[str] = None, review_frequency_days: int = 90, ) -> ControlDB: """Create a new control.""" control = ControlDB( id=str(uuid.uuid4()), control_id=control_id, domain=domain, control_type=control_type, title=title, description=description, pass_criteria=pass_criteria, implementation_guidance=implementation_guidance, code_reference=code_reference, is_automated=is_automated, automation_tool=automation_tool, owner=owner, review_frequency_days=review_frequency_days, ) self.db.add(control) self.db.commit() self.db.refresh(control) return control def get_by_id(self, control_uuid: str) -> Optional[ControlDB]: """Get control by UUID with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.id == control_uuid) .first() ) def get_by_control_id(self, control_id: str) -> Optional[ControlDB]: """Get control by control_id (e.g., 'PRIV-001') with eager-loaded relationships.""" return ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings).selectinload(ControlMappingDB.requirement), selectinload(ControlDB.evidence) ) .filter(ControlDB.control_id == control_id) .first() ) def get_all( self, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, ) -> List[ControlDB]: """Get all controls with optional filters and eager-loading.""" query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) return query.order_by(ControlDB.control_id).all() def get_paginated( self, page: int = 1, page_size: int = 50, domain: Optional[ControlDomainEnum] = None, status: Optional[ControlStatusEnum] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[ControlDB], int]: """ Get paginated controls with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(ControlDB) .options( selectinload(ControlDB.mappings), selectinload(ControlDB.evidence) ) ) if domain: query = query.filter(ControlDB.domain == domain) if status: query = query.filter(ControlDB.status == status) if is_automated is not None: query = query.filter(ControlDB.is_automated == is_automated) if search: search_term = f"%{search}%" query = query.filter( or_( ControlDB.title.ilike(search_term), ControlDB.description.ilike(search_term), ControlDB.control_id.ilike(search_term), ) ) total = query.count() items = ( query .order_by(ControlDB.control_id) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def get_by_domain(self, domain: ControlDomainEnum) -> List[ControlDB]: """Get all controls in a domain.""" return self.get_all(domain=domain) def get_by_status(self, status: ControlStatusEnum) -> List[ControlDB]: """Get all controls with a specific status.""" return self.get_all(status=status) def update_status( self, control_id: str, status: ControlStatusEnum, status_notes: Optional[str] = None ) -> Optional[ControlDB]: """Update control status.""" control = self.get_by_control_id(control_id) if not control: return None control.status = status if status_notes: control.status_notes = status_notes control.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(control) return control def mark_reviewed(self, control_id: str) -> Optional[ControlDB]: """Mark control as reviewed.""" control = self.get_by_control_id(control_id) if not control: return None control.last_reviewed_at = datetime.now(timezone.utc) from datetime import timedelta control.next_review_at = datetime.now(timezone.utc) + timedelta(days=control.review_frequency_days) control.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(control) return control def get_due_for_review(self) -> List[ControlDB]: """Get controls due for review.""" return ( self.db.query(ControlDB) .filter( or_( ControlDB.next_review_at is None, ControlDB.next_review_at <= datetime.now(timezone.utc) ) ) .order_by(ControlDB.next_review_at) .all() ) def get_statistics(self) -> Dict[str, Any]: """Get control statistics by status and domain.""" total = self.db.query(func.count(ControlDB.id)).scalar() by_status = dict( self.db.query(ControlDB.status, func.count(ControlDB.id)) .group_by(ControlDB.status) .all() ) by_domain = dict( self.db.query(ControlDB.domain, func.count(ControlDB.id)) .group_by(ControlDB.domain) .all() ) passed = by_status.get(ControlStatusEnum.PASS, 0) partial = by_status.get(ControlStatusEnum.PARTIAL, 0) score = 0.0 if total > 0: score = ((passed + (partial * 0.5)) / total) * 100 return { "total": total, "by_status": {str(k.value) if k else "none": v for k, v in by_status.items()}, "by_domain": {str(k.value) if k else "none": v for k, v in by_domain.items()}, "compliance_score": round(score, 1), } class ControlMappingRepository: """Repository for requirement-control mappings.""" def __init__(self, db: DBSession): self.db = db def create( self, requirement_id: str, control_id: str, coverage_level: str = "full", notes: Optional[str] = None, ) -> ControlMappingDB: """Create a mapping.""" # Get the control UUID from control_id control = self.db.query(ControlDB).filter(ControlDB.control_id == control_id).first() if not control: raise ValueError(f"Control {control_id} not found") mapping = ControlMappingDB( id=str(uuid.uuid4()), requirement_id=requirement_id, control_id=control.id, coverage_level=coverage_level, notes=notes, ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def get_by_requirement(self, requirement_id: str) -> List[ControlMappingDB]: """Get all mappings for a requirement.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.requirement_id == requirement_id) .all() ) def get_by_control(self, control_uuid: str) -> List[ControlMappingDB]: """Get all mappings for a control.""" return ( self.db.query(ControlMappingDB) .filter(ControlMappingDB.control_id == control_uuid) .all() )