diff --git a/backend-compliance/compliance/api/routes.py b/backend-compliance/compliance/api/routes.py index 6c97915..900b93f 100644 --- a/backend-compliance/compliance/api/routes.py +++ b/backend-compliance/compliance/api/routes.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="arg-type" """ FastAPI routes for Compliance module. @@ -5,51 +6,86 @@ Endpoints: - /regulations: Manage regulations - /requirements: Manage requirements - /controls: Manage controls -- /mappings: Requirement-Control mappings -- /evidence: Evidence management -- /risks: Risk management -- /dashboard: Dashboard statistics - /export: Audit export +- /init-tables, /create-indexes, /seed, /seed-risks: Admin setup + +Phase 1 Step 4 refactor: handlers delegate to +RegulationRequirementService and ControlExportService. +Repository classes are re-exported so existing test patches +(``compliance.api.routes.ControlRepository``, etc.) keep working. """ import logging - -logger = logging.getLogger(__name__) import os -from datetime import datetime, timezone -from typing import Optional +from typing import Any, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query from fastapi.responses import FileResponse from sqlalchemy.orm import Session from classroom_engine.database import get_db from ..db import ( + ControlDomainEnum, + ControlRepository, + ControlStatusEnum, + EvidenceRepository, RegulationRepository, RequirementRepository, - ControlRepository, - EvidenceRepository, - ControlStatusEnum, - ControlDomainEnum, ) -from ..db.models import EvidenceDB, ControlDB -from ..services.seeder import ComplianceSeeder -from ..services.export_generator import AuditExportGenerator +from ..services.regulation_requirement_service import ( + RegulationRequirementService, +) +from ..services.control_export_service import ControlExportService from .schemas import ( - RegulationResponse, RegulationListResponse, - RequirementCreate, RequirementResponse, RequirementListResponse, - ControlUpdate, ControlResponse, ControlListResponse, ControlReviewRequest, - ExportRequest, ExportResponse, ExportListResponse, - SeedRequest, SeedResponse, - # Pagination schemas - PaginationMeta, PaginatedRequirementResponse, PaginatedControlResponse, + ControlListResponse, + ControlResponse, + ControlReviewRequest, + ControlUpdate, + ExportListResponse, + ExportRequest, + ExportResponse, + PaginatedControlResponse, + PaginatedRequirementResponse, + PaginationMeta, + RegulationListResponse, + RegulationResponse, + RequirementCreate, + RequirementListResponse, + RequirementResponse, + SeedRequest, + SeedResponse, ) +from ._http_errors import translate_domain_errors logger = logging.getLogger(__name__) router = APIRouter(prefix="/compliance", tags=["compliance"]) +# --------------------------------------------------------------------------- +# Dependency factories +# --------------------------------------------------------------------------- + +def get_reg_req_service( + db: Session = Depends(get_db), +) -> RegulationRequirementService: + return RegulationRequirementService( + db, + reg_repo_cls=RegulationRepository, + req_repo_cls=RequirementRepository, + ) + + +def get_ctrl_export_service( + db: Session = Depends(get_db), +) -> ControlExportService: + return ControlExportService( + db, + control_repo_cls=ControlRepository, + evidence_repo_cls=EvidenceRepository, + ) + + # ============================================================================ # Regulations # ============================================================================ @@ -58,169 +94,87 @@ router = APIRouter(prefix="/compliance", tags=["compliance"]) async def list_regulations( is_active: Optional[bool] = None, regulation_type: Optional[str] = None, - db: Session = Depends(get_db), -): + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> RegulationListResponse: """List all regulations.""" - repo = RegulationRepository(db) - if is_active is not None: - regulations = repo.get_active() if is_active else repo.get_all() - else: - regulations = repo.get_all() - - if regulation_type: - from ..db.models import RegulationTypeEnum - try: - reg_type = RegulationTypeEnum(regulation_type) - regulations = [r for r in regulations if r.regulation_type == reg_type] - except ValueError: - pass - - # Add requirement counts - req_repo = RequirementRepository(db) - results = [] - for reg in regulations: - reqs = req_repo.get_by_regulation(reg.id) - reg_dict = { - "id": reg.id, - "code": reg.code, - "name": reg.name, - "full_name": reg.full_name, - "regulation_type": reg.regulation_type.value if reg.regulation_type else None, - "source_url": reg.source_url, - "local_pdf_path": reg.local_pdf_path, - "effective_date": reg.effective_date, - "description": reg.description, - "is_active": reg.is_active, - "created_at": reg.created_at, - "updated_at": reg.updated_at, - "requirement_count": len(reqs), - } - results.append(RegulationResponse(**reg_dict)) - - return RegulationListResponse(regulations=results, total=len(results)) + with translate_domain_errors(): + return svc.list_regulations(is_active, regulation_type) @router.get("/regulations/{code}", response_model=RegulationResponse) -async def get_regulation(code: str, db: Session = Depends(get_db)): +async def get_regulation( + code: str, + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> RegulationResponse: """Get a specific regulation by code.""" - repo = RegulationRepository(db) - regulation = repo.get_by_code(code) - if not regulation: - raise HTTPException(status_code=404, detail=f"Regulation {code} not found") - - req_repo = RequirementRepository(db) - reqs = req_repo.get_by_regulation(regulation.id) - - return RegulationResponse( - id=regulation.id, - code=regulation.code, - name=regulation.name, - full_name=regulation.full_name, - regulation_type=regulation.regulation_type.value if regulation.regulation_type else None, - source_url=regulation.source_url, - local_pdf_path=regulation.local_pdf_path, - effective_date=regulation.effective_date, - description=regulation.description, - is_active=regulation.is_active, - created_at=regulation.created_at, - updated_at=regulation.updated_at, - requirement_count=len(reqs), - ) + with translate_domain_errors(): + return svc.get_regulation(code) -@router.get("/regulations/{code}/requirements", response_model=RequirementListResponse) +@router.get( + "/regulations/{code}/requirements", + response_model=RequirementListResponse, +) async def get_regulation_requirements( code: str, is_applicable: Optional[bool] = None, - db: Session = Depends(get_db), -): + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> RequirementListResponse: """Get requirements for a specific regulation.""" - reg_repo = RegulationRepository(db) - regulation = reg_repo.get_by_code(code) - if not regulation: - raise HTTPException(status_code=404, detail=f"Regulation {code} not found") + with translate_domain_errors(): + return svc.get_regulation_requirements(code, is_applicable) - req_repo = RequirementRepository(db) - if is_applicable is not None: - requirements = req_repo.get_applicable(regulation.id) if is_applicable else req_repo.get_by_regulation(regulation.id) - else: - requirements = req_repo.get_by_regulation(regulation.id) - - results = [ - RequirementResponse( - id=r.id, - regulation_id=r.regulation_id, - regulation_code=code, - article=r.article, - paragraph=r.paragraph, - title=r.title, - description=r.description, - requirement_text=r.requirement_text, - breakpilot_interpretation=r.breakpilot_interpretation, - is_applicable=r.is_applicable, - applicability_reason=r.applicability_reason, - priority=r.priority, - created_at=r.created_at, - updated_at=r.updated_at, - ) - for r in requirements - ] - - return RequirementListResponse(requirements=results, total=len(results)) +# ============================================================================ +# Requirements +# ============================================================================ @router.get("/requirements/{requirement_id}") async def get_requirement( requirement_id: str, - include_legal_context: bool = Query(False, description="Include RAG legal context"), + include_legal_context: bool = Query( + False, description="Include RAG legal context" + ), db: Session = Depends(get_db), -): +) -> dict[str, Any]: """Get a specific requirement by ID, optionally with RAG legal context.""" - from ..db.models import RequirementDB, RegulationDB - - requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first() - if not requirement: - raise HTTPException(status_code=404, detail=f"Requirement {requirement_id} not found") - - regulation = db.query(RegulationDB).filter(RegulationDB.id == requirement.regulation_id).first() - - result = { - "id": requirement.id, - "regulation_id": requirement.regulation_id, - "regulation_code": regulation.code if regulation else None, - "article": requirement.article, - "paragraph": requirement.paragraph, - "title": requirement.title, - "description": requirement.description, - "requirement_text": requirement.requirement_text, - "breakpilot_interpretation": requirement.breakpilot_interpretation, - "implementation_status": requirement.implementation_status or "not_started", - "implementation_details": requirement.implementation_details, - "code_references": requirement.code_references, - "documentation_links": requirement.documentation_links, - "evidence_description": requirement.evidence_description, - "evidence_artifacts": requirement.evidence_artifacts, - "auditor_notes": requirement.auditor_notes, - "audit_status": requirement.audit_status or "pending", - "last_audit_date": requirement.last_audit_date, - "last_auditor": requirement.last_auditor, - "is_applicable": requirement.is_applicable, - "applicability_reason": requirement.applicability_reason, - "priority": requirement.priority, - "source_page": requirement.source_page, - "source_section": requirement.source_section, - } + svc = RegulationRequirementService(db) + with translate_domain_errors(): + result = svc.get_requirement(requirement_id, False) + # Handle async legal context fetching inline to preserve behavior if include_legal_context: try: from ..services.rag_client import get_rag_client - from ..services.ai_compliance_assistant import AIComplianceAssistant + from ..services.ai_compliance_assistant import ( + AIComplianceAssistant, + ) + from ..db.models import RequirementDB, RegulationDB + + requirement = ( + db.query(RequirementDB) + .filter(RequirementDB.id == requirement_id) + .first() + ) + regulation = ( + db.query(RegulationDB) + .filter( + RegulationDB.id == requirement.regulation_id + ) + .first() + ) if requirement else None rag = get_rag_client() assistant = AIComplianceAssistant() - query = f"{requirement.title} {requirement.article or ''}" - collection = assistant._collection_for_regulation(regulation.code if regulation else "") - rag_results = await rag.search(query, collection=collection, top_k=3) + query = ( + f"{requirement.title} {requirement.article or ''}" + ) + collection = assistant._collection_for_regulation( + regulation.code if regulation else "" + ) + rag_results = await rag.search( + query, collection=collection, top_k=3 + ) result["legal_context"] = [ { "text": r.text, @@ -233,175 +187,75 @@ async def get_requirement( for r in rag_results ] except Exception as e: - logger.warning("Failed to fetch legal context for %s: %s", requirement_id, e) + logger.warning( + "Failed to fetch legal context for %s: %s", + requirement_id, + e, + ) result["legal_context"] = [] return result -@router.get("/requirements", response_model=PaginatedRequirementResponse) +@router.get( + "/requirements", response_model=PaginatedRequirementResponse +) async def list_requirements_paginated( page: int = Query(1, ge=1, description="Page number"), - page_size: int = Query(50, ge=1, le=500, description="Items per page"), - regulation_code: Optional[str] = Query(None, description="Filter by regulation code"), - status: Optional[str] = Query(None, description="Filter by implementation status"), - is_applicable: Optional[bool] = Query(None, description="Filter by applicability"), - search: Optional[str] = Query(None, description="Search in title/description"), - db: Session = Depends(get_db), -): - """ - List requirements with pagination and eager-loaded relationships. - - This endpoint is optimized for large datasets (1000+ requirements) with: - - Eager loading to prevent N+1 queries - - Server-side pagination - - Full-text search support - """ - req_repo = RequirementRepository(db) - - # Use the new paginated method with eager loading - requirements, total = req_repo.get_paginated( - page=page, - page_size=page_size, - regulation_code=regulation_code, - status=status, - is_applicable=is_applicable, - search=search, - ) - - # Calculate pagination metadata - total_pages = (total + page_size - 1) // page_size - - results = [ - RequirementResponse( - id=r.id, - regulation_id=r.regulation_id, - regulation_code=r.regulation.code if r.regulation else None, - article=r.article, - paragraph=r.paragraph, - title=r.title, - description=r.description, - requirement_text=r.requirement_text, - breakpilot_interpretation=r.breakpilot_interpretation, - is_applicable=r.is_applicable, - applicability_reason=r.applicability_reason, - priority=r.priority, - implementation_status=r.implementation_status or "not_started", - implementation_details=r.implementation_details, - code_references=r.code_references, - documentation_links=r.documentation_links, - evidence_description=r.evidence_description, - evidence_artifacts=r.evidence_artifacts, - auditor_notes=r.auditor_notes, - audit_status=r.audit_status or "pending", - last_audit_date=r.last_audit_date, - last_auditor=r.last_auditor, - source_page=r.source_page, - source_section=r.source_section, - created_at=r.created_at, - updated_at=r.updated_at, + page_size: int = Query( + 50, ge=1, le=500, description="Items per page" + ), + regulation_code: Optional[str] = Query( + None, description="Filter by regulation code" + ), + status: Optional[str] = Query( + None, description="Filter by implementation status" + ), + is_applicable: Optional[bool] = Query( + None, description="Filter by applicability" + ), + search: Optional[str] = Query( + None, description="Search in title/description" + ), + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> PaginatedRequirementResponse: + """List requirements with pagination.""" + with translate_domain_errors(): + return svc.list_requirements_paginated( + page, page_size, regulation_code, status, + is_applicable, search, ) - for r in requirements - ] - - return PaginatedRequirementResponse( - data=results, - pagination=PaginationMeta( - page=page, - page_size=page_size, - total=total, - total_pages=total_pages, - has_next=page < total_pages, - has_prev=page > 1, - ), - ) @router.post("/requirements", response_model=RequirementResponse) async def create_requirement( data: RequirementCreate, - db: Session = Depends(get_db), -): + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> RequirementResponse: """Create a new requirement.""" - # Verify regulation exists - reg_repo = RegulationRepository(db) - regulation = reg_repo.get_by_id(data.regulation_id) - if not regulation: - raise HTTPException(status_code=404, detail=f"Regulation {data.regulation_id} not found") - - req_repo = RequirementRepository(db) - requirement = req_repo.create( - regulation_id=data.regulation_id, - article=data.article, - title=data.title, - paragraph=data.paragraph, - description=data.description, - requirement_text=data.requirement_text, - breakpilot_interpretation=data.breakpilot_interpretation, - is_applicable=data.is_applicable, - priority=data.priority, - ) - - return RequirementResponse( - id=requirement.id, - regulation_id=requirement.regulation_id, - regulation_code=regulation.code, - article=requirement.article, - paragraph=requirement.paragraph, - title=requirement.title, - description=requirement.description, - requirement_text=requirement.requirement_text, - breakpilot_interpretation=requirement.breakpilot_interpretation, - is_applicable=requirement.is_applicable, - applicability_reason=requirement.applicability_reason, - priority=requirement.priority, - created_at=requirement.created_at, - updated_at=requirement.updated_at, - ) + with translate_domain_errors(): + return svc.create_requirement(data) @router.delete("/requirements/{requirement_id}") -async def delete_requirement(requirement_id: str, db: Session = Depends(get_db)): +async def delete_requirement( + requirement_id: str, + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> dict[str, Any]: """Delete a requirement by ID.""" - req_repo = RequirementRepository(db) - deleted = req_repo.delete(requirement_id) - if not deleted: - raise HTTPException(status_code=404, detail=f"Requirement {requirement_id} not found") - return {"success": True, "message": "Requirement deleted"} + with translate_domain_errors(): + return svc.delete_requirement(requirement_id) @router.put("/requirements/{requirement_id}") -async def update_requirement(requirement_id: str, updates: dict, db: Session = Depends(get_db)): +async def update_requirement( + requirement_id: str, + updates: dict, + svc: RegulationRequirementService = Depends(get_reg_req_service), +) -> dict[str, Any]: """Update a requirement with implementation/audit details.""" - from ..db.models import RequirementDB - - requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first() - if not requirement: - raise HTTPException(status_code=404, detail=f"Requirement {requirement_id} not found") - - # Allowed fields to update - allowed_fields = [ - 'implementation_status', 'implementation_details', 'code_references', - 'documentation_links', 'evidence_description', 'evidence_artifacts', - 'auditor_notes', 'audit_status', 'is_applicable', 'applicability_reason', - 'breakpilot_interpretation' - ] - - for field in allowed_fields: - if field in updates: - setattr(requirement, field, updates[field]) - - # Track audit changes - if 'audit_status' in updates: - requirement.last_audit_date = datetime.now(timezone.utc) - # TODO: Get auditor from auth - requirement.last_auditor = updates.get('auditor_name', 'api_user') - - requirement.updated_at = datetime.now(timezone.utc) - db.commit() - db.refresh(requirement) - - return {"success": True, "message": "Requirement updated"} + with translate_domain_errors(): + return svc.update_requirement(requirement_id, updates) # ============================================================================ @@ -414,409 +268,139 @@ async def list_controls( status: Optional[str] = None, is_automated: Optional[bool] = None, search: Optional[str] = None, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ControlListResponse: """List all controls with optional filters.""" - repo = ControlRepository(db) - - if domain: - try: - domain_enum = ControlDomainEnum(domain) - controls = repo.get_by_domain(domain_enum) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid domain: {domain}") - elif status: - try: - status_enum = ControlStatusEnum(status) - controls = repo.get_by_status(status_enum) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid status: {status}") - else: - controls = repo.get_all() - - # Apply additional filters - if is_automated is not None: - controls = [c for c in controls if c.is_automated == is_automated] - - if search: - search_lower = search.lower() - controls = [ - c for c in controls - if search_lower in c.control_id.lower() - or search_lower in c.title.lower() - or (c.description and search_lower in c.description.lower()) - ] - - # Add counts - evidence_repo = EvidenceRepository(db) - results = [] - for ctrl in controls: - evidence = evidence_repo.get_by_control(ctrl.id) - results.append(ControlResponse( - id=ctrl.id, - control_id=ctrl.control_id, - domain=ctrl.domain.value if ctrl.domain else None, - control_type=ctrl.control_type.value if ctrl.control_type else None, - title=ctrl.title, - description=ctrl.description, - pass_criteria=ctrl.pass_criteria, - implementation_guidance=ctrl.implementation_guidance, - code_reference=ctrl.code_reference, - documentation_url=ctrl.documentation_url, - is_automated=ctrl.is_automated, - automation_tool=ctrl.automation_tool, - automation_config=ctrl.automation_config, - owner=ctrl.owner, - review_frequency_days=ctrl.review_frequency_days, - status=ctrl.status.value if ctrl.status else None, - status_notes=ctrl.status_notes, - last_reviewed_at=ctrl.last_reviewed_at, - next_review_at=ctrl.next_review_at, - created_at=ctrl.created_at, - updated_at=ctrl.updated_at, - evidence_count=len(evidence), - )) - - return ControlListResponse(controls=results, total=len(results)) + with translate_domain_errors(): + return svc.list_controls(domain, status, is_automated, search) -@router.get("/controls/paginated", response_model=PaginatedControlResponse) +@router.get( + "/controls/paginated", response_model=PaginatedControlResponse +) async def list_controls_paginated( page: int = Query(1, ge=1, description="Page number"), - page_size: int = Query(50, ge=1, le=500, description="Items per page"), - domain: Optional[str] = Query(None, description="Filter by domain"), - status: Optional[str] = Query(None, description="Filter by status"), - is_automated: Optional[bool] = Query(None, description="Filter by automation"), - search: Optional[str] = Query(None, description="Search in title/description"), - db: Session = Depends(get_db), -): - """ - List controls with pagination and eager-loaded relationships. - - This endpoint is optimized for large datasets with: - - Eager loading to prevent N+1 queries - - Server-side pagination - - Full-text search support - """ - repo = ControlRepository(db) - - # Convert domain/status to enums if provided - domain_enum = None - status_enum = None - if domain: - try: - domain_enum = ControlDomainEnum(domain) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid domain: {domain}") - if status: - try: - status_enum = ControlStatusEnum(status) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid status: {status}") - - controls, total = repo.get_paginated( - page=page, - page_size=page_size, - domain=domain_enum, - status=status_enum, - is_automated=is_automated, - search=search, - ) - - total_pages = (total + page_size - 1) // page_size - - results = [ - ControlResponse( - id=c.id, - control_id=c.control_id, - domain=c.domain.value if c.domain else None, - control_type=c.control_type.value if c.control_type else None, - title=c.title, - description=c.description, - pass_criteria=c.pass_criteria, - implementation_guidance=c.implementation_guidance, - code_reference=c.code_reference, - documentation_url=c.documentation_url, - is_automated=c.is_automated, - automation_tool=c.automation_tool, - automation_config=c.automation_config, - owner=c.owner, - review_frequency_days=c.review_frequency_days, - status=c.status.value if c.status else None, - status_notes=c.status_notes, - last_reviewed_at=c.last_reviewed_at, - next_review_at=c.next_review_at, - created_at=c.created_at, - updated_at=c.updated_at, - evidence_count=len(c.evidence) if c.evidence else 0, + page_size: int = Query( + 50, ge=1, le=500, description="Items per page" + ), + domain: Optional[str] = Query( + None, description="Filter by domain" + ), + status: Optional[str] = Query( + None, description="Filter by status" + ), + is_automated: Optional[bool] = Query( + None, description="Filter by automation" + ), + search: Optional[str] = Query( + None, description="Search in title/description" + ), + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> PaginatedControlResponse: + """List controls with pagination.""" + with translate_domain_errors(): + return svc.list_controls_paginated( + page, page_size, domain, status, is_automated, search, ) - for c in controls - ] - - return PaginatedControlResponse( - data=results, - pagination=PaginationMeta( - page=page, - page_size=page_size, - total=total, - total_pages=total_pages, - has_next=page < total_pages, - has_prev=page > 1, - ), - ) -@router.get("/controls/{control_id}", response_model=ControlResponse) -async def get_control(control_id: str, db: Session = Depends(get_db)): +@router.get( + "/controls/{control_id}", response_model=ControlResponse +) +async def get_control( + control_id: str, + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ControlResponse: """Get a specific control by control_id.""" - repo = ControlRepository(db) - control = repo.get_by_control_id(control_id) - if not control: - raise HTTPException(status_code=404, detail=f"Control {control_id} not found") - - evidence_repo = EvidenceRepository(db) - evidence = evidence_repo.get_by_control(control.id) - - return ControlResponse( - id=control.id, - control_id=control.control_id, - domain=control.domain.value if control.domain else None, - control_type=control.control_type.value if control.control_type else None, - title=control.title, - description=control.description, - pass_criteria=control.pass_criteria, - implementation_guidance=control.implementation_guidance, - code_reference=control.code_reference, - documentation_url=control.documentation_url, - is_automated=control.is_automated, - automation_tool=control.automation_tool, - automation_config=control.automation_config, - owner=control.owner, - review_frequency_days=control.review_frequency_days, - status=control.status.value if control.status else None, - status_notes=control.status_notes, - last_reviewed_at=control.last_reviewed_at, - next_review_at=control.next_review_at, - created_at=control.created_at, - updated_at=control.updated_at, - evidence_count=len(evidence), - ) + with translate_domain_errors(): + return svc.get_control(control_id) -@router.put("/controls/{control_id}", response_model=ControlResponse) +@router.put( + "/controls/{control_id}", response_model=ControlResponse +) async def update_control( control_id: str, update: ControlUpdate, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ControlResponse: """Update a control.""" - repo = ControlRepository(db) - control = repo.get_by_control_id(control_id) - if not control: - raise HTTPException(status_code=404, detail=f"Control {control_id} not found") - - update_data = update.model_dump(exclude_unset=True) - - # Convert status string to enum - if "status" in update_data: - try: - update_data["status"] = ControlStatusEnum(update_data["status"]) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid status: {update_data['status']}") - - updated = repo.update(control.id, **update_data) - db.commit() - - return ControlResponse( - id=updated.id, - control_id=updated.control_id, - domain=updated.domain.value if updated.domain else None, - control_type=updated.control_type.value if updated.control_type else None, - title=updated.title, - description=updated.description, - pass_criteria=updated.pass_criteria, - implementation_guidance=updated.implementation_guidance, - code_reference=updated.code_reference, - documentation_url=updated.documentation_url, - is_automated=updated.is_automated, - automation_tool=updated.automation_tool, - automation_config=updated.automation_config, - owner=updated.owner, - review_frequency_days=updated.review_frequency_days, - status=updated.status.value if updated.status else None, - status_notes=updated.status_notes, - last_reviewed_at=updated.last_reviewed_at, - next_review_at=updated.next_review_at, - created_at=updated.created_at, - updated_at=updated.updated_at, - ) + with translate_domain_errors(): + return svc.update_control(control_id, update) -@router.put("/controls/{control_id}/review", response_model=ControlResponse) +@router.put( + "/controls/{control_id}/review", + response_model=ControlResponse, +) async def review_control( control_id: str, review: ControlReviewRequest, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ControlResponse: """Mark a control as reviewed with new status.""" - repo = ControlRepository(db) - control = repo.get_by_control_id(control_id) - if not control: - raise HTTPException(status_code=404, detail=f"Control {control_id} not found") - - try: - status_enum = ControlStatusEnum(review.status) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid status: {review.status}") - - updated = repo.mark_reviewed(control.id, status_enum, review.status_notes) - db.commit() - - return ControlResponse( - id=updated.id, - control_id=updated.control_id, - domain=updated.domain.value if updated.domain else None, - control_type=updated.control_type.value if updated.control_type else None, - title=updated.title, - description=updated.description, - pass_criteria=updated.pass_criteria, - implementation_guidance=updated.implementation_guidance, - code_reference=updated.code_reference, - documentation_url=updated.documentation_url, - is_automated=updated.is_automated, - automation_tool=updated.automation_tool, - automation_config=updated.automation_config, - owner=updated.owner, - review_frequency_days=updated.review_frequency_days, - status=updated.status.value if updated.status else None, - status_notes=updated.status_notes, - last_reviewed_at=updated.last_reviewed_at, - next_review_at=updated.next_review_at, - created_at=updated.created_at, - updated_at=updated.updated_at, - ) + with translate_domain_errors(): + return svc.review_control(control_id, review) -@router.get("/controls/by-domain/{domain}", response_model=ControlListResponse) -async def get_controls_by_domain(domain: str, db: Session = Depends(get_db)): +@router.get( + "/controls/by-domain/{domain}", + response_model=ControlListResponse, +) +async def get_controls_by_domain( + domain: str, + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ControlListResponse: """Get controls by domain.""" - try: - domain_enum = ControlDomainEnum(domain) - except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid domain: {domain}") + with translate_domain_errors(): + return svc.get_controls_by_domain(domain) - repo = ControlRepository(db) - controls = repo.get_by_domain(domain_enum) - - results = [ - ControlResponse( - id=c.id, - control_id=c.control_id, - domain=c.domain.value if c.domain else None, - control_type=c.control_type.value if c.control_type else None, - title=c.title, - description=c.description, - pass_criteria=c.pass_criteria, - implementation_guidance=c.implementation_guidance, - code_reference=c.code_reference, - documentation_url=c.documentation_url, - is_automated=c.is_automated, - automation_tool=c.automation_tool, - automation_config=c.automation_config, - owner=c.owner, - review_frequency_days=c.review_frequency_days, - status=c.status.value if c.status else None, - status_notes=c.status_notes, - last_reviewed_at=c.last_reviewed_at, - next_review_at=c.next_review_at, - created_at=c.created_at, - updated_at=c.updated_at, - ) - for c in controls - ] - - return ControlListResponse(controls=results, total=len(results)) +# ============================================================================ +# Export +# ============================================================================ @router.post("/export", response_model=ExportResponse) async def create_export( request: ExportRequest, background_tasks: BackgroundTasks, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ExportResponse: """Create a new audit export.""" - generator = AuditExportGenerator(db) - export = generator.create_export( - requested_by="api_user", # TODO: Get from auth - export_type=request.export_type, - included_regulations=request.included_regulations, - included_domains=request.included_domains, - date_range_start=request.date_range_start, - date_range_end=request.date_range_end, - ) - - return ExportResponse( - id=export.id, - export_type=export.export_type, - export_name=export.export_name, - status=export.status.value if export.status else None, - requested_by=export.requested_by, - requested_at=export.requested_at, - completed_at=export.completed_at, - file_path=export.file_path, - file_hash=export.file_hash, - file_size_bytes=export.file_size_bytes, - total_controls=export.total_controls, - total_evidence=export.total_evidence, - compliance_score=export.compliance_score, - error_message=export.error_message, - ) + with translate_domain_errors(): + data = svc.create_export( + export_type=request.export_type, + included_regulations=request.included_regulations, + included_domains=request.included_domains, + date_range_start=request.date_range_start, + date_range_end=request.date_range_end, + ) + return ExportResponse(**data) @router.get("/export/{export_id}", response_model=ExportResponse) -async def get_export(export_id: str, db: Session = Depends(get_db)): +async def get_export( + export_id: str, + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ExportResponse: """Get export status.""" - generator = AuditExportGenerator(db) - export = generator.get_export_status(export_id) - if not export: - raise HTTPException(status_code=404, detail=f"Export {export_id} not found") - - return ExportResponse( - id=export.id, - export_type=export.export_type, - export_name=export.export_name, - status=export.status.value if export.status else None, - requested_by=export.requested_by, - requested_at=export.requested_at, - completed_at=export.completed_at, - file_path=export.file_path, - file_hash=export.file_hash, - file_size_bytes=export.file_size_bytes, - total_controls=export.total_controls, - total_evidence=export.total_evidence, - compliance_score=export.compliance_score, - error_message=export.error_message, - ) + with translate_domain_errors(): + data = svc.get_export(export_id) + return ExportResponse(**data) @router.get("/export/{export_id}/download") -async def download_export(export_id: str, db: Session = Depends(get_db)): +async def download_export( + export_id: str, + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> FileResponse: """Download export file.""" - generator = AuditExportGenerator(db) - export = generator.get_export_status(export_id) - if not export: - raise HTTPException(status_code=404, detail=f"Export {export_id} not found") - - if export.status.value != "completed": - raise HTTPException(status_code=400, detail="Export not completed") - - if not export.file_path or not os.path.exists(export.file_path): - raise HTTPException(status_code=404, detail="Export file not found") - + with translate_domain_errors(): + file_path = svc.download_export(export_id) return FileResponse( - export.file_path, + file_path, media_type="application/zip", - filename=os.path.basename(export.file_path), + filename=os.path.basename(file_path), ) @@ -824,168 +408,56 @@ async def download_export(export_id: str, db: Session = Depends(get_db)): async def list_exports( limit: int = 20, offset: int = 0, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> ExportListResponse: """List recent exports.""" - generator = AuditExportGenerator(db) - exports = generator.list_exports(limit, offset) - - results = [ - ExportResponse( - id=e.id, - export_type=e.export_type, - export_name=e.export_name, - status=e.status.value if e.status else None, - requested_by=e.requested_by, - requested_at=e.requested_at, - completed_at=e.completed_at, - file_path=e.file_path, - file_hash=e.file_hash, - file_size_bytes=e.file_size_bytes, - total_controls=e.total_controls, - total_evidence=e.total_evidence, - compliance_score=e.compliance_score, - error_message=e.error_message, - ) - for e in exports - ] - - return ExportListResponse(exports=results, total=len(results)) + with translate_domain_errors(): + data = svc.list_exports(limit, offset) + return ExportListResponse(**data) # ============================================================================ -# Seeding +# Seeding / Admin # ============================================================================ @router.post("/init-tables") -async def init_tables(db: Session = Depends(get_db)): +async def init_tables( + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> dict[str, Any]: """Create compliance tables if they don't exist.""" - from classroom_engine.database import engine - from ..db.models import ( - RegulationDB, RequirementDB, ControlMappingDB, - RiskDB, AuditExportDB, AISystemDB - ) - try: - # Create all tables - RegulationDB.__table__.create(engine, checkfirst=True) - RequirementDB.__table__.create(engine, checkfirst=True) - ControlDB.__table__.create(engine, checkfirst=True) - ControlMappingDB.__table__.create(engine, checkfirst=True) - EvidenceDB.__table__.create(engine, checkfirst=True) - RiskDB.__table__.create(engine, checkfirst=True) - AuditExportDB.__table__.create(engine, checkfirst=True) - AISystemDB.__table__.create(engine, checkfirst=True) - - return {"success": True, "message": "Tables created successfully"} + return svc.init_tables() except Exception as e: - logger.error(f"Table creation failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/create-indexes") -async def create_performance_indexes(db: Session = Depends(get_db)): - """ - Create additional performance indexes for large datasets. - - These indexes are optimized for: - - Pagination queries (1000+ requirements) - - Full-text search - - Filtering by status/priority - """ - from sqlalchemy import text - - indexes = [ - # Priority index for sorting (descending, as we want high priority first) - ("ix_req_priority_desc", "CREATE INDEX IF NOT EXISTS ix_req_priority_desc ON compliance_requirements (priority DESC)"), - - # Compound index for common filtering patterns - ("ix_req_applicable_status", "CREATE INDEX IF NOT EXISTS ix_req_applicable_status ON compliance_requirements (is_applicable, implementation_status)"), - - # Control status index - ("ix_ctrl_status", "CREATE INDEX IF NOT EXISTS ix_ctrl_status ON compliance_controls (status)"), - - # Evidence collected_at for timeline queries - ("ix_evidence_collected", "CREATE INDEX IF NOT EXISTS ix_evidence_collected ON compliance_evidence (collected_at DESC)"), - - # Risk inherent risk level - ("ix_risk_level", "CREATE INDEX IF NOT EXISTS ix_risk_level ON compliance_risks (inherent_risk)"), - ] - - created = [] - errors = [] - - for idx_name, idx_sql in indexes: - try: - db.execute(text(idx_sql)) - db.commit() - created.append(idx_name) - except Exception as e: - errors.append({"index": idx_name, "error": str(e)}) - logger.warning(f"Index creation failed for {idx_name}: {e}") - - return { - "success": len(errors) == 0, - "created": created, - "errors": errors, - "message": f"Created {len(created)} indexes" + (f", {len(errors)} failed" if errors else ""), - } +async def create_performance_indexes( + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> dict[str, Any]: + """Create additional performance indexes.""" + return svc.create_indexes() @router.post("/seed-risks") -async def seed_risks_only(db: Session = Depends(get_db)): - """Seed only risks (incremental update for existing databases).""" - from classroom_engine.database import engine - from ..db.models import RiskDB - +async def seed_risks_only( + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> dict[str, Any]: + """Seed only risks.""" try: - # Ensure table exists - RiskDB.__table__.create(engine, checkfirst=True) - - seeder = ComplianceSeeder(db) - count = seeder.seed_risks_only() - - return { - "success": True, - "message": f"Successfully seeded {count} risks", - "risks_seeded": count, - } + return svc.seed_risks_only() except Exception as e: - logger.error(f"Risk seeding failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/seed", response_model=SeedResponse) async def seed_database( request: SeedRequest, - db: Session = Depends(get_db), -): + svc: ControlExportService = Depends(get_ctrl_export_service), +) -> SeedResponse: """Seed the compliance database with initial data.""" - from classroom_engine.database import engine - from ..db.models import ( - RegulationDB, RequirementDB, ControlMappingDB, - RiskDB, AuditExportDB - ) - try: - # Ensure tables exist first - RegulationDB.__table__.create(engine, checkfirst=True) - RequirementDB.__table__.create(engine, checkfirst=True) - ControlDB.__table__.create(engine, checkfirst=True) - ControlMappingDB.__table__.create(engine, checkfirst=True) - EvidenceDB.__table__.create(engine, checkfirst=True) - RiskDB.__table__.create(engine, checkfirst=True) - AuditExportDB.__table__.create(engine, checkfirst=True) - - seeder = ComplianceSeeder(db) - counts = seeder.seed_all(force=request.force) - return SeedResponse( - success=True, - message="Database seeded successfully", - counts=counts, - ) + data = svc.seed_database(force=request.force) + return SeedResponse(**data) except Exception as e: - logger.error(f"Seeding failed: {e}") raise HTTPException(status_code=500, detail=str(e)) - - diff --git a/backend-compliance/compliance/services/control_export_service.py b/backend-compliance/compliance/services/control_export_service.py new file mode 100644 index 0000000..7294d76 --- /dev/null +++ b/backend-compliance/compliance/services/control_export_service.py @@ -0,0 +1,498 @@ +# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" +""" +Service for control, export, and admin/seeding business logic. + +Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic +for controls CRUD, export management, and database seeding lives here. +""" + +import logging +import os +from typing import Any, Optional + +from sqlalchemy.orm import Session + +from compliance.db import ( + ControlDomainEnum, + ControlRepository, + ControlStatusEnum, + EvidenceRepository, +) +from compliance.db.models import ControlDB, EvidenceDB +from compliance.domain import NotFoundError, ValidationError +from compliance.schemas.control import ( + ControlListResponse, + ControlResponse, + ControlReviewRequest, + ControlUpdate, + PaginatedControlResponse, +) +from compliance.schemas.common import PaginationMeta + +logger = logging.getLogger(__name__) + + +def _control_to_response( + ctrl: Any, evidence_count: int = 0 +) -> ControlResponse: + return ControlResponse( + id=ctrl.id, + control_id=ctrl.control_id, + domain=ctrl.domain.value if ctrl.domain else None, + control_type=( + ctrl.control_type.value if ctrl.control_type else None + ), + title=ctrl.title, + description=ctrl.description, + pass_criteria=ctrl.pass_criteria, + implementation_guidance=ctrl.implementation_guidance, + code_reference=ctrl.code_reference, + documentation_url=ctrl.documentation_url, + is_automated=ctrl.is_automated, + automation_tool=ctrl.automation_tool, + automation_config=ctrl.automation_config, + owner=ctrl.owner, + review_frequency_days=ctrl.review_frequency_days, + status=ctrl.status.value if ctrl.status else None, + status_notes=ctrl.status_notes, + last_reviewed_at=ctrl.last_reviewed_at, + next_review_at=ctrl.next_review_at, + created_at=ctrl.created_at, + updated_at=ctrl.updated_at, + evidence_count=evidence_count, + ) + + +class ControlExportService: + """Business logic for control and export endpoints.""" + + def __init__( + self, + db: Session, + control_repo_cls: Any = ControlRepository, + evidence_repo_cls: Any = EvidenceRepository, + ) -> None: + self.db = db + self.ctrl_repo = control_repo_cls(db) + self.evidence_repo = evidence_repo_cls(db) + + # ------------------------------------------------------------------ + # Controls + # ------------------------------------------------------------------ + + def list_controls( + self, + domain: Optional[str], + status: Optional[str], + is_automated: Optional[bool], + search: Optional[str], + ) -> ControlListResponse: + if domain: + try: + domain_enum = ControlDomainEnum(domain) + controls = self.ctrl_repo.get_by_domain(domain_enum) + except ValueError: + raise ValidationError(f"Invalid domain: {domain}") + elif status: + try: + status_enum = ControlStatusEnum(status) + controls = self.ctrl_repo.get_by_status(status_enum) + except ValueError: + raise ValidationError(f"Invalid status: {status}") + else: + controls = self.ctrl_repo.get_all() + + if is_automated is not None: + controls = [ + c for c in controls if c.is_automated == is_automated + ] + + if search: + search_lower = search.lower() + controls = [ + c + for c in controls + if search_lower in c.control_id.lower() + or search_lower in c.title.lower() + or (c.description and search_lower in c.description.lower()) + ] + + results = [] + for ctrl in controls: + evidence = self.evidence_repo.get_by_control(ctrl.id) + results.append(_control_to_response(ctrl, len(evidence))) + + return ControlListResponse(controls=results, total=len(results)) + + def list_controls_paginated( + self, + page: int, + page_size: int, + domain: Optional[str], + status: Optional[str], + is_automated: Optional[bool], + search: Optional[str], + ) -> PaginatedControlResponse: + domain_enum = None + status_enum = None + if domain: + try: + domain_enum = ControlDomainEnum(domain) + except ValueError: + raise ValidationError(f"Invalid domain: {domain}") + if status: + try: + status_enum = ControlStatusEnum(status) + except ValueError: + raise ValidationError(f"Invalid status: {status}") + + controls, total = self.ctrl_repo.get_paginated( + page=page, + page_size=page_size, + domain=domain_enum, + status=status_enum, + is_automated=is_automated, + search=search, + ) + + total_pages = (total + page_size - 1) // page_size + + results = [ + ControlResponse( + id=c.id, + control_id=c.control_id, + domain=c.domain.value if c.domain else None, + control_type=( + c.control_type.value if c.control_type else None + ), + title=c.title, + description=c.description, + pass_criteria=c.pass_criteria, + implementation_guidance=c.implementation_guidance, + code_reference=c.code_reference, + documentation_url=c.documentation_url, + is_automated=c.is_automated, + automation_tool=c.automation_tool, + automation_config=c.automation_config, + owner=c.owner, + review_frequency_days=c.review_frequency_days, + status=c.status.value if c.status else None, + status_notes=c.status_notes, + last_reviewed_at=c.last_reviewed_at, + next_review_at=c.next_review_at, + created_at=c.created_at, + updated_at=c.updated_at, + evidence_count=( + len(c.evidence) if c.evidence else 0 + ), + ) + for c in controls + ] + + return PaginatedControlResponse( + data=results, + pagination=PaginationMeta( + page=page, + page_size=page_size, + total=total, + total_pages=total_pages, + has_next=page < total_pages, + has_prev=page > 1, + ), + ) + + def get_control(self, control_id: str) -> ControlResponse: + control = self.ctrl_repo.get_by_control_id(control_id) + if not control: + raise NotFoundError(f"Control {control_id} not found") + + evidence = self.evidence_repo.get_by_control(control.id) + return _control_to_response(control, len(evidence)) + + def update_control( + self, control_id: str, update: ControlUpdate + ) -> ControlResponse: + control = self.ctrl_repo.get_by_control_id(control_id) + if not control: + raise NotFoundError(f"Control {control_id} not found") + + update_data = update.model_dump(exclude_unset=True) + + if "status" in update_data: + try: + update_data["status"] = ControlStatusEnum( + update_data["status"] + ) + except ValueError: + raise ValidationError( + f"Invalid status: {update_data['status']}" + ) + + updated = self.ctrl_repo.update(control.id, **update_data) + self.db.commit() + return _control_to_response(updated) + + def review_control( + self, control_id: str, review: ControlReviewRequest + ) -> ControlResponse: + control = self.ctrl_repo.get_by_control_id(control_id) + if not control: + raise NotFoundError(f"Control {control_id} not found") + + try: + status_enum = ControlStatusEnum(review.status) + except ValueError: + raise ValidationError(f"Invalid status: {review.status}") + + updated = self.ctrl_repo.mark_reviewed( + control.id, status_enum, review.status_notes + ) + self.db.commit() + return _control_to_response(updated) + + def get_controls_by_domain( + self, domain: str + ) -> ControlListResponse: + try: + domain_enum = ControlDomainEnum(domain) + except ValueError: + raise ValidationError(f"Invalid domain: {domain}") + + controls = self.ctrl_repo.get_by_domain(domain_enum) + results = [_control_to_response(c) for c in controls] + return ControlListResponse(controls=results, total=len(results)) + + # ------------------------------------------------------------------ + # Export + # ------------------------------------------------------------------ + + def create_export( + self, + export_type: str, + included_regulations: Any, + included_domains: Any, + date_range_start: Any, + date_range_end: Any, + ) -> dict[str, Any]: + from compliance.services.export_generator import ( + AuditExportGenerator, + ) + + generator = AuditExportGenerator(self.db) + export = generator.create_export( + requested_by="api_user", + export_type=export_type, + included_regulations=included_regulations, + included_domains=included_domains, + date_range_start=date_range_start, + date_range_end=date_range_end, + ) + return self._export_to_dict(export) + + def get_export(self, export_id: str) -> dict[str, Any]: + from compliance.services.export_generator import ( + AuditExportGenerator, + ) + + generator = AuditExportGenerator(self.db) + export = generator.get_export_status(export_id) + if not export: + raise NotFoundError(f"Export {export_id} not found") + return self._export_to_dict(export) + + def download_export(self, export_id: str) -> str: + """Return file path for download, or raise.""" + from compliance.services.export_generator import ( + AuditExportGenerator, + ) + + generator = AuditExportGenerator(self.db) + export = generator.get_export_status(export_id) + if not export: + raise NotFoundError(f"Export {export_id} not found") + + if export.status.value != "completed": + raise ValidationError("Export not completed") + + if not export.file_path or not os.path.exists(export.file_path): + raise NotFoundError("Export file not found") + + return export.file_path + + def list_exports( + self, limit: int, offset: int + ) -> dict[str, Any]: + from compliance.services.export_generator import ( + AuditExportGenerator, + ) + + generator = AuditExportGenerator(self.db) + exports = generator.list_exports(limit, offset) + results = [self._export_to_dict(e) for e in exports] + return {"exports": results, "total": len(results)} + + # ------------------------------------------------------------------ + # Admin / Seeding + # ------------------------------------------------------------------ + + def init_tables(self) -> dict[str, Any]: + from classroom_engine.database import engine + from compliance.db.models import ( + AISystemDB, + AuditExportDB, + ControlMappingDB, + RegulationDB, + RequirementDB, + RiskDB, + ) + + try: + RegulationDB.__table__.create(engine, checkfirst=True) + RequirementDB.__table__.create(engine, checkfirst=True) + ControlDB.__table__.create(engine, checkfirst=True) + ControlMappingDB.__table__.create(engine, checkfirst=True) + EvidenceDB.__table__.create(engine, checkfirst=True) + RiskDB.__table__.create(engine, checkfirst=True) + AuditExportDB.__table__.create(engine, checkfirst=True) + AISystemDB.__table__.create(engine, checkfirst=True) + return { + "success": True, + "message": "Tables created successfully", + } + except Exception as e: + logger.error(f"Table creation failed: {e}") + raise + + def create_indexes(self) -> dict[str, Any]: + from sqlalchemy import text + + indexes = [ + ( + "ix_req_priority_desc", + "CREATE INDEX IF NOT EXISTS ix_req_priority_desc " + "ON compliance_requirements (priority DESC)", + ), + ( + "ix_req_applicable_status", + "CREATE INDEX IF NOT EXISTS ix_req_applicable_status " + "ON compliance_requirements " + "(is_applicable, implementation_status)", + ), + ( + "ix_ctrl_status", + "CREATE INDEX IF NOT EXISTS ix_ctrl_status " + "ON compliance_controls (status)", + ), + ( + "ix_evidence_collected", + "CREATE INDEX IF NOT EXISTS ix_evidence_collected " + "ON compliance_evidence (collected_at DESC)", + ), + ( + "ix_risk_level", + "CREATE INDEX IF NOT EXISTS ix_risk_level " + "ON compliance_risks (inherent_risk)", + ), + ] + + created: list[str] = [] + errors: list[dict[str, str]] = [] + + for idx_name, idx_sql in indexes: + try: + self.db.execute(text(idx_sql)) + self.db.commit() + created.append(idx_name) + except Exception as e: + errors.append({"index": idx_name, "error": str(e)}) + logger.warning( + f"Index creation failed for {idx_name}: {e}" + ) + + return { + "success": len(errors) == 0, + "created": created, + "errors": errors, + "message": ( + f"Created {len(created)} indexes" + + ( + f", {len(errors)} failed" + if errors + else "" + ) + ), + } + + def seed_risks_only(self) -> dict[str, Any]: + from classroom_engine.database import engine + from compliance.db.models import RiskDB + from compliance.services.seeder import ComplianceSeeder + + try: + RiskDB.__table__.create(engine, checkfirst=True) + seeder = ComplianceSeeder(self.db) + count = seeder.seed_risks_only() + return { + "success": True, + "message": f"Successfully seeded {count} risks", + "risks_seeded": count, + } + except Exception as e: + logger.error(f"Risk seeding failed: {e}") + raise + + def seed_database(self, force: bool) -> dict[str, Any]: + from classroom_engine.database import engine + from compliance.db.models import ( + AuditExportDB, + ControlMappingDB, + RegulationDB, + RequirementDB, + RiskDB, + ) + from compliance.services.seeder import ComplianceSeeder + + try: + RegulationDB.__table__.create(engine, checkfirst=True) + RequirementDB.__table__.create(engine, checkfirst=True) + ControlDB.__table__.create(engine, checkfirst=True) + ControlMappingDB.__table__.create(engine, checkfirst=True) + EvidenceDB.__table__.create(engine, checkfirst=True) + RiskDB.__table__.create(engine, checkfirst=True) + AuditExportDB.__table__.create(engine, checkfirst=True) + + seeder = ComplianceSeeder(self.db) + counts = seeder.seed_all(force=force) + return { + "success": True, + "message": "Database seeded successfully", + "counts": counts, + } + except Exception as e: + logger.error(f"Seeding failed: {e}") + raise + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _export_to_dict(export: Any) -> dict[str, Any]: + return { + "id": export.id, + "export_type": export.export_type, + "export_name": export.export_name, + "status": ( + export.status.value if export.status else None + ), + "requested_by": export.requested_by, + "requested_at": export.requested_at, + "completed_at": export.completed_at, + "file_path": export.file_path, + "file_hash": export.file_hash, + "file_size_bytes": export.file_size_bytes, + "total_controls": export.total_controls, + "total_evidence": export.total_evidence, + "compliance_score": export.compliance_score, + "error_message": export.error_message, + } diff --git a/backend-compliance/compliance/services/regulation_requirement_service.py b/backend-compliance/compliance/services/regulation_requirement_service.py new file mode 100644 index 0000000..086d2cb --- /dev/null +++ b/backend-compliance/compliance/services/regulation_requirement_service.py @@ -0,0 +1,410 @@ +# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" +""" +Service for regulation and requirement business logic. + +Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic +for regulations CRUD and requirements CRUD lives here. The route module +delegates to this service and translates domain errors to HTTP responses. +""" + +import logging +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy.orm import Session + +from compliance.db import RegulationRepository, RequirementRepository +from compliance.db.models import RegulationDB, RequirementDB +from compliance.domain import NotFoundError, ValidationError +from compliance.schemas.regulation import ( + RegulationListResponse, + RegulationResponse, +) +from compliance.schemas.requirement import ( + PaginatedRequirementResponse, + RequirementCreate, + RequirementListResponse, + RequirementResponse, +) +from compliance.schemas.common import PaginationMeta + +logger = logging.getLogger(__name__) + + +def _regulation_to_response( + reg: Any, requirement_count: int +) -> RegulationResponse: + return RegulationResponse( + id=reg.id, + code=reg.code, + name=reg.name, + full_name=reg.full_name, + regulation_type=( + reg.regulation_type.value if reg.regulation_type else None + ), + source_url=reg.source_url, + local_pdf_path=reg.local_pdf_path, + effective_date=reg.effective_date, + description=reg.description, + is_active=reg.is_active, + created_at=reg.created_at, + updated_at=reg.updated_at, + requirement_count=requirement_count, + ) + + +def _requirement_to_response(r: Any, code: Optional[str] = None) -> RequirementResponse: + return RequirementResponse( + id=r.id, + regulation_id=r.regulation_id, + regulation_code=code, + article=r.article, + paragraph=r.paragraph, + title=r.title, + description=r.description, + requirement_text=r.requirement_text, + breakpilot_interpretation=r.breakpilot_interpretation, + is_applicable=r.is_applicable, + applicability_reason=r.applicability_reason, + priority=r.priority, + created_at=r.created_at, + updated_at=r.updated_at, + ) + + +class RegulationRequirementService: + """Business logic for regulation and requirement endpoints.""" + + def __init__( + self, + db: Session, + reg_repo_cls: Any = RegulationRepository, + req_repo_cls: Any = RequirementRepository, + ) -> None: + self.db = db + self.reg_repo = reg_repo_cls(db) + self.req_repo = req_repo_cls(db) + + # ------------------------------------------------------------------ + # Regulations + # ------------------------------------------------------------------ + + def list_regulations( + self, + is_active: Optional[bool], + regulation_type: Optional[str], + ) -> RegulationListResponse: + if is_active is not None: + regulations = ( + self.reg_repo.get_active() + if is_active + else self.reg_repo.get_all() + ) + else: + regulations = self.reg_repo.get_all() + + if regulation_type: + from compliance.db.models import RegulationTypeEnum + try: + reg_type = RegulationTypeEnum(regulation_type) + regulations = [ + r for r in regulations if r.regulation_type == reg_type + ] + except ValueError: + pass + + results = [] + for reg in regulations: + reqs = self.req_repo.get_by_regulation(reg.id) + results.append(_regulation_to_response(reg, len(reqs))) + + return RegulationListResponse(regulations=results, total=len(results)) + + def get_regulation(self, code: str) -> RegulationResponse: + regulation = self.reg_repo.get_by_code(code) + if not regulation: + raise NotFoundError(f"Regulation {code} not found") + + reqs = self.req_repo.get_by_regulation(regulation.id) + return _regulation_to_response(regulation, len(reqs)) + + def get_regulation_requirements( + self, + code: str, + is_applicable: Optional[bool], + ) -> RequirementListResponse: + regulation = self.reg_repo.get_by_code(code) + if not regulation: + raise NotFoundError(f"Regulation {code} not found") + + if is_applicable is not None: + requirements = ( + self.req_repo.get_applicable(regulation.id) + if is_applicable + else self.req_repo.get_by_regulation(regulation.id) + ) + else: + requirements = self.req_repo.get_by_regulation(regulation.id) + + results = [_requirement_to_response(r, code) for r in requirements] + return RequirementListResponse( + requirements=results, total=len(results) + ) + + # ------------------------------------------------------------------ + # Requirements + # ------------------------------------------------------------------ + + def get_requirement( + self, + requirement_id: str, + include_legal_context: bool, + ) -> dict[str, Any]: + requirement = ( + self.db.query(RequirementDB) + .filter(RequirementDB.id == requirement_id) + .first() + ) + if not requirement: + raise NotFoundError( + f"Requirement {requirement_id} not found" + ) + + regulation = ( + self.db.query(RegulationDB) + .filter(RegulationDB.id == requirement.regulation_id) + .first() + ) + + result: dict[str, Any] = { + "id": requirement.id, + "regulation_id": requirement.regulation_id, + "regulation_code": regulation.code if regulation else None, + "article": requirement.article, + "paragraph": requirement.paragraph, + "title": requirement.title, + "description": requirement.description, + "requirement_text": requirement.requirement_text, + "breakpilot_interpretation": requirement.breakpilot_interpretation, + "implementation_status": ( + requirement.implementation_status or "not_started" + ), + "implementation_details": requirement.implementation_details, + "code_references": requirement.code_references, + "documentation_links": requirement.documentation_links, + "evidence_description": requirement.evidence_description, + "evidence_artifacts": requirement.evidence_artifacts, + "auditor_notes": requirement.auditor_notes, + "audit_status": requirement.audit_status or "pending", + "last_audit_date": requirement.last_audit_date, + "last_auditor": requirement.last_auditor, + "is_applicable": requirement.is_applicable, + "applicability_reason": requirement.applicability_reason, + "priority": requirement.priority, + "source_page": requirement.source_page, + "source_section": requirement.source_section, + } + + if include_legal_context: + result["legal_context"] = self._fetch_legal_context( + requirement, regulation + ) + + return result + + def list_requirements_paginated( + self, + page: int, + page_size: int, + regulation_code: Optional[str], + status: Optional[str], + is_applicable: Optional[bool], + search: Optional[str], + ) -> PaginatedRequirementResponse: + requirements, total = self.req_repo.get_paginated( + page=page, + page_size=page_size, + regulation_code=regulation_code, + status=status, + is_applicable=is_applicable, + search=search, + ) + + total_pages = (total + page_size - 1) // page_size + + results = [ + RequirementResponse( + id=r.id, + regulation_id=r.regulation_id, + regulation_code=( + r.regulation.code if r.regulation else None + ), + article=r.article, + paragraph=r.paragraph, + title=r.title, + description=r.description, + requirement_text=r.requirement_text, + breakpilot_interpretation=r.breakpilot_interpretation, + is_applicable=r.is_applicable, + applicability_reason=r.applicability_reason, + priority=r.priority, + implementation_status=( + r.implementation_status or "not_started" + ), + implementation_details=r.implementation_details, + code_references=r.code_references, + documentation_links=r.documentation_links, + evidence_description=r.evidence_description, + evidence_artifacts=r.evidence_artifacts, + auditor_notes=r.auditor_notes, + audit_status=r.audit_status or "pending", + last_audit_date=r.last_audit_date, + last_auditor=r.last_auditor, + source_page=r.source_page, + source_section=r.source_section, + created_at=r.created_at, + updated_at=r.updated_at, + ) + for r in requirements + ] + + return PaginatedRequirementResponse( + data=results, + pagination=PaginationMeta( + page=page, + page_size=page_size, + total=total, + total_pages=total_pages, + has_next=page < total_pages, + has_prev=page > 1, + ), + ) + + def create_requirement( + self, data: RequirementCreate + ) -> RequirementResponse: + regulation = self.reg_repo.get_by_id(data.regulation_id) + if not regulation: + raise NotFoundError( + f"Regulation {data.regulation_id} not found" + ) + + requirement = self.req_repo.create( + regulation_id=data.regulation_id, + article=data.article, + title=data.title, + paragraph=data.paragraph, + description=data.description, + requirement_text=data.requirement_text, + breakpilot_interpretation=data.breakpilot_interpretation, + is_applicable=data.is_applicable, + priority=data.priority, + ) + + return _requirement_to_response(requirement, regulation.code) + + def delete_requirement(self, requirement_id: str) -> dict[str, Any]: + deleted = self.req_repo.delete(requirement_id) + if not deleted: + raise NotFoundError( + f"Requirement {requirement_id} not found" + ) + return {"success": True, "message": "Requirement deleted"} + + def update_requirement( + self, requirement_id: str, updates: dict[str, Any] + ) -> dict[str, Any]: + requirement = ( + self.db.query(RequirementDB) + .filter(RequirementDB.id == requirement_id) + .first() + ) + if not requirement: + raise NotFoundError( + f"Requirement {requirement_id} not found" + ) + + allowed_fields = [ + "implementation_status", + "implementation_details", + "code_references", + "documentation_links", + "evidence_description", + "evidence_artifacts", + "auditor_notes", + "audit_status", + "is_applicable", + "applicability_reason", + "breakpilot_interpretation", + ] + + for field in allowed_fields: + if field in updates: + setattr(requirement, field, updates[field]) + + if "audit_status" in updates: + requirement.last_audit_date = datetime.now(timezone.utc) + requirement.last_auditor = updates.get( + "auditor_name", "api_user" + ) + + requirement.updated_at = datetime.now(timezone.utc) + self.db.commit() + self.db.refresh(requirement) + + return {"success": True, "message": "Requirement updated"} + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _fetch_legal_context_async( + self, requirement: Any, regulation: Any + ) -> list[dict[str, Any]]: + """Async version for RAG legal context.""" + return self._fetch_legal_context(requirement, regulation) + + def _fetch_legal_context( + self, requirement: Any, regulation: Any + ) -> list[dict[str, Any]]: + try: + from compliance.services.rag_client import get_rag_client + from compliance.services.ai_compliance_assistant import ( + AIComplianceAssistant, + ) + import asyncio + + rag = get_rag_client() + assistant = AIComplianceAssistant() + query = f"{requirement.title} {requirement.article or ''}" + collection = assistant._collection_for_regulation( + regulation.code if regulation else "" + ) + # This is called from an async context but the method is sync + # We need to handle the async search call + loop = asyncio.get_event_loop() + if loop.is_running(): + # We're in an async context, return empty and let + # the route handler deal with it + return [] + rag_results = loop.run_until_complete( + rag.search(query, collection=collection, top_k=3) + ) + return [ + { + "text": r.text, + "regulation_code": r.regulation_code, + "regulation_short": r.regulation_short, + "article": r.article, + "score": r.score, + "source_url": r.source_url, + } + for r in rag_results + ] + except Exception as e: + logger.warning( + "Failed to fetch legal context for %s: %s", + requirement.id, + e, + ) + return []