# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" """ Service for regulation and requirement business logic. Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic for regulations CRUD and requirements CRUD lives here. The route module delegates to this service and translates domain errors to HTTP responses. """ import logging from datetime import datetime, timezone from typing import Any, Optional from sqlalchemy.orm import Session from compliance.db import RegulationRepository, RequirementRepository from compliance.db.models import RegulationDB, RequirementDB from compliance.domain import NotFoundError, ValidationError from compliance.schemas.regulation import ( RegulationListResponse, RegulationResponse, ) from compliance.schemas.requirement import ( PaginatedRequirementResponse, RequirementCreate, RequirementListResponse, RequirementResponse, ) from compliance.schemas.common import PaginationMeta logger = logging.getLogger(__name__) def _regulation_to_response( reg: Any, requirement_count: int ) -> RegulationResponse: return RegulationResponse( id=reg.id, code=reg.code, name=reg.name, full_name=reg.full_name, regulation_type=( reg.regulation_type.value if reg.regulation_type else None ), source_url=reg.source_url, local_pdf_path=reg.local_pdf_path, effective_date=reg.effective_date, description=reg.description, is_active=reg.is_active, created_at=reg.created_at, updated_at=reg.updated_at, requirement_count=requirement_count, ) def _requirement_to_response(r: Any, code: Optional[str] = None) -> RequirementResponse: return RequirementResponse( id=r.id, regulation_id=r.regulation_id, regulation_code=code, article=r.article, paragraph=r.paragraph, title=r.title, description=r.description, requirement_text=r.requirement_text, breakpilot_interpretation=r.breakpilot_interpretation, is_applicable=r.is_applicable, applicability_reason=r.applicability_reason, priority=r.priority, created_at=r.created_at, updated_at=r.updated_at, ) class RegulationRequirementService: """Business logic for regulation and requirement endpoints.""" def __init__( self, db: Session, reg_repo_cls: Any = RegulationRepository, req_repo_cls: Any = RequirementRepository, ) -> None: self.db = db self.reg_repo = reg_repo_cls(db) self.req_repo = req_repo_cls(db) # ------------------------------------------------------------------ # Regulations # ------------------------------------------------------------------ def list_regulations( self, is_active: Optional[bool], regulation_type: Optional[str], ) -> RegulationListResponse: if is_active is not None: regulations = ( self.reg_repo.get_active() if is_active else self.reg_repo.get_all() ) else: regulations = self.reg_repo.get_all() if regulation_type: from compliance.db.models import RegulationTypeEnum try: reg_type = RegulationTypeEnum(regulation_type) regulations = [ r for r in regulations if r.regulation_type == reg_type ] except ValueError: pass results = [] for reg in regulations: reqs = self.req_repo.get_by_regulation(reg.id) results.append(_regulation_to_response(reg, len(reqs))) return RegulationListResponse(regulations=results, total=len(results)) def get_regulation(self, code: str) -> RegulationResponse: regulation = self.reg_repo.get_by_code(code) if not regulation: raise NotFoundError(f"Regulation {code} not found") reqs = self.req_repo.get_by_regulation(regulation.id) return _regulation_to_response(regulation, len(reqs)) def get_regulation_requirements( self, code: str, is_applicable: Optional[bool], ) -> RequirementListResponse: regulation = self.reg_repo.get_by_code(code) if not regulation: raise NotFoundError(f"Regulation {code} not found") if is_applicable is not None: requirements = ( self.req_repo.get_applicable(regulation.id) if is_applicable else self.req_repo.get_by_regulation(regulation.id) ) else: requirements = self.req_repo.get_by_regulation(regulation.id) results = [_requirement_to_response(r, code) for r in requirements] return RequirementListResponse( requirements=results, total=len(results) ) # ------------------------------------------------------------------ # Requirements # ------------------------------------------------------------------ def get_requirement( self, requirement_id: str, include_legal_context: bool, ) -> dict[str, Any]: requirement = ( self.db.query(RequirementDB) .filter(RequirementDB.id == requirement_id) .first() ) if not requirement: raise NotFoundError( f"Requirement {requirement_id} not found" ) regulation = ( self.db.query(RegulationDB) .filter(RegulationDB.id == requirement.regulation_id) .first() ) result: dict[str, Any] = { "id": requirement.id, "regulation_id": requirement.regulation_id, "regulation_code": regulation.code if regulation else None, "article": requirement.article, "paragraph": requirement.paragraph, "title": requirement.title, "description": requirement.description, "requirement_text": requirement.requirement_text, "breakpilot_interpretation": requirement.breakpilot_interpretation, "implementation_status": ( requirement.implementation_status or "not_started" ), "implementation_details": requirement.implementation_details, "code_references": requirement.code_references, "documentation_links": requirement.documentation_links, "evidence_description": requirement.evidence_description, "evidence_artifacts": requirement.evidence_artifacts, "auditor_notes": requirement.auditor_notes, "audit_status": requirement.audit_status or "pending", "last_audit_date": requirement.last_audit_date, "last_auditor": requirement.last_auditor, "is_applicable": requirement.is_applicable, "applicability_reason": requirement.applicability_reason, "priority": requirement.priority, "source_page": requirement.source_page, "source_section": requirement.source_section, } if include_legal_context: result["legal_context"] = self._fetch_legal_context( requirement, regulation ) return result def list_requirements_paginated( self, page: int, page_size: int, regulation_code: Optional[str], status: Optional[str], is_applicable: Optional[bool], search: Optional[str], ) -> PaginatedRequirementResponse: requirements, total = self.req_repo.get_paginated( page=page, page_size=page_size, regulation_code=regulation_code, status=status, is_applicable=is_applicable, search=search, ) total_pages = (total + page_size - 1) // page_size results = [ RequirementResponse( id=r.id, regulation_id=r.regulation_id, regulation_code=( r.regulation.code if r.regulation else None ), article=r.article, paragraph=r.paragraph, title=r.title, description=r.description, requirement_text=r.requirement_text, breakpilot_interpretation=r.breakpilot_interpretation, is_applicable=r.is_applicable, applicability_reason=r.applicability_reason, priority=r.priority, implementation_status=( r.implementation_status or "not_started" ), implementation_details=r.implementation_details, code_references=r.code_references, documentation_links=r.documentation_links, evidence_description=r.evidence_description, evidence_artifacts=r.evidence_artifacts, auditor_notes=r.auditor_notes, audit_status=r.audit_status or "pending", last_audit_date=r.last_audit_date, last_auditor=r.last_auditor, source_page=r.source_page, source_section=r.source_section, created_at=r.created_at, updated_at=r.updated_at, ) for r in requirements ] return PaginatedRequirementResponse( data=results, pagination=PaginationMeta( page=page, page_size=page_size, total=total, total_pages=total_pages, has_next=page < total_pages, has_prev=page > 1, ), ) def create_requirement( self, data: RequirementCreate ) -> RequirementResponse: regulation = self.reg_repo.get_by_id(data.regulation_id) if not regulation: raise NotFoundError( f"Regulation {data.regulation_id} not found" ) requirement = self.req_repo.create( regulation_id=data.regulation_id, article=data.article, title=data.title, paragraph=data.paragraph, description=data.description, requirement_text=data.requirement_text, breakpilot_interpretation=data.breakpilot_interpretation, is_applicable=data.is_applicable, priority=data.priority, ) return _requirement_to_response(requirement, regulation.code) def delete_requirement(self, requirement_id: str) -> dict[str, Any]: deleted = self.req_repo.delete(requirement_id) if not deleted: raise NotFoundError( f"Requirement {requirement_id} not found" ) return {"success": True, "message": "Requirement deleted"} def update_requirement( self, requirement_id: str, updates: dict[str, Any] ) -> dict[str, Any]: requirement = ( self.db.query(RequirementDB) .filter(RequirementDB.id == requirement_id) .first() ) if not requirement: raise NotFoundError( f"Requirement {requirement_id} not found" ) allowed_fields = [ "implementation_status", "implementation_details", "code_references", "documentation_links", "evidence_description", "evidence_artifacts", "auditor_notes", "audit_status", "is_applicable", "applicability_reason", "breakpilot_interpretation", ] for field in allowed_fields: if field in updates: setattr(requirement, field, updates[field]) if "audit_status" in updates: requirement.last_audit_date = datetime.now(timezone.utc) requirement.last_auditor = updates.get( "auditor_name", "api_user" ) requirement.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(requirement) return {"success": True, "message": "Requirement updated"} # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ async def _fetch_legal_context_async( self, requirement: Any, regulation: Any ) -> list[dict[str, Any]]: """Async version for RAG legal context.""" return self._fetch_legal_context(requirement, regulation) def _fetch_legal_context( self, requirement: Any, regulation: Any ) -> list[dict[str, Any]]: try: from compliance.services.rag_client import get_rag_client from compliance.services.ai_compliance_assistant import ( AIComplianceAssistant, ) import asyncio rag = get_rag_client() assistant = AIComplianceAssistant() query = f"{requirement.title} {requirement.article or ''}" collection = assistant._collection_for_regulation( regulation.code if regulation else "" ) # This is called from an async context but the method is sync # We need to handle the async search call loop = asyncio.get_event_loop() if loop.is_running(): # We're in an async context, return empty and let # the route handler deal with it return [] rag_results = loop.run_until_complete( rag.search(query, collection=collection, top_k=3) ) return [ { "text": r.text, "regulation_code": r.regulation_code, "regulation_short": r.regulation_short, "article": r.article, "score": r.score, "source_url": r.source_url, } for r in rag_results ] except Exception as e: logger.warning( "Failed to fetch legal context for %s: %s", requirement.id, e, ) return []