""" Compliance repositories — extracted from compliance/db/repository.py. Phase 1 Step 5: the monolithic repository module is decomposed per aggregate. Every repository class is re-exported from ``compliance.db.repository`` for backwards compatibility. """ import uuid from datetime import datetime, date, timezone from typing import List, Optional, Dict, Any, Tuple from sqlalchemy.orm import Session as DBSession, selectinload, joinedload from sqlalchemy import func, and_, or_ from compliance.db.models import ( RegulationDB, RequirementDB, ControlDB, ControlMappingDB, EvidenceDB, RiskDB, AuditExportDB, AuditSessionDB, AuditSignOffDB, AuditResultEnum, AuditSessionStatusEnum, RegulationTypeEnum, ControlDomainEnum, ControlStatusEnum, RiskLevelEnum, EvidenceStatusEnum, ExportStatusEnum, ServiceModuleDB, ModuleRegulationMappingDB, ) class RegulationRepository: """Repository for regulations/standards.""" def __init__(self, db: DBSession): self.db = db def create( self, code: str, name: str, regulation_type: RegulationTypeEnum, full_name: Optional[str] = None, source_url: Optional[str] = None, local_pdf_path: Optional[str] = None, effective_date: Optional[date] = None, description: Optional[str] = None, ) -> RegulationDB: """Create a new regulation.""" regulation = RegulationDB( id=str(uuid.uuid4()), code=code, name=name, full_name=full_name, regulation_type=regulation_type, source_url=source_url, local_pdf_path=local_pdf_path, effective_date=effective_date, description=description, ) self.db.add(regulation) self.db.commit() self.db.refresh(regulation) return regulation def get_by_id(self, regulation_id: str) -> Optional[RegulationDB]: """Get regulation by ID.""" return self.db.query(RegulationDB).filter(RegulationDB.id == regulation_id).first() def get_by_code(self, code: str) -> Optional[RegulationDB]: """Get regulation by code (e.g., 'GDPR').""" return self.db.query(RegulationDB).filter(RegulationDB.code == code).first() def get_all( self, regulation_type: Optional[RegulationTypeEnum] = None, is_active: Optional[bool] = True ) -> List[RegulationDB]: """Get all regulations with optional filters.""" query = self.db.query(RegulationDB) if regulation_type: query = query.filter(RegulationDB.regulation_type == regulation_type) if is_active is not None: query = query.filter(RegulationDB.is_active == is_active) return query.order_by(RegulationDB.code).all() def update(self, regulation_id: str, **kwargs) -> Optional[RegulationDB]: """Update a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return None for key, value in kwargs.items(): if hasattr(regulation, key): setattr(regulation, key, value) regulation.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(regulation) return regulation def delete(self, regulation_id: str) -> bool: """Delete a regulation.""" regulation = self.get_by_id(regulation_id) if not regulation: return False self.db.delete(regulation) self.db.commit() return True def get_active(self) -> List[RegulationDB]: """Get all active regulations.""" return self.get_all(is_active=True) def count(self) -> int: """Count all regulations.""" return self.db.query(func.count(RegulationDB.id)).scalar() or 0 class RequirementRepository: """Repository for requirements.""" def __init__(self, db: DBSession): self.db = db def create( self, regulation_id: str, article: str, title: str, paragraph: Optional[str] = None, description: Optional[str] = None, requirement_text: Optional[str] = None, breakpilot_interpretation: Optional[str] = None, is_applicable: bool = True, priority: int = 2, ) -> RequirementDB: """Create a new requirement.""" requirement = RequirementDB( id=str(uuid.uuid4()), regulation_id=regulation_id, article=article, paragraph=paragraph, title=title, description=description, requirement_text=requirement_text, breakpilot_interpretation=breakpilot_interpretation, is_applicable=is_applicable, priority=priority, ) self.db.add(requirement) self.db.commit() self.db.refresh(requirement) return requirement def get_by_id(self, requirement_id: str) -> Optional[RequirementDB]: """Get requirement by ID with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.id == requirement_id) .first() ) def get_by_regulation( self, regulation_id: str, is_applicable: Optional[bool] = None ) -> List[RequirementDB]: """Get all requirements for a regulation with eager-loaded controls.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .filter(RequirementDB.regulation_id == regulation_id) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_by_regulation_code(self, code: str) -> List[RequirementDB]: """Get requirements by regulation code with eager-loaded relationships.""" return ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) .join(RegulationDB) .filter(RegulationDB.code == code) .order_by(RequirementDB.article, RequirementDB.paragraph) .all() ) def get_all(self, is_applicable: Optional[bool] = None) -> List[RequirementDB]: """Get all requirements with optional filter and eager-loading.""" query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) return query.order_by(RequirementDB.article, RequirementDB.paragraph).all() def get_paginated( self, page: int = 1, page_size: int = 50, regulation_code: Optional[str] = None, status: Optional[str] = None, is_applicable: Optional[bool] = None, search: Optional[str] = None, ) -> Tuple[List[RequirementDB], int]: """ Get paginated requirements with eager-loaded relationships. Returns tuple of (items, total_count). """ query = ( self.db.query(RequirementDB) .options( selectinload(RequirementDB.control_mappings).selectinload(ControlMappingDB.control), joinedload(RequirementDB.regulation) ) ) # Filters if regulation_code: query = query.join(RegulationDB).filter(RegulationDB.code == regulation_code) if status: query = query.filter(RequirementDB.implementation_status == status) if is_applicable is not None: query = query.filter(RequirementDB.is_applicable == is_applicable) if search: search_term = f"%{search}%" query = query.filter( or_( RequirementDB.title.ilike(search_term), RequirementDB.description.ilike(search_term), RequirementDB.article.ilike(search_term), ) ) # Count before pagination total = query.count() # Apply pagination and ordering items = ( query .order_by(RequirementDB.priority.desc(), RequirementDB.article, RequirementDB.paragraph) .offset((page - 1) * page_size) .limit(page_size) .all() ) return items, total def delete(self, requirement_id: str) -> bool: """Delete a requirement.""" requirement = self.db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first() if not requirement: return False self.db.delete(requirement) self.db.commit() return True def count(self) -> int: """Count all requirements.""" return self.db.query(func.count(RequirementDB.id)).scalar() or 0