# mypy: disable-error-code="arg-type" """ FastAPI routes for Compliance module. Endpoints: - /regulations: Manage regulations - /requirements: Manage requirements - /controls: Manage controls - /export: Audit export - /init-tables, /create-indexes, /seed, /seed-risks: Admin setup Phase 1 Step 4 refactor: handlers delegate to RegulationRequirementService and ControlExportService. Repository classes are re-exported so existing test patches (``compliance.api.routes.ControlRepository``, etc.) keep working. """ import logging import os from typing import Any, Optional from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from fastapi.responses import FileResponse from sqlalchemy.orm import Session from classroom_engine.database import get_db from .audit_trail_utils import log_audit_trail from ..db import ( ControlDomainEnum, ControlRepository, ControlStatusEnum, EvidenceRepository, RegulationRepository, RequirementRepository, ) from ..services.regulation_requirement_service import ( RegulationRequirementService, ) from ..services.control_export_service import ControlExportService from .schemas import ( ControlListResponse, ControlResponse, ControlReviewRequest, ControlUpdate, ExportListResponse, ExportRequest, ExportResponse, PaginatedControlResponse, PaginatedRequirementResponse, PaginationMeta, RegulationListResponse, RegulationResponse, RequirementCreate, RequirementListResponse, RequirementResponse, SeedRequest, SeedResponse, ) from ._http_errors import translate_domain_errors logger = logging.getLogger(__name__) router = APIRouter(prefix="/compliance", tags=["compliance"]) # --------------------------------------------------------------------------- # Dependency factories # --------------------------------------------------------------------------- def get_reg_req_service( db: Session = Depends(get_db), ) -> RegulationRequirementService: return RegulationRequirementService( db, reg_repo_cls=RegulationRepository, req_repo_cls=RequirementRepository, ) def get_ctrl_export_service( db: Session = Depends(get_db), ) -> ControlExportService: return ControlExportService( db, control_repo_cls=ControlRepository, evidence_repo_cls=EvidenceRepository, ) # ============================================================================ # Regulations # ============================================================================ @router.get("/regulations", response_model=RegulationListResponse) async def list_regulations( is_active: Optional[bool] = None, regulation_type: Optional[str] = None, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RegulationListResponse: """List all regulations.""" with translate_domain_errors(): return svc.list_regulations(is_active, regulation_type) @router.get("/regulations/{code}", response_model=RegulationResponse) async def get_regulation( code: str, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RegulationResponse: """Get a specific regulation by code.""" with translate_domain_errors(): return svc.get_regulation(code) @router.get( "/regulations/{code}/requirements", response_model=RequirementListResponse, ) async def get_regulation_requirements( code: str, is_applicable: Optional[bool] = None, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RequirementListResponse: """Get requirements for a specific regulation.""" with translate_domain_errors(): return svc.get_regulation_requirements(code, is_applicable) # ============================================================================ # Requirements # ============================================================================ @router.get("/requirements/{requirement_id}") async def get_requirement( requirement_id: str, include_legal_context: bool = Query( False, description="Include RAG legal context" ), db: Session = Depends(get_db), ) -> dict[str, Any]: """Get a specific requirement by ID, optionally with RAG legal context.""" svc = RegulationRequirementService(db) with translate_domain_errors(): result = svc.get_requirement(requirement_id, False) # Handle async legal context fetching inline to preserve behavior if include_legal_context: try: from ..services.rag_client import get_rag_client from ..services.ai_compliance_assistant import ( AIComplianceAssistant, ) from ..db.models import RequirementDB, RegulationDB requirement = ( db.query(RequirementDB) .filter(RequirementDB.id == requirement_id) .first() ) regulation = ( db.query(RegulationDB) .filter( RegulationDB.id == requirement.regulation_id ) .first() ) if requirement else None rag = get_rag_client() assistant = AIComplianceAssistant() query = ( f"{requirement.title} {requirement.article or ''}" # type: ignore[union-attr] ) collection = assistant._collection_for_regulation( regulation.code if regulation else "" ) rag_results = await rag.search( query, collection=collection, top_k=3 ) result["legal_context"] = [ { "text": r.text, "regulation_code": r.regulation_code, "regulation_short": r.regulation_short, "article": r.article, "score": r.score, "source_url": r.source_url, } for r in rag_results ] except Exception as e: logger.warning( "Failed to fetch legal context for %s: %s", requirement_id, e, ) result["legal_context"] = [] return result @router.get( "/requirements", response_model=PaginatedRequirementResponse ) async def list_requirements_paginated( page: int = Query(1, ge=1, description="Page number"), page_size: int = Query( 50, ge=1, le=500, description="Items per page" ), regulation_code: Optional[str] = Query( None, description="Filter by regulation code" ), status: Optional[str] = Query( None, description="Filter by implementation status" ), is_applicable: Optional[bool] = Query( None, description="Filter by applicability" ), search: Optional[str] = Query( None, description="Search in title/description" ), svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> PaginatedRequirementResponse: """List requirements with pagination.""" with translate_domain_errors(): return svc.list_requirements_paginated( page, page_size, regulation_code, status, is_applicable, search, ) @router.post("/requirements", response_model=RequirementResponse) async def create_requirement( data: RequirementCreate, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> RequirementResponse: """Create a new requirement.""" with translate_domain_errors(): return svc.create_requirement(data) @router.delete("/requirements/{requirement_id}") async def delete_requirement( requirement_id: str, svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> dict[str, Any]: """Delete a requirement by ID.""" with translate_domain_errors(): return svc.delete_requirement(requirement_id) @router.put("/requirements/{requirement_id}") async def update_requirement( requirement_id: str, updates: dict[str, Any], svc: RegulationRequirementService = Depends(get_reg_req_service), ) -> dict[str, Any]: """Update a requirement with implementation/audit details.""" with translate_domain_errors(): return svc.update_requirement(requirement_id, updates) # ============================================================================ # Controls # ============================================================================ @router.get("/controls", response_model=ControlListResponse) async def list_controls( domain: Optional[str] = None, status: Optional[str] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlListResponse: """List all controls with optional filters.""" with translate_domain_errors(): return svc.list_controls(domain, status, is_automated, search) @router.get( "/controls/paginated", response_model=PaginatedControlResponse ) async def list_controls_paginated( page: int = Query(1, ge=1, description="Page number"), page_size: int = Query( 50, ge=1, le=500, description="Items per page" ), domain: Optional[str] = Query( None, description="Filter by domain" ), status: Optional[str] = Query( None, description="Filter by status" ), is_automated: Optional[bool] = Query( None, description="Filter by automation" ), search: Optional[str] = Query( None, description="Search in title/description" ), svc: ControlExportService = Depends(get_ctrl_export_service), ) -> PaginatedControlResponse: """List controls with pagination.""" with translate_domain_errors(): return svc.list_controls_paginated( page, page_size, domain, status, is_automated, search, ) @router.get( "/controls/{control_id}", response_model=ControlResponse ) async def get_control( control_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Get a specific control by control_id.""" repo = ControlRepository(db) control = repo.get_by_control_id(control_id) if not control: raise HTTPException(status_code=404, detail=f"Control {control_id} not found") evidence_repo = EvidenceRepository(db) evidence = evidence_repo.get_by_control(control.id) return ControlResponse( id=control.id, control_id=control.control_id, domain=control.domain.value if control.domain else None, control_type=control.control_type.value if control.control_type else None, title=control.title, description=control.description, pass_criteria=control.pass_criteria, implementation_guidance=control.implementation_guidance, code_reference=control.code_reference, documentation_url=control.documentation_url, is_automated=control.is_automated, automation_tool=control.automation_tool, automation_config=control.automation_config, owner=control.owner, review_frequency_days=control.review_frequency_days, status=control.status.value if control.status else None, status_notes=control.status_notes, status_justification=control.status_justification, last_reviewed_at=control.last_reviewed_at, next_review_at=control.next_review_at, created_at=control.created_at, updated_at=control.updated_at, evidence_count=len(evidence), ) @router.put( "/controls/{control_id}", response_model=ControlResponse ) async def update_control( control_id: str, update: ControlUpdate, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Update a control.""" repo = ControlRepository(db) control = repo.get_by_control_id(control_id) if not control: raise HTTPException(status_code=404, detail=f"Control {control_id} not found") update_data = update.model_dump(exclude_unset=True) # Convert status string to enum and validate transition if "status" in update_data: try: new_status_enum = ControlStatusEnum(update_data["status"]) except ValueError: raise HTTPException(status_code=400, detail=f"Invalid status: {update_data['status']}") # Validate status transition (Anti-Fake-Evidence) from ..services.control_status_machine import validate_transition current_status = control.status.value if control.status else "planned" evidence_list = db.query(EvidenceDB).filter(EvidenceDB.control_id == control.id).all() allowed, violations = validate_transition( current_status=current_status, new_status=update_data["status"], evidence_list=evidence_list, status_justification=update_data.get("status_justification") or update_data.get("status_notes"), ) if not allowed: raise HTTPException( status_code=409, detail={ "error": "Status transition not allowed", "current_status": current_status, "requested_status": update_data["status"], "violations": violations, } ) update_data["status"] = new_status_enum updated = repo.update(control.id, **update_data) db.commit() # Audit trail for status changes new_status = updated.status.value if updated.status else None if "status" in update.model_dump(exclude_unset=True) and current_status != new_status: log_audit_trail( db, "control", control.id, updated.control_id or updated.title, "status_change", performed_by=update.owner or "system", field_changed="status", old_value=current_status, new_value=new_status, ) db.commit() return ControlResponse( id=updated.id, control_id=updated.control_id, domain=updated.domain.value if updated.domain else None, control_type=updated.control_type.value if updated.control_type else None, title=updated.title, description=updated.description, pass_criteria=updated.pass_criteria, implementation_guidance=updated.implementation_guidance, code_reference=updated.code_reference, documentation_url=updated.documentation_url, is_automated=updated.is_automated, automation_tool=updated.automation_tool, automation_config=updated.automation_config, owner=updated.owner, review_frequency_days=updated.review_frequency_days, status=updated.status.value if updated.status else None, status_notes=updated.status_notes, status_justification=updated.status_justification, last_reviewed_at=updated.last_reviewed_at, next_review_at=updated.next_review_at, created_at=updated.created_at, updated_at=updated.updated_at, ) @router.put( "/controls/{control_id}/review", response_model=ControlResponse, ) async def review_control( control_id: str, review: ControlReviewRequest, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlResponse: """Mark a control as reviewed with new status.""" repo = ControlRepository(db) control = repo.get_by_control_id(control_id) if not control: raise HTTPException(status_code=404, detail=f"Control {control_id} not found") try: status_enum = ControlStatusEnum(review.status) except ValueError: raise HTTPException(status_code=400, detail=f"Invalid status: {review.status}") updated = repo.mark_reviewed(control.id, status_enum, review.status_notes) db.commit() return ControlResponse( id=updated.id, control_id=updated.control_id, domain=updated.domain.value if updated.domain else None, control_type=updated.control_type.value if updated.control_type else None, title=updated.title, description=updated.description, pass_criteria=updated.pass_criteria, implementation_guidance=updated.implementation_guidance, code_reference=updated.code_reference, documentation_url=updated.documentation_url, is_automated=updated.is_automated, automation_tool=updated.automation_tool, automation_config=updated.automation_config, owner=updated.owner, review_frequency_days=updated.review_frequency_days, status=updated.status.value if updated.status else None, status_notes=updated.status_notes, status_justification=updated.status_justification, last_reviewed_at=updated.last_reviewed_at, next_review_at=updated.next_review_at, created_at=updated.created_at, updated_at=updated.updated_at, ) @router.get( "/controls/by-domain/{domain}", response_model=ControlListResponse, ) async def get_controls_by_domain( domain: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ControlListResponse: """Get controls by domain.""" with translate_domain_errors(): return svc.get_controls_by_domain(domain) # ============================================================================ # Export # ============================================================================ @router.post("/export", response_model=ExportResponse) async def create_export( request: ExportRequest, background_tasks: BackgroundTasks, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportResponse: """Create a new audit export.""" with translate_domain_errors(): data = svc.create_export( export_type=request.export_type, included_regulations=request.included_regulations, included_domains=request.included_domains, date_range_start=request.date_range_start, date_range_end=request.date_range_end, ) return ExportResponse(**data) @router.get("/export/{export_id}", response_model=ExportResponse) async def get_export( export_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportResponse: """Get export status.""" with translate_domain_errors(): data = svc.get_export(export_id) return ExportResponse(**data) @router.get("/export/{export_id}/download") async def download_export( export_id: str, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> FileResponse: """Download export file.""" with translate_domain_errors(): file_path = svc.download_export(export_id) return FileResponse( file_path, media_type="application/zip", filename=os.path.basename(file_path), ) @router.get("/exports", response_model=ExportListResponse) async def list_exports( limit: int = 20, offset: int = 0, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> ExportListResponse: """List recent exports.""" with translate_domain_errors(): data = svc.list_exports(limit, offset) return ExportListResponse(**data) # ============================================================================ # Seeding / Admin # ============================================================================ @router.post("/init-tables") async def init_tables( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Create compliance tables if they don't exist.""" try: return svc.init_tables() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/create-indexes") async def create_performance_indexes( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Create additional performance indexes.""" return svc.create_indexes() @router.post("/seed-risks") async def seed_risks_only( svc: ControlExportService = Depends(get_ctrl_export_service), ) -> dict[str, Any]: """Seed only risks.""" try: return svc.seed_risks_only() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/seed", response_model=SeedResponse) async def seed_database( request: SeedRequest, svc: ControlExportService = Depends(get_ctrl_export_service), ) -> SeedResponse: """Seed the compliance database with initial data.""" try: data = svc.seed_database(force=request.force) return SeedResponse(**data) except Exception as e: raise HTTPException(status_code=500, detail=str(e))