# mypy: disable-error-code="arg-type" """ FastAPI routes for Compliance module. Endpoints: - /regulations: Manage regulations - /requirements: Manage requirements - /controls: Manage controls - /export: Audit export - /init-tables, /create-indexes, /seed, /seed-risks: Admin setup Phase 1 Step 4 refactor: handlers delegate to RegulationRequirementService and ControlExportService. Repository classes are re-exported so existing test patches (``compliance.api.routes.ControlRepository``, etc.) keep working. """ import logging import os from typing import Any, Optional from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from fastapi.responses import FileResponse from sqlalchemy.orm import Session from classroom_engine.database import get_db from ..db import ( ControlDomainEnum, ControlRepository, ControlStatusEnum, EvidenceRepository, RegulationRepository, RequirementRepository, ) from ..services.regulation_requirement_service import ( RegulationRequirementService, ) from ..services.control_export_service import ControlExportService from .schemas import ( ControlListResponse, ControlResponse, ControlReviewRequest, ControlUpdate, ExportListResponse, ExportRequest, ExportResponse, PaginatedControlResponse, PaginatedRequirementResponse, PaginationMeta, RegulationListResponse, RegulationResponse, RequirementCreate, RequirementListResponse, RequirementResponse, SeedRequest, SeedResponse, ) from ._http_errors import translate_domain_errors logger = logging.getLogger(__name__) router = APIRouter(prefix="/compliance", tags=["compliance"]) # --------------------------------------------------------------------------- # Dependency factories # --------------------------------------------------------------------------- def get_reg_req_service( db: Session = Depends(get_db), ) -> RegulationRequirementService: return RegulationRequirementService( db, reg_repo_cls=RegulationRepository, req_repo_cls=RequirementRepository, ) def get_ctrl_export_service( db: Session = Depends(get_db), ) -> ControlExportService: return ControlExportService( db, control_repo_cls=ControlRepository, evidence_repo_cls=EvidenceRepository, ) # ============================================================================ # Regulations # ============================================================================ @router.get("/regulations", response_model=RegulationListResponse) async def list_regulations( is_active: Optional[bool] = None, regulation_type: Optional[str] = None, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RegulationListResponse: """List all regulations.""" with translate_domain_errors(): return svc.list_regulations(is_active, regulation_type) @router.get("/regulations/{code}", response_model=RegulationResponse) async def get_regulation( code: str, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RegulationResponse: """Get a specific regulation by code.""" with translate_domain_errors(): return svc.get_regulation(code) @router.get( "/regulations/{code}/requirements", response_model=RequirementListResponse, ) async def get_regulation_requirements( code: str, is_applicable: Optional[bool] = None, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RequirementListResponse: """Get requirements for a specific regulation.""" with translate_domain_errors(): return svc.get_regulation_requirements(code, is_applicable) # ============================================================================ # Requirements # ============================================================================ @router.get("/requirements/{requirement_id}") async def get_requirement( requirement_id: str, include_legal_context: bool = Query( False, description="Include RAG legal context" ), db: Session = Depends(get_db), ) -> dict[str, Any]: """Get a specific requirement by ID, optionally with RAG legal context.""" svc = RegulationRequirementService(db) with translate_domain_errors(): result = svc.get_requirement(requirement_id, False) # Handle async legal context fetching inline to preserve behavior if include_legal_context: try: from ..services.rag_client import get_rag_client from ..services.ai_compliance_assistant import ( AIComplianceAssistant, ) from ..db.models import RequirementDB, RegulationDB requirement = ( db.query(RequirementDB) .filter(RequirementDB.id == requirement_id) .first() ) regulation = ( db.query(RegulationDB) .filter( RegulationDB.id == requirement.regulation_id ) .first() ) if requirement else None rag = get_rag_client() assistant = AIComplianceAssistant() query = ( f"{requirement.title} {requirement.article or ''}" ) collection = assistant._collection_for_regulation( regulation.code if regulation else "" ) rag_results = await rag.search( query, collection=collection, top_k=3 ) result["legal_context"] = [ { "text": r.text, "regulation_code": r.regulation_code, "regulation_short": r.regulation_short, "article": r.article, "score": r.score, "source_url": r.source_url, } for r in rag_results ] except Exception as e: logger.warning( "Failed to fetch legal context for %s: %s", requirement_id, e, ) result["legal_context"] = [] return result @router.get( "/requirements", response_model=PaginatedRequirementResponse ) async def list_requirements_paginated( page: int = Query(1, ge=1, description="Page number"), page_size: int = Query( 50, ge=1, le=500, description="Items per page" ), regulation_code: Optional[str] = Query( None, description="Filter by regulation code" ), status: Optional[str] = Query( None, description="Filter by implementation status" ), is_applicable: Optional[bool] = Query( None, description="Filter by applicability" ), search: Optional[str] = Query( None, description="Search in title/description" ), svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> PaginatedRequirementResponse: """List requirements with pagination.""" with translate_domain_errors(): return svc.list_requirements_paginated( page, page_size, regulation_code, status, is_applicable, search, ) @router.post("/requirements", response_model=RequirementResponse) async def create_requirement( data: RequirementCreate, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RequirementResponse: """Create a new requirement.""" with translate_domain_errors(): return svc.create_requirement(data) @router.delete("/requirements/{requirement_id}") async def delete_requirement( requirement_id: str, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> dict[str, Any]: """Delete a requirement by ID.""" with translate_domain_errors(): return svc.delete_requirement(requirement_id) @router.put("/requirements/{requirement_id}") async def update_requirement( requirement_id: str, updates: dict, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> dict[str, Any]: """Update a requirement with implementation/audit details.""" with translate_domain_errors(): return svc.update_requirement(requirement_id, updates) # ============================================================================ # Controls # ============================================================================ @router.get("/controls", response_model=ControlListResponse) async def list_controls( domain: Optional[str] = None, status: Optional[str] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlListResponse: """List all controls with optional filters.""" with translate_domain_errors(): return svc.list_controls(domain, status, is_automated, search) @router.get( "/controls/paginated", response_model=PaginatedControlResponse ) async def list_controls_paginated( page: int = Query(1, ge=1, description="Page number"), page_size: int = Query( 50, ge=1, le=500, description="Items per page" ), domain: Optional[str] = Query( None, description="Filter by domain" ), status: Optional[str] = Query( None, description="Filter by status" ), is_automated: Optional[bool] = Query( None, description="Filter by automation" ), search: Optional[str] = Query( None, description="Search in title/description" ), svc: ControlExportService = Depends(get_ctrl_export_service), ) -> PaginatedControlResponse: """List controls with pagination.""" with translate_domain_errors(): return svc.list_controls_paginated( page, page_size, domain, status, is_automated, search, ) @router.get( "/controls/{control_id}", response_model=ControlResponse ) async def get_control( control_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Get a specific control by control_id.""" with translate_domain_errors(): return svc.get_control(control_id) @router.put( "/controls/{control_id}", response_model=ControlResponse ) async def update_control( control_id: str, update: ControlUpdate, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Update a control.""" with translate_domain_errors(): return svc.update_control(control_id, update) @router.put( "/controls/{control_id}/review", response_model=ControlResponse, ) async def review_control( control_id: str, review: ControlReviewRequest, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Mark a control as reviewed with new status.""" with translate_domain_errors(): return svc.review_control(control_id, review) @router.get( "/controls/by-domain/{domain}", response_model=ControlListResponse, ) async def get_controls_by_domain( domain: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlListResponse: """Get controls by domain.""" with translate_domain_errors(): return svc.get_controls_by_domain(domain) # ============================================================================ # Export # ============================================================================ @router.post("/export", response_model=ExportResponse) async def create_export( request: ExportRequest, background_tasks: BackgroundTasks, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportResponse: """Create a new audit export.""" with translate_domain_errors(): data = svc.create_export( export_type=request.export_type, included_regulations=request.included_regulations, included_domains=request.included_domains, date_range_start=request.date_range_start, date_range_end=request.date_range_end, ) return ExportResponse(**data) @router.get("/export/{export_id}", response_model=ExportResponse) async def get_export( export_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportResponse: """Get export status.""" with translate_domain_errors(): data = svc.get_export(export_id) return ExportResponse(**data) @router.get("/export/{export_id}/download") async def download_export( export_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> FileResponse: """Download export file.""" with translate_domain_errors(): file_path = svc.download_export(export_id) return FileResponse( file_path, media_type="application/zip", filename=os.path.basename(file_path), ) @router.get("/exports", response_model=ExportListResponse) async def list_exports( limit: int = 20, offset: int = 0, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportListResponse: """List recent exports.""" with translate_domain_errors(): data = svc.list_exports(limit, offset) return ExportListResponse(**data) # ============================================================================ # Seeding / Admin # ============================================================================ @router.post("/init-tables") async def init_tables( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Create compliance tables if they don't exist.""" try: return svc.init_tables() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/create-indexes") async def create_performance_indexes( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Create additional performance indexes.""" return svc.create_indexes() @router.post("/seed-risks") async def seed_risks_only( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Seed only risks.""" try: return svc.seed_risks_only() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/seed", response_model=SeedResponse) async def seed_database( request: SeedRequest, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> SeedResponse: """Seed the compliance database with initial data.""" try: data = svc.seed_database(force=request.force) return SeedResponse(**data) except Exception as e: raise HTTPException(status_code=500, detail=str(e))