""" Compliance repositories — extracted from compliance/db/repository.py. Phase 1 Step 5: the monolithic repository module is decomposed per aggregate. Every repository class is re-exported from ``compliance.db.repository`` for backwards compatibility. """ import uuid from datetime import datetime, date, timezone from typing import List, Optional, Dict, Any, Tuple from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from compliance.db.models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum, ServiceModuleDB, ModuleRegulationMappingDB, ) class RiskRepository: """Repository for risks.""" def __init__(self, db: DBSession): self.db = db def create( self, risk_id: str, title: str, category: str, likelihood: int, impact: int, description: Optional[str] = None, mitigating_controls: Optional[List[str]] = None, owner: Optional[str] = None, treatment_plan: Optional[str] = None, ) -> RiskDB: """Create a risk.""" inherent_risk = RiskDB.calculate_risk_level(likelihood, impact) risk = RiskDB( id=str(uuid.uuid4()), risk_id=risk_id, title=title, description=description, category=category, likelihood=likelihood, impact=impact, inherent_risk=inherent_risk, mitigating_controls=mitigating_controls or [], owner=owner, treatment_plan=treatment_plan, ) self.db.add(risk) self.db.commit() self.db.refresh(risk) return risk def get_by_id(self, risk_uuid: str) -> Optional[RiskDB]: """Get risk by UUID.""" return self.db.query(RiskDB).filter(RiskDB.id == risk_uuid).first() def get_by_risk_id(self, risk_id: str) -> Optional[RiskDB]: """Get risk by risk_id (e.g., 'RISK-001').""" return self.db.query(RiskDB).filter(RiskDB.risk_id == risk_id).first() def get_all( self, category: Optional[str] = None, status: Optional[str] = None, min_risk_level: Optional[RiskLevelEnum] = None, ) -> List[RiskDB]: """Get all risks with filters.""" query = self.db.query(RiskDB) if category: query = query.filter(RiskDB.category == category) if status: query = query.filter(RiskDB.status == status) if min_risk_level: risk_order = { RiskLevelEnum.LOW: 1, RiskLevelEnum.MEDIUM: 2, RiskLevelEnum.HIGH: 3, RiskLevelEnum.CRITICAL: 4, } min_order = risk_order.get(min_risk_level, 1) query = query.filter( RiskDB.inherent_risk.in_( [k for k, v in risk_order.items() if v >= min_order] ) ) return query.order_by(RiskDB.risk_id).all() def update(self, risk_id: str, **kwargs) -> Optional[RiskDB]: """Update a risk.""" risk = self.get_by_risk_id(risk_id) if not risk: return None for key, value in kwargs.items(): if hasattr(risk, key): setattr(risk, key, value) # Recalculate risk levels if likelihood/impact changed if 'likelihood' in kwargs or 'impact' in kwargs: risk.inherent_risk = RiskDB.calculate_risk_level(risk.likelihood, risk.impact) if 'residual_likelihood' in kwargs or 'residual_impact' in kwargs: if risk.residual_likelihood and risk.residual_impact: risk.residual_risk = RiskDB.calculate_risk_level( risk.residual_likelihood, risk.residual_impact ) risk.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(risk) return risk def get_matrix_data(self) -> Dict[str, Any]: """Get data for risk matrix visualization.""" risks = self.get_all() matrix = {} for risk in risks: key = f"{risk.likelihood}_{risk.impact}" if key not in matrix: matrix[key] = [] matrix[key].append({ "risk_id": risk.risk_id, "title": risk.title, "inherent_risk": risk.inherent_risk.value if risk.inherent_risk else None, }) return { "matrix": matrix, "total_risks": len(risks), "by_level": { "critical": len([r for r in risks if r.inherent_risk == RiskLevelEnum.CRITICAL]), "high": len([r for r in risks if r.inherent_risk == RiskLevelEnum.HIGH]), "medium": len([r for r in risks if r.inherent_risk == RiskLevelEnum.MEDIUM]), "low": len([r for r in risks if r.inherent_risk == RiskLevelEnum.LOW]), } }