Split routes.py (991 LOC) into thin handlers + two service files: - RegulationRequirementService: regulations CRUD, requirements CRUD - ControlExportService: controls CRUD/review/domain, export, admin seeding All 216 tests pass. Route module re-exports repository classes so existing test patches (compliance.api.routes.*Repository) keep working. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
464 lines
14 KiB
Python
464 lines
14 KiB
Python
# mypy: disable-error-code="arg-type"
|
|
"""
|
|
FastAPI routes for Compliance module.
|
|
|
|
Endpoints:
|
|
- /regulations: Manage regulations
|
|
- /requirements: Manage requirements
|
|
- /controls: Manage controls
|
|
- /export: Audit export
|
|
- /init-tables, /create-indexes, /seed, /seed-risks: Admin setup
|
|
|
|
Phase 1 Step 4 refactor: handlers delegate to
|
|
RegulationRequirementService and ControlExportService.
|
|
Repository classes are re-exported so existing test patches
|
|
(``compliance.api.routes.ControlRepository``, etc.) keep working.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
|
from fastapi.responses import FileResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from classroom_engine.database import get_db
|
|
|
|
from ..db import (
|
|
ControlDomainEnum,
|
|
ControlRepository,
|
|
ControlStatusEnum,
|
|
EvidenceRepository,
|
|
RegulationRepository,
|
|
RequirementRepository,
|
|
)
|
|
from ..services.regulation_requirement_service import (
|
|
RegulationRequirementService,
|
|
)
|
|
from ..services.control_export_service import ControlExportService
|
|
from .schemas import (
|
|
ControlListResponse,
|
|
ControlResponse,
|
|
ControlReviewRequest,
|
|
ControlUpdate,
|
|
ExportListResponse,
|
|
ExportRequest,
|
|
ExportResponse,
|
|
PaginatedControlResponse,
|
|
PaginatedRequirementResponse,
|
|
PaginationMeta,
|
|
RegulationListResponse,
|
|
RegulationResponse,
|
|
RequirementCreate,
|
|
RequirementListResponse,
|
|
RequirementResponse,
|
|
SeedRequest,
|
|
SeedResponse,
|
|
)
|
|
from ._http_errors import translate_domain_errors
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dependency factories
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_reg_req_service(
|
|
db: Session = Depends(get_db),
|
|
) -> RegulationRequirementService:
|
|
return RegulationRequirementService(
|
|
db,
|
|
reg_repo_cls=RegulationRepository,
|
|
req_repo_cls=RequirementRepository,
|
|
)
|
|
|
|
|
|
def get_ctrl_export_service(
|
|
db: Session = Depends(get_db),
|
|
) -> ControlExportService:
|
|
return ControlExportService(
|
|
db,
|
|
control_repo_cls=ControlRepository,
|
|
evidence_repo_cls=EvidenceRepository,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Regulations
|
|
# ============================================================================
|
|
|
|
@router.get("/regulations", response_model=RegulationListResponse)
|
|
async def list_regulations(
|
|
is_active: Optional[bool] = None,
|
|
regulation_type: Optional[str] = None,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> RegulationListResponse:
|
|
"""List all regulations."""
|
|
with translate_domain_errors():
|
|
return svc.list_regulations(is_active, regulation_type)
|
|
|
|
|
|
@router.get("/regulations/{code}", response_model=RegulationResponse)
|
|
async def get_regulation(
|
|
code: str,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> RegulationResponse:
|
|
"""Get a specific regulation by code."""
|
|
with translate_domain_errors():
|
|
return svc.get_regulation(code)
|
|
|
|
|
|
@router.get(
|
|
"/regulations/{code}/requirements",
|
|
response_model=RequirementListResponse,
|
|
)
|
|
async def get_regulation_requirements(
|
|
code: str,
|
|
is_applicable: Optional[bool] = None,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> RequirementListResponse:
|
|
"""Get requirements for a specific regulation."""
|
|
with translate_domain_errors():
|
|
return svc.get_regulation_requirements(code, is_applicable)
|
|
|
|
|
|
# ============================================================================
|
|
# Requirements
|
|
# ============================================================================
|
|
|
|
@router.get("/requirements/{requirement_id}")
|
|
async def get_requirement(
|
|
requirement_id: str,
|
|
include_legal_context: bool = Query(
|
|
False, description="Include RAG legal context"
|
|
),
|
|
db: Session = Depends(get_db),
|
|
) -> dict[str, Any]:
|
|
"""Get a specific requirement by ID, optionally with RAG legal context."""
|
|
svc = RegulationRequirementService(db)
|
|
with translate_domain_errors():
|
|
result = svc.get_requirement(requirement_id, False)
|
|
|
|
# Handle async legal context fetching inline to preserve behavior
|
|
if include_legal_context:
|
|
try:
|
|
from ..services.rag_client import get_rag_client
|
|
from ..services.ai_compliance_assistant import (
|
|
AIComplianceAssistant,
|
|
)
|
|
from ..db.models import RequirementDB, RegulationDB
|
|
|
|
requirement = (
|
|
db.query(RequirementDB)
|
|
.filter(RequirementDB.id == requirement_id)
|
|
.first()
|
|
)
|
|
regulation = (
|
|
db.query(RegulationDB)
|
|
.filter(
|
|
RegulationDB.id == requirement.regulation_id
|
|
)
|
|
.first()
|
|
) if requirement else None
|
|
|
|
rag = get_rag_client()
|
|
assistant = AIComplianceAssistant()
|
|
query = (
|
|
f"{requirement.title} {requirement.article or ''}"
|
|
)
|
|
collection = assistant._collection_for_regulation(
|
|
regulation.code if regulation else ""
|
|
)
|
|
rag_results = await rag.search(
|
|
query, collection=collection, top_k=3
|
|
)
|
|
result["legal_context"] = [
|
|
{
|
|
"text": r.text,
|
|
"regulation_code": r.regulation_code,
|
|
"regulation_short": r.regulation_short,
|
|
"article": r.article,
|
|
"score": r.score,
|
|
"source_url": r.source_url,
|
|
}
|
|
for r in rag_results
|
|
]
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Failed to fetch legal context for %s: %s",
|
|
requirement_id,
|
|
e,
|
|
)
|
|
result["legal_context"] = []
|
|
|
|
return result
|
|
|
|
|
|
@router.get(
|
|
"/requirements", response_model=PaginatedRequirementResponse
|
|
)
|
|
async def list_requirements_paginated(
|
|
page: int = Query(1, ge=1, description="Page number"),
|
|
page_size: int = Query(
|
|
50, ge=1, le=500, description="Items per page"
|
|
),
|
|
regulation_code: Optional[str] = Query(
|
|
None, description="Filter by regulation code"
|
|
),
|
|
status: Optional[str] = Query(
|
|
None, description="Filter by implementation status"
|
|
),
|
|
is_applicable: Optional[bool] = Query(
|
|
None, description="Filter by applicability"
|
|
),
|
|
search: Optional[str] = Query(
|
|
None, description="Search in title/description"
|
|
),
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> PaginatedRequirementResponse:
|
|
"""List requirements with pagination."""
|
|
with translate_domain_errors():
|
|
return svc.list_requirements_paginated(
|
|
page, page_size, regulation_code, status,
|
|
is_applicable, search,
|
|
)
|
|
|
|
|
|
@router.post("/requirements", response_model=RequirementResponse)
|
|
async def create_requirement(
|
|
data: RequirementCreate,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> RequirementResponse:
|
|
"""Create a new requirement."""
|
|
with translate_domain_errors():
|
|
return svc.create_requirement(data)
|
|
|
|
|
|
@router.delete("/requirements/{requirement_id}")
|
|
async def delete_requirement(
|
|
requirement_id: str,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> dict[str, Any]:
|
|
"""Delete a requirement by ID."""
|
|
with translate_domain_errors():
|
|
return svc.delete_requirement(requirement_id)
|
|
|
|
|
|
@router.put("/requirements/{requirement_id}")
|
|
async def update_requirement(
|
|
requirement_id: str,
|
|
updates: dict,
|
|
svc: RegulationRequirementService = Depends(get_reg_req_service),
|
|
) -> dict[str, Any]:
|
|
"""Update a requirement with implementation/audit details."""
|
|
with translate_domain_errors():
|
|
return svc.update_requirement(requirement_id, updates)
|
|
|
|
|
|
# ============================================================================
|
|
# Controls
|
|
# ============================================================================
|
|
|
|
@router.get("/controls", response_model=ControlListResponse)
|
|
async def list_controls(
|
|
domain: Optional[str] = None,
|
|
status: Optional[str] = None,
|
|
is_automated: Optional[bool] = None,
|
|
search: Optional[str] = None,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ControlListResponse:
|
|
"""List all controls with optional filters."""
|
|
with translate_domain_errors():
|
|
return svc.list_controls(domain, status, is_automated, search)
|
|
|
|
|
|
@router.get(
|
|
"/controls/paginated", response_model=PaginatedControlResponse
|
|
)
|
|
async def list_controls_paginated(
|
|
page: int = Query(1, ge=1, description="Page number"),
|
|
page_size: int = Query(
|
|
50, ge=1, le=500, description="Items per page"
|
|
),
|
|
domain: Optional[str] = Query(
|
|
None, description="Filter by domain"
|
|
),
|
|
status: Optional[str] = Query(
|
|
None, description="Filter by status"
|
|
),
|
|
is_automated: Optional[bool] = Query(
|
|
None, description="Filter by automation"
|
|
),
|
|
search: Optional[str] = Query(
|
|
None, description="Search in title/description"
|
|
),
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> PaginatedControlResponse:
|
|
"""List controls with pagination."""
|
|
with translate_domain_errors():
|
|
return svc.list_controls_paginated(
|
|
page, page_size, domain, status, is_automated, search,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/controls/{control_id}", response_model=ControlResponse
|
|
)
|
|
async def get_control(
|
|
control_id: str,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ControlResponse:
|
|
"""Get a specific control by control_id."""
|
|
with translate_domain_errors():
|
|
return svc.get_control(control_id)
|
|
|
|
|
|
@router.put(
|
|
"/controls/{control_id}", response_model=ControlResponse
|
|
)
|
|
async def update_control(
|
|
control_id: str,
|
|
update: ControlUpdate,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ControlResponse:
|
|
"""Update a control."""
|
|
with translate_domain_errors():
|
|
return svc.update_control(control_id, update)
|
|
|
|
|
|
@router.put(
|
|
"/controls/{control_id}/review",
|
|
response_model=ControlResponse,
|
|
)
|
|
async def review_control(
|
|
control_id: str,
|
|
review: ControlReviewRequest,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ControlResponse:
|
|
"""Mark a control as reviewed with new status."""
|
|
with translate_domain_errors():
|
|
return svc.review_control(control_id, review)
|
|
|
|
|
|
@router.get(
|
|
"/controls/by-domain/{domain}",
|
|
response_model=ControlListResponse,
|
|
)
|
|
async def get_controls_by_domain(
|
|
domain: str,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ControlListResponse:
|
|
"""Get controls by domain."""
|
|
with translate_domain_errors():
|
|
return svc.get_controls_by_domain(domain)
|
|
|
|
|
|
# ============================================================================
|
|
# Export
|
|
# ============================================================================
|
|
|
|
@router.post("/export", response_model=ExportResponse)
|
|
async def create_export(
|
|
request: ExportRequest,
|
|
background_tasks: BackgroundTasks,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ExportResponse:
|
|
"""Create a new audit export."""
|
|
with translate_domain_errors():
|
|
data = svc.create_export(
|
|
export_type=request.export_type,
|
|
included_regulations=request.included_regulations,
|
|
included_domains=request.included_domains,
|
|
date_range_start=request.date_range_start,
|
|
date_range_end=request.date_range_end,
|
|
)
|
|
return ExportResponse(**data)
|
|
|
|
|
|
@router.get("/export/{export_id}", response_model=ExportResponse)
|
|
async def get_export(
|
|
export_id: str,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ExportResponse:
|
|
"""Get export status."""
|
|
with translate_domain_errors():
|
|
data = svc.get_export(export_id)
|
|
return ExportResponse(**data)
|
|
|
|
|
|
@router.get("/export/{export_id}/download")
|
|
async def download_export(
|
|
export_id: str,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> FileResponse:
|
|
"""Download export file."""
|
|
with translate_domain_errors():
|
|
file_path = svc.download_export(export_id)
|
|
return FileResponse(
|
|
file_path,
|
|
media_type="application/zip",
|
|
filename=os.path.basename(file_path),
|
|
)
|
|
|
|
|
|
@router.get("/exports", response_model=ExportListResponse)
|
|
async def list_exports(
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> ExportListResponse:
|
|
"""List recent exports."""
|
|
with translate_domain_errors():
|
|
data = svc.list_exports(limit, offset)
|
|
return ExportListResponse(**data)
|
|
|
|
|
|
# ============================================================================
|
|
# Seeding / Admin
|
|
# ============================================================================
|
|
|
|
@router.post("/init-tables")
|
|
async def init_tables(
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> dict[str, Any]:
|
|
"""Create compliance tables if they don't exist."""
|
|
try:
|
|
return svc.init_tables()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/create-indexes")
|
|
async def create_performance_indexes(
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> dict[str, Any]:
|
|
"""Create additional performance indexes."""
|
|
return svc.create_indexes()
|
|
|
|
|
|
@router.post("/seed-risks")
|
|
async def seed_risks_only(
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> dict[str, Any]:
|
|
"""Seed only risks."""
|
|
try:
|
|
return svc.seed_risks_only()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/seed", response_model=SeedResponse)
|
|
async def seed_database(
|
|
request: SeedRequest,
|
|
svc: ControlExportService = Depends(get_ctrl_export_service),
|
|
) -> SeedResponse:
|
|
"""Seed the compliance database with initial data."""
|
|
try:
|
|
data = svc.seed_database(force=request.force)
|
|
return SeedResponse(**data)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|