# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" """ Service for control, export, and admin/seeding business logic. Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic for controls CRUD, export management, and database seeding lives here. """ import logging import os from typing import Any, Optional from sqlalchemy.orm import Session from compliance.db import ( ControlDomainEnum, ControlRepository, ControlStatusEnum, EvidenceRepository, ) from compliance.db.models import ControlDB, EvidenceDB from compliance.domain import NotFoundError, ValidationError from compliance.schemas.control import ( ControlListResponse, ControlResponse, ControlReviewRequest, ControlUpdate, PaginatedControlResponse, ) from compliance.schemas.common import PaginationMeta logger = logging.getLogger(__name__) def _control_to_response( ctrl: Any, evidence_count: int = 0 ) -> ControlResponse: return ControlResponse( id=ctrl.id, control_id=ctrl.control_id, domain=ctrl.domain.value if ctrl.domain else None, control_type=( ctrl.control_type.value if ctrl.control_type else None ), title=ctrl.title, description=ctrl.description, pass_criteria=ctrl.pass_criteria, implementation_guidance=ctrl.implementation_guidance, code_reference=ctrl.code_reference, documentation_url=ctrl.documentation_url, is_automated=ctrl.is_automated, automation_tool=ctrl.automation_tool, automation_config=ctrl.automation_config, owner=ctrl.owner, review_frequency_days=ctrl.review_frequency_days, status=ctrl.status.value if ctrl.status else None, status_notes=ctrl.status_notes, last_reviewed_at=ctrl.last_reviewed_at, next_review_at=ctrl.next_review_at, created_at=ctrl.created_at, updated_at=ctrl.updated_at, evidence_count=evidence_count, ) class ControlExportService: """Business logic for control and export endpoints.""" def __init__( self, db: Session, control_repo_cls: Any = ControlRepository, evidence_repo_cls: Any = EvidenceRepository, ) -> None: self.db = db self.ctrl_repo = control_repo_cls(db) self.evidence_repo = evidence_repo_cls(db) # ------------------------------------------------------------------ # Controls # ------------------------------------------------------------------ def list_controls( self, domain: Optional[str], status: Optional[str], is_automated: Optional[bool], search: Optional[str], ) -> ControlListResponse: if domain: try: domain_enum = ControlDomainEnum(domain) controls = self.ctrl_repo.get_by_domain(domain_enum) except ValueError: raise ValidationError(f"Invalid domain: {domain}") elif status: try: status_enum = ControlStatusEnum(status) controls = self.ctrl_repo.get_by_status(status_enum) except ValueError: raise ValidationError(f"Invalid status: {status}") else: controls = self.ctrl_repo.get_all() if is_automated is not None: controls = [ c for c in controls if c.is_automated == is_automated ] if search: search_lower = search.lower() controls = [ c for c in controls if search_lower in c.control_id.lower() or search_lower in c.title.lower() or (c.description and search_lower in c.description.lower()) ] results = [] for ctrl in controls: evidence = self.evidence_repo.get_by_control(ctrl.id) results.append(_control_to_response(ctrl, len(evidence))) return ControlListResponse(controls=results, total=len(results)) def list_controls_paginated( self, page: int, page_size: int, domain: Optional[str], status: Optional[str], is_automated: Optional[bool], search: Optional[str], ) -> PaginatedControlResponse: domain_enum = None status_enum = None if domain: try: domain_enum = ControlDomainEnum(domain) except ValueError: raise ValidationError(f"Invalid domain: {domain}") if status: try: status_enum = ControlStatusEnum(status) except ValueError: raise ValidationError(f"Invalid status: {status}") controls, total = self.ctrl_repo.get_paginated( page=page, page_size=page_size, domain=domain_enum, status=status_enum, is_automated=is_automated, search=search, ) total_pages = (total + page_size - 1) // page_size results = [ ControlResponse( id=c.id, control_id=c.control_id, domain=c.domain.value if c.domain else None, control_type=( c.control_type.value if c.control_type else None ), title=c.title, description=c.description, pass_criteria=c.pass_criteria, implementation_guidance=c.implementation_guidance, code_reference=c.code_reference, documentation_url=c.documentation_url, is_automated=c.is_automated, automation_tool=c.automation_tool, automation_config=c.automation_config, owner=c.owner, review_frequency_days=c.review_frequency_days, status=c.status.value if c.status else None, status_notes=c.status_notes, last_reviewed_at=c.last_reviewed_at, next_review_at=c.next_review_at, created_at=c.created_at, updated_at=c.updated_at, evidence_count=( len(c.evidence) if c.evidence else 0 ), ) for c in controls ] return PaginatedControlResponse( data=results, pagination=PaginationMeta( page=page, page_size=page_size, total=total, total_pages=total_pages, has_next=page < total_pages, has_prev=page > 1, ), ) def get_control(self, control_id: str) -> ControlResponse: control = self.ctrl_repo.get_by_control_id(control_id) if not control: raise NotFoundError(f"Control {control_id} not found") evidence = self.evidence_repo.get_by_control(control.id) return _control_to_response(control, len(evidence)) def update_control( self, control_id: str, update: ControlUpdate ) -> ControlResponse: control = self.ctrl_repo.get_by_control_id(control_id) if not control: raise NotFoundError(f"Control {control_id} not found") update_data = update.model_dump(exclude_unset=True) if "status" in update_data: try: update_data["status"] = ControlStatusEnum( update_data["status"] ) except ValueError: raise ValidationError( f"Invalid status: {update_data['status']}" ) updated = self.ctrl_repo.update(control.id, **update_data) self.db.commit() return _control_to_response(updated) def review_control( self, control_id: str, review: ControlReviewRequest ) -> ControlResponse: control = self.ctrl_repo.get_by_control_id(control_id) if not control: raise NotFoundError(f"Control {control_id} not found") try: status_enum = ControlStatusEnum(review.status) except ValueError: raise ValidationError(f"Invalid status: {review.status}") updated = self.ctrl_repo.mark_reviewed( control.id, status_enum, review.status_notes ) self.db.commit() return _control_to_response(updated) def get_controls_by_domain( self, domain: str ) -> ControlListResponse: try: domain_enum = ControlDomainEnum(domain) except ValueError: raise ValidationError(f"Invalid domain: {domain}") controls = self.ctrl_repo.get_by_domain(domain_enum) results = [_control_to_response(c) for c in controls] return ControlListResponse(controls=results, total=len(results)) # ------------------------------------------------------------------ # Export # ------------------------------------------------------------------ def create_export( self, export_type: str, included_regulations: Any, included_domains: Any, date_range_start: Any, date_range_end: Any, ) -> dict[str, Any]: from compliance.services.export_generator import ( AuditExportGenerator, ) generator = AuditExportGenerator(self.db) export = generator.create_export( requested_by="api_user", export_type=export_type, included_regulations=included_regulations, included_domains=included_domains, date_range_start=date_range_start, date_range_end=date_range_end, ) return self._export_to_dict(export) def get_export(self, export_id: str) -> dict[str, Any]: from compliance.services.export_generator import ( AuditExportGenerator, ) generator = AuditExportGenerator(self.db) export = generator.get_export_status(export_id) if not export: raise NotFoundError(f"Export {export_id} not found") return self._export_to_dict(export) def download_export(self, export_id: str) -> str: """Return file path for download, or raise.""" from compliance.services.export_generator import ( AuditExportGenerator, ) generator = AuditExportGenerator(self.db) export = generator.get_export_status(export_id) if not export: raise NotFoundError(f"Export {export_id} not found") if export.status.value != "completed": raise ValidationError("Export not completed") if not export.file_path or not os.path.exists(export.file_path): raise NotFoundError("Export file not found") return export.file_path def list_exports( self, limit: int, offset: int ) -> dict[str, Any]: from compliance.services.export_generator import ( AuditExportGenerator, ) generator = AuditExportGenerator(self.db) exports = generator.list_exports(limit, offset) results = [self._export_to_dict(e) for e in exports] return {"exports": results, "total": len(results)} # ------------------------------------------------------------------ # Admin / Seeding # ------------------------------------------------------------------ def init_tables(self) -> dict[str, Any]: from classroom_engine.database import engine from compliance.db.models import ( AISystemDB, AuditExportDB, ControlMappingDB, RegulationDB, RequirementDB, RiskDB, ) try: RegulationDB.__table__.create(engine, checkfirst=True) RequirementDB.__table__.create(engine, checkfirst=True) ControlDB.__table__.create(engine, checkfirst=True) ControlMappingDB.__table__.create(engine, checkfirst=True) EvidenceDB.__table__.create(engine, checkfirst=True) RiskDB.__table__.create(engine, checkfirst=True) AuditExportDB.__table__.create(engine, checkfirst=True) AISystemDB.__table__.create(engine, checkfirst=True) return { "success": True, "message": "Tables created successfully", } except Exception as e: logger.error(f"Table creation failed: {e}") raise def create_indexes(self) -> dict[str, Any]: from sqlalchemy import text indexes = [ ( "ix_req_priority_desc", "CREATE INDEX IF NOT EXISTS ix_req_priority_desc " "ON compliance_requirements (priority DESC)", ), ( "ix_req_applicable_status", "CREATE INDEX IF NOT EXISTS ix_req_applicable_status " "ON compliance_requirements " "(is_applicable, implementation_status)", ), ( "ix_ctrl_status", "CREATE INDEX IF NOT EXISTS ix_ctrl_status " "ON compliance_controls (status)", ), ( "ix_evidence_collected", "CREATE INDEX IF NOT EXISTS ix_evidence_collected " "ON compliance_evidence (collected_at DESC)", ), ( "ix_risk_level", "CREATE INDEX IF NOT EXISTS ix_risk_level " "ON compliance_risks (inherent_risk)", ), ] created: list[str] = [] errors: list[dict[str, str]] = [] for idx_name, idx_sql in indexes: try: self.db.execute(text(idx_sql)) self.db.commit() created.append(idx_name) except Exception as e: errors.append({"index": idx_name, "error": str(e)}) logger.warning( f"Index creation failed for {idx_name}: {e}" ) return { "success": len(errors) == 0, "created": created, "errors": errors, "message": ( f"Created {len(created)} indexes" + ( f", {len(errors)} failed" if errors else "" ) ), } def seed_risks_only(self) -> dict[str, Any]: from classroom_engine.database import engine from compliance.db.models import RiskDB from compliance.services.seeder import ComplianceSeeder try: RiskDB.__table__.create(engine, checkfirst=True) seeder = ComplianceSeeder(self.db) count = seeder.seed_risks_only() return { "success": True, "message": f"Successfully seeded {count} risks", "risks_seeded": count, } except Exception as e: logger.error(f"Risk seeding failed: {e}") raise def seed_database(self, force: bool) -> dict[str, Any]: from classroom_engine.database import engine from compliance.db.models import ( AuditExportDB, ControlMappingDB, RegulationDB, RequirementDB, RiskDB, ) from compliance.services.seeder import ComplianceSeeder try: RegulationDB.__table__.create(engine, checkfirst=True) RequirementDB.__table__.create(engine, checkfirst=True) ControlDB.__table__.create(engine, checkfirst=True) ControlMappingDB.__table__.create(engine, checkfirst=True) EvidenceDB.__table__.create(engine, checkfirst=True) RiskDB.__table__.create(engine, checkfirst=True) AuditExportDB.__table__.create(engine, checkfirst=True) seeder = ComplianceSeeder(self.db) counts = seeder.seed_all(force=force) return { "success": True, "message": "Database seeded successfully", "counts": counts, } except Exception as e: logger.error(f"Seeding failed: {e}") raise # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ @staticmethod def _export_to_dict(export: Any) -> dict[str, Any]: return { "id": export.id, "export_type": export.export_type, "export_name": export.export_name, "status": ( export.status.value if export.status else None ), "requested_by": export.requested_by, "requested_at": export.requested_at, "completed_at": export.completed_at, "file_path": export.file_path, "file_hash": export.file_hash, "file_size_bytes": export.file_size_bytes, "total_controls": export.total_controls, "total_evidence": export.total_evidence, "compliance_score": export.compliance_score, "error_message": export.error_message, }