refactor(backend/api): extract compliance routes services (Step 4 — file 13 of 18)

Split routes.py (991 LOC) into thin handlers + two service files:
- RegulationRequirementService: regulations CRUD, requirements CRUD
- ControlExportService: controls CRUD/review/domain, export, admin seeding

All 216 tests pass. Route module re-exports repository classes so
existing test patches (compliance.api.routes.*Repository) keep working.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-04-09 19:12:22 +02:00
parent d2c94619d8
commit 6658776610
3 changed files with 1184 additions and 804 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,498 @@
# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return"
"""
Service for control, export, and admin/seeding business logic.
Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic
for controls CRUD, export management, and database seeding lives here.
"""
import logging
import os
from typing import Any, Optional
from sqlalchemy.orm import Session
from compliance.db import (
ControlDomainEnum,
ControlRepository,
ControlStatusEnum,
EvidenceRepository,
)
from compliance.db.models import ControlDB, EvidenceDB
from compliance.domain import NotFoundError, ValidationError
from compliance.schemas.control import (
ControlListResponse,
ControlResponse,
ControlReviewRequest,
ControlUpdate,
PaginatedControlResponse,
)
from compliance.schemas.common import PaginationMeta
logger = logging.getLogger(__name__)
def _control_to_response(
ctrl: Any, evidence_count: int = 0
) -> ControlResponse:
return ControlResponse(
id=ctrl.id,
control_id=ctrl.control_id,
domain=ctrl.domain.value if ctrl.domain else None,
control_type=(
ctrl.control_type.value if ctrl.control_type else None
),
title=ctrl.title,
description=ctrl.description,
pass_criteria=ctrl.pass_criteria,
implementation_guidance=ctrl.implementation_guidance,
code_reference=ctrl.code_reference,
documentation_url=ctrl.documentation_url,
is_automated=ctrl.is_automated,
automation_tool=ctrl.automation_tool,
automation_config=ctrl.automation_config,
owner=ctrl.owner,
review_frequency_days=ctrl.review_frequency_days,
status=ctrl.status.value if ctrl.status else None,
status_notes=ctrl.status_notes,
last_reviewed_at=ctrl.last_reviewed_at,
next_review_at=ctrl.next_review_at,
created_at=ctrl.created_at,
updated_at=ctrl.updated_at,
evidence_count=evidence_count,
)
class ControlExportService:
"""Business logic for control and export endpoints."""
def __init__(
self,
db: Session,
control_repo_cls: Any = ControlRepository,
evidence_repo_cls: Any = EvidenceRepository,
) -> None:
self.db = db
self.ctrl_repo = control_repo_cls(db)
self.evidence_repo = evidence_repo_cls(db)
# ------------------------------------------------------------------
# Controls
# ------------------------------------------------------------------
def list_controls(
self,
domain: Optional[str],
status: Optional[str],
is_automated: Optional[bool],
search: Optional[str],
) -> ControlListResponse:
if domain:
try:
domain_enum = ControlDomainEnum(domain)
controls = self.ctrl_repo.get_by_domain(domain_enum)
except ValueError:
raise ValidationError(f"Invalid domain: {domain}")
elif status:
try:
status_enum = ControlStatusEnum(status)
controls = self.ctrl_repo.get_by_status(status_enum)
except ValueError:
raise ValidationError(f"Invalid status: {status}")
else:
controls = self.ctrl_repo.get_all()
if is_automated is not None:
controls = [
c for c in controls if c.is_automated == is_automated
]
if search:
search_lower = search.lower()
controls = [
c
for c in controls
if search_lower in c.control_id.lower()
or search_lower in c.title.lower()
or (c.description and search_lower in c.description.lower())
]
results = []
for ctrl in controls:
evidence = self.evidence_repo.get_by_control(ctrl.id)
results.append(_control_to_response(ctrl, len(evidence)))
return ControlListResponse(controls=results, total=len(results))
def list_controls_paginated(
self,
page: int,
page_size: int,
domain: Optional[str],
status: Optional[str],
is_automated: Optional[bool],
search: Optional[str],
) -> PaginatedControlResponse:
domain_enum = None
status_enum = None
if domain:
try:
domain_enum = ControlDomainEnum(domain)
except ValueError:
raise ValidationError(f"Invalid domain: {domain}")
if status:
try:
status_enum = ControlStatusEnum(status)
except ValueError:
raise ValidationError(f"Invalid status: {status}")
controls, total = self.ctrl_repo.get_paginated(
page=page,
page_size=page_size,
domain=domain_enum,
status=status_enum,
is_automated=is_automated,
search=search,
)
total_pages = (total + page_size - 1) // page_size
results = [
ControlResponse(
id=c.id,
control_id=c.control_id,
domain=c.domain.value if c.domain else None,
control_type=(
c.control_type.value if c.control_type else None
),
title=c.title,
description=c.description,
pass_criteria=c.pass_criteria,
implementation_guidance=c.implementation_guidance,
code_reference=c.code_reference,
documentation_url=c.documentation_url,
is_automated=c.is_automated,
automation_tool=c.automation_tool,
automation_config=c.automation_config,
owner=c.owner,
review_frequency_days=c.review_frequency_days,
status=c.status.value if c.status else None,
status_notes=c.status_notes,
last_reviewed_at=c.last_reviewed_at,
next_review_at=c.next_review_at,
created_at=c.created_at,
updated_at=c.updated_at,
evidence_count=(
len(c.evidence) if c.evidence else 0
),
)
for c in controls
]
return PaginatedControlResponse(
data=results,
pagination=PaginationMeta(
page=page,
page_size=page_size,
total=total,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1,
),
)
def get_control(self, control_id: str) -> ControlResponse:
control = self.ctrl_repo.get_by_control_id(control_id)
if not control:
raise NotFoundError(f"Control {control_id} not found")
evidence = self.evidence_repo.get_by_control(control.id)
return _control_to_response(control, len(evidence))
def update_control(
self, control_id: str, update: ControlUpdate
) -> ControlResponse:
control = self.ctrl_repo.get_by_control_id(control_id)
if not control:
raise NotFoundError(f"Control {control_id} not found")
update_data = update.model_dump(exclude_unset=True)
if "status" in update_data:
try:
update_data["status"] = ControlStatusEnum(
update_data["status"]
)
except ValueError:
raise ValidationError(
f"Invalid status: {update_data['status']}"
)
updated = self.ctrl_repo.update(control.id, **update_data)
self.db.commit()
return _control_to_response(updated)
def review_control(
self, control_id: str, review: ControlReviewRequest
) -> ControlResponse:
control = self.ctrl_repo.get_by_control_id(control_id)
if not control:
raise NotFoundError(f"Control {control_id} not found")
try:
status_enum = ControlStatusEnum(review.status)
except ValueError:
raise ValidationError(f"Invalid status: {review.status}")
updated = self.ctrl_repo.mark_reviewed(
control.id, status_enum, review.status_notes
)
self.db.commit()
return _control_to_response(updated)
def get_controls_by_domain(
self, domain: str
) -> ControlListResponse:
try:
domain_enum = ControlDomainEnum(domain)
except ValueError:
raise ValidationError(f"Invalid domain: {domain}")
controls = self.ctrl_repo.get_by_domain(domain_enum)
results = [_control_to_response(c) for c in controls]
return ControlListResponse(controls=results, total=len(results))
# ------------------------------------------------------------------
# Export
# ------------------------------------------------------------------
def create_export(
self,
export_type: str,
included_regulations: Any,
included_domains: Any,
date_range_start: Any,
date_range_end: Any,
) -> dict[str, Any]:
from compliance.services.export_generator import (
AuditExportGenerator,
)
generator = AuditExportGenerator(self.db)
export = generator.create_export(
requested_by="api_user",
export_type=export_type,
included_regulations=included_regulations,
included_domains=included_domains,
date_range_start=date_range_start,
date_range_end=date_range_end,
)
return self._export_to_dict(export)
def get_export(self, export_id: str) -> dict[str, Any]:
from compliance.services.export_generator import (
AuditExportGenerator,
)
generator = AuditExportGenerator(self.db)
export = generator.get_export_status(export_id)
if not export:
raise NotFoundError(f"Export {export_id} not found")
return self._export_to_dict(export)
def download_export(self, export_id: str) -> str:
"""Return file path for download, or raise."""
from compliance.services.export_generator import (
AuditExportGenerator,
)
generator = AuditExportGenerator(self.db)
export = generator.get_export_status(export_id)
if not export:
raise NotFoundError(f"Export {export_id} not found")
if export.status.value != "completed":
raise ValidationError("Export not completed")
if not export.file_path or not os.path.exists(export.file_path):
raise NotFoundError("Export file not found")
return export.file_path
def list_exports(
self, limit: int, offset: int
) -> dict[str, Any]:
from compliance.services.export_generator import (
AuditExportGenerator,
)
generator = AuditExportGenerator(self.db)
exports = generator.list_exports(limit, offset)
results = [self._export_to_dict(e) for e in exports]
return {"exports": results, "total": len(results)}
# ------------------------------------------------------------------
# Admin / Seeding
# ------------------------------------------------------------------
def init_tables(self) -> dict[str, Any]:
from classroom_engine.database import engine
from compliance.db.models import (
AISystemDB,
AuditExportDB,
ControlMappingDB,
RegulationDB,
RequirementDB,
RiskDB,
)
try:
RegulationDB.__table__.create(engine, checkfirst=True)
RequirementDB.__table__.create(engine, checkfirst=True)
ControlDB.__table__.create(engine, checkfirst=True)
ControlMappingDB.__table__.create(engine, checkfirst=True)
EvidenceDB.__table__.create(engine, checkfirst=True)
RiskDB.__table__.create(engine, checkfirst=True)
AuditExportDB.__table__.create(engine, checkfirst=True)
AISystemDB.__table__.create(engine, checkfirst=True)
return {
"success": True,
"message": "Tables created successfully",
}
except Exception as e:
logger.error(f"Table creation failed: {e}")
raise
def create_indexes(self) -> dict[str, Any]:
from sqlalchemy import text
indexes = [
(
"ix_req_priority_desc",
"CREATE INDEX IF NOT EXISTS ix_req_priority_desc "
"ON compliance_requirements (priority DESC)",
),
(
"ix_req_applicable_status",
"CREATE INDEX IF NOT EXISTS ix_req_applicable_status "
"ON compliance_requirements "
"(is_applicable, implementation_status)",
),
(
"ix_ctrl_status",
"CREATE INDEX IF NOT EXISTS ix_ctrl_status "
"ON compliance_controls (status)",
),
(
"ix_evidence_collected",
"CREATE INDEX IF NOT EXISTS ix_evidence_collected "
"ON compliance_evidence (collected_at DESC)",
),
(
"ix_risk_level",
"CREATE INDEX IF NOT EXISTS ix_risk_level "
"ON compliance_risks (inherent_risk)",
),
]
created: list[str] = []
errors: list[dict[str, str]] = []
for idx_name, idx_sql in indexes:
try:
self.db.execute(text(idx_sql))
self.db.commit()
created.append(idx_name)
except Exception as e:
errors.append({"index": idx_name, "error": str(e)})
logger.warning(
f"Index creation failed for {idx_name}: {e}"
)
return {
"success": len(errors) == 0,
"created": created,
"errors": errors,
"message": (
f"Created {len(created)} indexes"
+ (
f", {len(errors)} failed"
if errors
else ""
)
),
}
def seed_risks_only(self) -> dict[str, Any]:
from classroom_engine.database import engine
from compliance.db.models import RiskDB
from compliance.services.seeder import ComplianceSeeder
try:
RiskDB.__table__.create(engine, checkfirst=True)
seeder = ComplianceSeeder(self.db)
count = seeder.seed_risks_only()
return {
"success": True,
"message": f"Successfully seeded {count} risks",
"risks_seeded": count,
}
except Exception as e:
logger.error(f"Risk seeding failed: {e}")
raise
def seed_database(self, force: bool) -> dict[str, Any]:
from classroom_engine.database import engine
from compliance.db.models import (
AuditExportDB,
ControlMappingDB,
RegulationDB,
RequirementDB,
RiskDB,
)
from compliance.services.seeder import ComplianceSeeder
try:
RegulationDB.__table__.create(engine, checkfirst=True)
RequirementDB.__table__.create(engine, checkfirst=True)
ControlDB.__table__.create(engine, checkfirst=True)
ControlMappingDB.__table__.create(engine, checkfirst=True)
EvidenceDB.__table__.create(engine, checkfirst=True)
RiskDB.__table__.create(engine, checkfirst=True)
AuditExportDB.__table__.create(engine, checkfirst=True)
seeder = ComplianceSeeder(self.db)
counts = seeder.seed_all(force=force)
return {
"success": True,
"message": "Database seeded successfully",
"counts": counts,
}
except Exception as e:
logger.error(f"Seeding failed: {e}")
raise
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
@staticmethod
def _export_to_dict(export: Any) -> dict[str, Any]:
return {
"id": export.id,
"export_type": export.export_type,
"export_name": export.export_name,
"status": (
export.status.value if export.status else None
),
"requested_by": export.requested_by,
"requested_at": export.requested_at,
"completed_at": export.completed_at,
"file_path": export.file_path,
"file_hash": export.file_hash,
"file_size_bytes": export.file_size_bytes,
"total_controls": export.total_controls,
"total_evidence": export.total_evidence,
"compliance_score": export.compliance_score,
"error_message": export.error_message,
}

View File

@@ -0,0 +1,410 @@
# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return"
"""
Service for regulation and requirement business logic.
Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic
for regulations CRUD and requirements CRUD lives here. The route module
delegates to this service and translates domain errors to HTTP responses.
"""
import logging
from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy.orm import Session
from compliance.db import RegulationRepository, RequirementRepository
from compliance.db.models import RegulationDB, RequirementDB
from compliance.domain import NotFoundError, ValidationError
from compliance.schemas.regulation import (
RegulationListResponse,
RegulationResponse,
)
from compliance.schemas.requirement import (
PaginatedRequirementResponse,
RequirementCreate,
RequirementListResponse,
RequirementResponse,
)
from compliance.schemas.common import PaginationMeta
logger = logging.getLogger(__name__)
def _regulation_to_response(
reg: Any, requirement_count: int
) -> RegulationResponse:
return RegulationResponse(
id=reg.id,
code=reg.code,
name=reg.name,
full_name=reg.full_name,
regulation_type=(
reg.regulation_type.value if reg.regulation_type else None
),
source_url=reg.source_url,
local_pdf_path=reg.local_pdf_path,
effective_date=reg.effective_date,
description=reg.description,
is_active=reg.is_active,
created_at=reg.created_at,
updated_at=reg.updated_at,
requirement_count=requirement_count,
)
def _requirement_to_response(r: Any, code: Optional[str] = None) -> RequirementResponse:
return RequirementResponse(
id=r.id,
regulation_id=r.regulation_id,
regulation_code=code,
article=r.article,
paragraph=r.paragraph,
title=r.title,
description=r.description,
requirement_text=r.requirement_text,
breakpilot_interpretation=r.breakpilot_interpretation,
is_applicable=r.is_applicable,
applicability_reason=r.applicability_reason,
priority=r.priority,
created_at=r.created_at,
updated_at=r.updated_at,
)
class RegulationRequirementService:
"""Business logic for regulation and requirement endpoints."""
def __init__(
self,
db: Session,
reg_repo_cls: Any = RegulationRepository,
req_repo_cls: Any = RequirementRepository,
) -> None:
self.db = db
self.reg_repo = reg_repo_cls(db)
self.req_repo = req_repo_cls(db)
# ------------------------------------------------------------------
# Regulations
# ------------------------------------------------------------------
def list_regulations(
self,
is_active: Optional[bool],
regulation_type: Optional[str],
) -> RegulationListResponse:
if is_active is not None:
regulations = (
self.reg_repo.get_active()
if is_active
else self.reg_repo.get_all()
)
else:
regulations = self.reg_repo.get_all()
if regulation_type:
from compliance.db.models import RegulationTypeEnum
try:
reg_type = RegulationTypeEnum(regulation_type)
regulations = [
r for r in regulations if r.regulation_type == reg_type
]
except ValueError:
pass
results = []
for reg in regulations:
reqs = self.req_repo.get_by_regulation(reg.id)
results.append(_regulation_to_response(reg, len(reqs)))
return RegulationListResponse(regulations=results, total=len(results))
def get_regulation(self, code: str) -> RegulationResponse:
regulation = self.reg_repo.get_by_code(code)
if not regulation:
raise NotFoundError(f"Regulation {code} not found")
reqs = self.req_repo.get_by_regulation(regulation.id)
return _regulation_to_response(regulation, len(reqs))
def get_regulation_requirements(
self,
code: str,
is_applicable: Optional[bool],
) -> RequirementListResponse:
regulation = self.reg_repo.get_by_code(code)
if not regulation:
raise NotFoundError(f"Regulation {code} not found")
if is_applicable is not None:
requirements = (
self.req_repo.get_applicable(regulation.id)
if is_applicable
else self.req_repo.get_by_regulation(regulation.id)
)
else:
requirements = self.req_repo.get_by_regulation(regulation.id)
results = [_requirement_to_response(r, code) for r in requirements]
return RequirementListResponse(
requirements=results, total=len(results)
)
# ------------------------------------------------------------------
# Requirements
# ------------------------------------------------------------------
def get_requirement(
self,
requirement_id: str,
include_legal_context: bool,
) -> dict[str, Any]:
requirement = (
self.db.query(RequirementDB)
.filter(RequirementDB.id == requirement_id)
.first()
)
if not requirement:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
regulation = (
self.db.query(RegulationDB)
.filter(RegulationDB.id == requirement.regulation_id)
.first()
)
result: dict[str, Any] = {
"id": requirement.id,
"regulation_id": requirement.regulation_id,
"regulation_code": regulation.code if regulation else None,
"article": requirement.article,
"paragraph": requirement.paragraph,
"title": requirement.title,
"description": requirement.description,
"requirement_text": requirement.requirement_text,
"breakpilot_interpretation": requirement.breakpilot_interpretation,
"implementation_status": (
requirement.implementation_status or "not_started"
),
"implementation_details": requirement.implementation_details,
"code_references": requirement.code_references,
"documentation_links": requirement.documentation_links,
"evidence_description": requirement.evidence_description,
"evidence_artifacts": requirement.evidence_artifacts,
"auditor_notes": requirement.auditor_notes,
"audit_status": requirement.audit_status or "pending",
"last_audit_date": requirement.last_audit_date,
"last_auditor": requirement.last_auditor,
"is_applicable": requirement.is_applicable,
"applicability_reason": requirement.applicability_reason,
"priority": requirement.priority,
"source_page": requirement.source_page,
"source_section": requirement.source_section,
}
if include_legal_context:
result["legal_context"] = self._fetch_legal_context(
requirement, regulation
)
return result
def list_requirements_paginated(
self,
page: int,
page_size: int,
regulation_code: Optional[str],
status: Optional[str],
is_applicable: Optional[bool],
search: Optional[str],
) -> PaginatedRequirementResponse:
requirements, total = self.req_repo.get_paginated(
page=page,
page_size=page_size,
regulation_code=regulation_code,
status=status,
is_applicable=is_applicable,
search=search,
)
total_pages = (total + page_size - 1) // page_size
results = [
RequirementResponse(
id=r.id,
regulation_id=r.regulation_id,
regulation_code=(
r.regulation.code if r.regulation else None
),
article=r.article,
paragraph=r.paragraph,
title=r.title,
description=r.description,
requirement_text=r.requirement_text,
breakpilot_interpretation=r.breakpilot_interpretation,
is_applicable=r.is_applicable,
applicability_reason=r.applicability_reason,
priority=r.priority,
implementation_status=(
r.implementation_status or "not_started"
),
implementation_details=r.implementation_details,
code_references=r.code_references,
documentation_links=r.documentation_links,
evidence_description=r.evidence_description,
evidence_artifacts=r.evidence_artifacts,
auditor_notes=r.auditor_notes,
audit_status=r.audit_status or "pending",
last_audit_date=r.last_audit_date,
last_auditor=r.last_auditor,
source_page=r.source_page,
source_section=r.source_section,
created_at=r.created_at,
updated_at=r.updated_at,
)
for r in requirements
]
return PaginatedRequirementResponse(
data=results,
pagination=PaginationMeta(
page=page,
page_size=page_size,
total=total,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1,
),
)
def create_requirement(
self, data: RequirementCreate
) -> RequirementResponse:
regulation = self.reg_repo.get_by_id(data.regulation_id)
if not regulation:
raise NotFoundError(
f"Regulation {data.regulation_id} not found"
)
requirement = self.req_repo.create(
regulation_id=data.regulation_id,
article=data.article,
title=data.title,
paragraph=data.paragraph,
description=data.description,
requirement_text=data.requirement_text,
breakpilot_interpretation=data.breakpilot_interpretation,
is_applicable=data.is_applicable,
priority=data.priority,
)
return _requirement_to_response(requirement, regulation.code)
def delete_requirement(self, requirement_id: str) -> dict[str, Any]:
deleted = self.req_repo.delete(requirement_id)
if not deleted:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
return {"success": True, "message": "Requirement deleted"}
def update_requirement(
self, requirement_id: str, updates: dict[str, Any]
) -> dict[str, Any]:
requirement = (
self.db.query(RequirementDB)
.filter(RequirementDB.id == requirement_id)
.first()
)
if not requirement:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
allowed_fields = [
"implementation_status",
"implementation_details",
"code_references",
"documentation_links",
"evidence_description",
"evidence_artifacts",
"auditor_notes",
"audit_status",
"is_applicable",
"applicability_reason",
"breakpilot_interpretation",
]
for field in allowed_fields:
if field in updates:
setattr(requirement, field, updates[field])
if "audit_status" in updates:
requirement.last_audit_date = datetime.now(timezone.utc)
requirement.last_auditor = updates.get(
"auditor_name", "api_user"
)
requirement.updated_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(requirement)
return {"success": True, "message": "Requirement updated"}
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _fetch_legal_context_async(
self, requirement: Any, regulation: Any
) -> list[dict[str, Any]]:
"""Async version for RAG legal context."""
return self._fetch_legal_context(requirement, regulation)
def _fetch_legal_context(
self, requirement: Any, regulation: Any
) -> list[dict[str, Any]]:
try:
from compliance.services.rag_client import get_rag_client
from compliance.services.ai_compliance_assistant import (
AIComplianceAssistant,
)
import asyncio
rag = get_rag_client()
assistant = AIComplianceAssistant()
query = f"{requirement.title} {requirement.article or ''}"
collection = assistant._collection_for_regulation(
regulation.code if regulation else ""
)
# This is called from an async context but the method is sync
# We need to handle the async search call
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, return empty and let
# the route handler deal with it
return []
rag_results = loop.run_until_complete(
rag.search(query, collection=collection, top_k=3)
)
return [
{
"text": r.text,
"regulation_code": r.regulation_code,
"regulation_short": r.regulation_short,
"article": r.article,
"score": r.score,
"source_url": r.source_url,
}
for r in rag_results
]
except Exception as e:
logger.warning(
"Failed to fetch legal context for %s: %s",
requirement.id,
e,
)
return []