refactor(backend/api): extract TOMService (Step 4 — file 3 of 18)

compliance/api/tom_routes.py (609 LOC) -> 215 LOC thin routes +
434-line TOMService. Request bodies (TOMStateBody, TOMMeasureCreate,
TOMMeasureUpdate, TOMMeasureBulkItem, TOMMeasureBulkBody) moved to
compliance/schemas/tom.py (joining the existing response models from
the Step 3 split).

Single-service split (not two like banner): state, measures CRUD + bulk
upsert, stats, export, and version lookups are all tightly coupled
around the TOMMeasureDB aggregate, so splitting would create artificial
boundaries. TOMService is 434 LOC — comfortably under the 500 hard cap.

Domain error mapping:
  - ConflictError   -> 409 (version conflict on state save; duplicate control_id on create)
  - NotFoundError   -> 404 (missing measure on update; missing version)
  - ValidationError -> 400 (missing tenant_id on DELETE /state)

Legacy test compat: the existing tests/test_tom_routes.py imports
TOMMeasureBulkItem, _parse_dt, _measure_to_dict, and DEFAULT_TENANT_ID
directly from compliance.api.tom_routes. All re-exported via __all__ so
the 44-test file runs unchanged.

mypy.ini flips compliance.api.tom_routes from ignore_errors=True to
False. TOMService carries the scoped Column[T] header.

Verified:
  - 217/217 pytest (173 baseline + 44 TOM) pass
  - OpenAPI 360/484 unchanged
  - mypy compliance/ -> Success on 124 source files
  - tom_routes.py 609 -> 215 LOC
  - Hard-cap violations: 16 -> 15

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-04-07 19:42:17 +02:00
parent 10073f3ef0
commit d571412657
5 changed files with 675 additions and 502 deletions

View File

@@ -11,276 +11,94 @@ Endpoints:
POST /tom/measures/bulk — Bulk upsert (for deriveTOMs sync)
GET /tom/stats — Statistics
GET /tom/export — Export as CSV or JSON
GET /tom/measures/{id}/versions — List measure versions
GET /tom/measures/{id}/versions/{n} — Get specific version
Phase 1 Step 4 refactor: handlers are thin and delegate to TOMService.
"""
import csv
import io
import json
import logging
from datetime import datetime, timezone
from typing import Optional, List, Any, Dict
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from compliance.api._http_errors import translate_domain_errors
from compliance.schemas.tom import (
TOMMeasureBulkBody,
TOMMeasureBulkItem, # re-exported for backwards compat (legacy test imports)
TOMMeasureCreate,
TOMMeasureUpdate,
TOMStateBody,
)
from ..db.tom_models import TOMStateDB, TOMMeasureDB
# Keep the legacy import path ``from compliance.api.tom_routes import TOMMeasureBulkItem``
# working — it was the public name before the Step 3 schemas split.
__all__ = [
"router",
"TOMMeasureBulkBody",
"TOMMeasureBulkItem",
"TOMMeasureCreate",
"TOMMeasureUpdate",
"TOMStateBody",
"DEFAULT_TENANT_ID",
"_parse_dt",
"_measure_to_dict",
]
from compliance.services.tom_service import (
DEFAULT_TENANT_ID,
TOMService,
_measure_to_dict, # re-exported for legacy test imports
_parse_dt, # re-exported for legacy test imports
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tom", tags=["tom"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
def get_tom_service(db: Session = Depends(get_db)) -> TOMService:
return TOMService(db)
# =============================================================================
# Pydantic Schemas (kept close to routes like loeschfristen pattern)
# =============================================================================
class TOMStateBody(BaseModel):
tenant_id: Optional[str] = None
tenantId: Optional[str] = None # Accept camelCase from frontend
state: Dict[str, Any]
version: Optional[int] = None
def get_tenant_id(self) -> str:
return self.tenant_id or self.tenantId or DEFAULT_TENANT_ID
class TOMMeasureCreate(BaseModel):
control_id: str
name: str
description: Optional[str] = None
category: str
type: str
applicability: str = "REQUIRED"
applicability_reason: Optional[str] = None
implementation_status: str = "NOT_IMPLEMENTED"
responsible_person: Optional[str] = None
responsible_department: Optional[str] = None
implementation_date: Optional[str] = None
review_date: Optional[str] = None
review_frequency: Optional[str] = None
priority: Optional[str] = None
complexity: Optional[str] = None
linked_evidence: Optional[List[Any]] = None
evidence_gaps: Optional[List[Any]] = None
related_controls: Optional[Dict[str, Any]] = None
verified_at: Optional[str] = None
verified_by: Optional[str] = None
effectiveness_rating: Optional[str] = None
class TOMMeasureUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
category: Optional[str] = None
type: Optional[str] = None
applicability: Optional[str] = None
applicability_reason: Optional[str] = None
implementation_status: Optional[str] = None
responsible_person: Optional[str] = None
responsible_department: Optional[str] = None
implementation_date: Optional[str] = None
review_date: Optional[str] = None
review_frequency: Optional[str] = None
priority: Optional[str] = None
complexity: Optional[str] = None
linked_evidence: Optional[List[Any]] = None
evidence_gaps: Optional[List[Any]] = None
related_controls: Optional[Dict[str, Any]] = None
verified_at: Optional[str] = None
verified_by: Optional[str] = None
effectiveness_rating: Optional[str] = None
class TOMMeasureBulkItem(BaseModel):
control_id: str
name: str
description: Optional[str] = None
category: str
type: str
applicability: str = "REQUIRED"
applicability_reason: Optional[str] = None
implementation_status: str = "NOT_IMPLEMENTED"
responsible_person: Optional[str] = None
responsible_department: Optional[str] = None
implementation_date: Optional[str] = None
review_date: Optional[str] = None
review_frequency: Optional[str] = None
priority: Optional[str] = None
complexity: Optional[str] = None
linked_evidence: Optional[List[Any]] = None
evidence_gaps: Optional[List[Any]] = None
related_controls: Optional[Dict[str, Any]] = None
class TOMMeasureBulkBody(BaseModel):
tenant_id: Optional[str] = None
measures: List[TOMMeasureBulkItem]
# =============================================================================
# Helper: parse optional datetime strings
# =============================================================================
def _parse_dt(val: Optional[str]) -> Optional[datetime]:
if not val:
return None
try:
return datetime.fromisoformat(val.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return None
def _measure_to_dict(m: TOMMeasureDB) -> dict:
return {
"id": str(m.id),
"tenant_id": m.tenant_id,
"control_id": m.control_id,
"name": m.name,
"description": m.description,
"category": m.category,
"type": m.type,
"applicability": m.applicability,
"applicability_reason": m.applicability_reason,
"implementation_status": m.implementation_status,
"responsible_person": m.responsible_person,
"responsible_department": m.responsible_department,
"implementation_date": m.implementation_date.isoformat() if m.implementation_date else None,
"review_date": m.review_date.isoformat() if m.review_date else None,
"review_frequency": m.review_frequency,
"priority": m.priority,
"complexity": m.complexity,
"linked_evidence": m.linked_evidence or [],
"evidence_gaps": m.evidence_gaps or [],
"related_controls": m.related_controls or {},
"verified_at": m.verified_at.isoformat() if m.verified_at else None,
"verified_by": m.verified_by,
"effectiveness_rating": m.effectiveness_rating,
"created_by": m.created_by,
"created_at": m.created_at.isoformat() if m.created_at else None,
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
}
# =============================================================================
# STATE ENDPOINTS
# STATE
# =============================================================================
@router.get("/state")
async def get_tom_state(
tenant_id: Optional[str] = Query(None, alias="tenant_id"),
tenantId: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Load TOM generator state for a tenant."""
tid = tenant_id or tenantId or DEFAULT_TENANT_ID
row = db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tid).first()
if not row:
return {
"success": True,
"data": {
"tenantId": tid,
"state": {},
"version": 0,
"isNew": True,
},
}
return {
"success": True,
"data": {
"tenantId": tid,
"state": row.state,
"version": row.version,
"lastModified": row.updated_at.isoformat() if row.updated_at else None,
},
}
with translate_domain_errors():
return service.get_state(tenant_id or tenantId or DEFAULT_TENANT_ID)
@router.post("/state")
async def save_tom_state(body: TOMStateBody, db: Session = Depends(get_db)):
async def save_tom_state(
body: TOMStateBody,
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Save TOM generator state with optimistic locking (version check)."""
tid = body.get_tenant_id()
existing = db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tid).first()
# Version conflict check
if body.version is not None and existing and existing.version != body.version:
raise HTTPException(
status_code=409,
detail={
"success": False,
"error": "Version conflict. State was modified by another request.",
"code": "VERSION_CONFLICT",
},
)
now = datetime.now(timezone.utc)
if existing:
existing.state = body.state
existing.version = existing.version + 1
existing.updated_at = now
else:
existing = TOMStateDB(
tenant_id=tid,
state=body.state,
version=1,
created_at=now,
updated_at=now,
)
db.add(existing)
db.commit()
db.refresh(existing)
return {
"success": True,
"data": {
"tenantId": tid,
"state": existing.state,
"version": existing.version,
"lastModified": existing.updated_at.isoformat() if existing.updated_at else None,
},
}
with translate_domain_errors():
return service.save_state(body)
@router.delete("/state")
async def delete_tom_state(
tenant_id: Optional[str] = Query(None, alias="tenant_id"),
tenantId: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Clear TOM generator state for a tenant."""
tid = tenant_id or tenantId
if not tid:
raise HTTPException(status_code=400, detail="tenant_id is required")
row = db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tid).first()
deleted = False
if row:
db.delete(row)
db.commit()
deleted = True
return {
"success": True,
"tenantId": tid,
"deleted": deleted,
"deletedAt": datetime.now(timezone.utc).isoformat(),
}
with translate_domain_errors():
return service.delete_state(tenant_id or tenantId)
# =============================================================================
# MEASURES ENDPOINTS
# MEASURES
# =============================================================================
@router.get("/measures")
@@ -292,188 +110,51 @@ async def list_measures(
search: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""List TOM measures with optional filters."""
tid = tenant_id or DEFAULT_TENANT_ID
q = db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tid)
if category:
q = q.filter(TOMMeasureDB.category == category)
if implementation_status:
q = q.filter(TOMMeasureDB.implementation_status == implementation_status)
if priority:
q = q.filter(TOMMeasureDB.priority == priority)
if search:
pattern = f"%{search}%"
q = q.filter(
(TOMMeasureDB.name.ilike(pattern))
| (TOMMeasureDB.description.ilike(pattern))
| (TOMMeasureDB.control_id.ilike(pattern))
with translate_domain_errors():
return service.list_measures(
tenant_id=tenant_id or DEFAULT_TENANT_ID,
category=category,
implementation_status=implementation_status,
priority=priority,
search=search,
limit=limit,
offset=offset,
)
total = q.count()
rows = q.order_by(TOMMeasureDB.control_id).offset(offset).limit(limit).all()
return {
"measures": [_measure_to_dict(r) for r in rows],
"total": total,
"limit": limit,
"offset": offset,
}
@router.post("/measures", status_code=201)
async def create_measure(
body: TOMMeasureCreate,
tenant_id: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Create a single TOM measure."""
tid = tenant_id or DEFAULT_TENANT_ID
# Check for duplicate control_id
existing = (
db.query(TOMMeasureDB)
.filter(TOMMeasureDB.tenant_id == tid, TOMMeasureDB.control_id == body.control_id)
.first()
)
if existing:
raise HTTPException(status_code=409, detail=f"Measure with control_id '{body.control_id}' already exists")
now = datetime.now(timezone.utc)
measure = TOMMeasureDB(
tenant_id=tid,
control_id=body.control_id,
name=body.name,
description=body.description,
category=body.category,
type=body.type,
applicability=body.applicability,
applicability_reason=body.applicability_reason,
implementation_status=body.implementation_status,
responsible_person=body.responsible_person,
responsible_department=body.responsible_department,
implementation_date=_parse_dt(body.implementation_date),
review_date=_parse_dt(body.review_date),
review_frequency=body.review_frequency,
priority=body.priority,
complexity=body.complexity,
linked_evidence=body.linked_evidence or [],
evidence_gaps=body.evidence_gaps or [],
related_controls=body.related_controls or {},
verified_at=_parse_dt(body.verified_at),
verified_by=body.verified_by,
effectiveness_rating=body.effectiveness_rating,
created_at=now,
updated_at=now,
)
db.add(measure)
db.commit()
db.refresh(measure)
return _measure_to_dict(measure)
with translate_domain_errors():
return service.create_measure(tenant_id or DEFAULT_TENANT_ID, body)
@router.put("/measures/{measure_id}")
async def update_measure(
measure_id: UUID,
body: TOMMeasureUpdate,
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Update a TOM measure."""
row = db.query(TOMMeasureDB).filter(TOMMeasureDB.id == measure_id).first()
if not row:
raise HTTPException(status_code=404, detail="Measure not found")
update_data = body.model_dump(exclude_unset=True)
for key, val in update_data.items():
if key in ("implementation_date", "review_date", "verified_at"):
val = _parse_dt(val)
setattr(row, key, val)
row.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(row)
return _measure_to_dict(row)
with translate_domain_errors():
return service.update_measure(measure_id, body)
@router.post("/measures/bulk")
async def bulk_upsert_measures(
body: TOMMeasureBulkBody,
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Bulk upsert measures — used by deriveTOMs sync from frontend."""
tid = body.tenant_id or DEFAULT_TENANT_ID
now = datetime.now(timezone.utc)
created = 0
updated = 0
for item in body.measures:
existing = (
db.query(TOMMeasureDB)
.filter(TOMMeasureDB.tenant_id == tid, TOMMeasureDB.control_id == item.control_id)
.first()
)
if existing:
existing.name = item.name
existing.description = item.description
existing.category = item.category
existing.type = item.type
existing.applicability = item.applicability
existing.applicability_reason = item.applicability_reason
existing.implementation_status = item.implementation_status
existing.responsible_person = item.responsible_person
existing.responsible_department = item.responsible_department
existing.implementation_date = _parse_dt(item.implementation_date)
existing.review_date = _parse_dt(item.review_date)
existing.review_frequency = item.review_frequency
existing.priority = item.priority
existing.complexity = item.complexity
existing.linked_evidence = item.linked_evidence or []
existing.evidence_gaps = item.evidence_gaps or []
existing.related_controls = item.related_controls or {}
existing.updated_at = now
updated += 1
else:
measure = TOMMeasureDB(
tenant_id=tid,
control_id=item.control_id,
name=item.name,
description=item.description,
category=item.category,
type=item.type,
applicability=item.applicability,
applicability_reason=item.applicability_reason,
implementation_status=item.implementation_status,
responsible_person=item.responsible_person,
responsible_department=item.responsible_department,
implementation_date=_parse_dt(item.implementation_date),
review_date=_parse_dt(item.review_date),
review_frequency=item.review_frequency,
priority=item.priority,
complexity=item.complexity,
linked_evidence=item.linked_evidence or [],
evidence_gaps=item.evidence_gaps or [],
related_controls=item.related_controls or {},
created_at=now,
updated_at=now,
)
db.add(measure)
created += 1
db.commit()
return {
"success": True,
"tenant_id": tid,
"created": created,
"updated": updated,
"total": created + updated,
}
with translate_domain_errors():
return service.bulk_upsert(body)
# =============================================================================
@@ -483,96 +164,22 @@ async def bulk_upsert_measures(
@router.get("/stats")
async def get_tom_stats(
tenant_id: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> dict[str, Any]:
"""Return TOM statistics for a tenant."""
tid = tenant_id or DEFAULT_TENANT_ID
base_q = db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tid)
total = base_q.count()
# By status
status_rows = (
db.query(TOMMeasureDB.implementation_status, func.count(TOMMeasureDB.id))
.filter(TOMMeasureDB.tenant_id == tid)
.group_by(TOMMeasureDB.implementation_status)
.all()
)
by_status = {row[0]: row[1] for row in status_rows}
# By category
cat_rows = (
db.query(TOMMeasureDB.category, func.count(TOMMeasureDB.id))
.filter(TOMMeasureDB.tenant_id == tid)
.group_by(TOMMeasureDB.category)
.all()
)
by_category = {row[0]: row[1] for row in cat_rows}
# Overdue reviews
now = datetime.now(timezone.utc)
overdue = (
base_q.filter(
TOMMeasureDB.review_date.isnot(None),
TOMMeasureDB.review_date < now,
)
.count()
)
return {
"total": total,
"by_status": by_status,
"by_category": by_category,
"overdue_review_count": overdue,
"implemented": by_status.get("IMPLEMENTED", 0),
"partial": by_status.get("PARTIAL", 0),
"not_implemented": by_status.get("NOT_IMPLEMENTED", 0),
}
with translate_domain_errors():
return service.stats(tenant_id or DEFAULT_TENANT_ID)
@router.get("/export")
async def export_measures(
tenant_id: Optional[str] = Query(None),
format: str = Query("csv"),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> StreamingResponse:
"""Export TOM measures as CSV (semicolon-separated) or JSON."""
tid = tenant_id or DEFAULT_TENANT_ID
rows = (
db.query(TOMMeasureDB)
.filter(TOMMeasureDB.tenant_id == tid)
.order_by(TOMMeasureDB.control_id)
.all()
)
measures = [_measure_to_dict(r) for r in rows]
if format == "json":
return StreamingResponse(
io.BytesIO(json.dumps(measures, ensure_ascii=False, indent=2).encode("utf-8")),
media_type="application/json",
headers={"Content-Disposition": "attachment; filename=tom_export.json"},
)
# CSV (semicolon, like VVT)
output = io.StringIO()
fieldnames = [
"control_id", "name", "description", "category", "type",
"applicability", "implementation_status", "responsible_person",
"responsible_department", "implementation_date", "review_date",
"review_frequency", "priority", "complexity", "effectiveness_rating",
]
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=";", extrasaction="ignore")
writer.writeheader()
for m in measures:
writer.writerow(m)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode("utf-8")),
media_type="text/csv; charset=utf-8",
headers={"Content-Disposition": "attachment; filename=tom_export.csv"},
)
with translate_domain_errors():
return service.export(tenant_id or DEFAULT_TENANT_ID, format)
# =============================================================================
@@ -584,12 +191,13 @@ async def list_measure_versions(
measure_id: str,
tenant_id: Optional[str] = Query(None, alias="tenant_id"),
tenantId: Optional[str] = Query(None, alias="tenantId"),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> Any:
"""List all versions for a TOM measure."""
from .versioning_utils import list_versions
tid = tenant_id or tenantId or DEFAULT_TENANT_ID
return list_versions(db, "tom", measure_id, tid)
with translate_domain_errors():
return service.list_versions(
measure_id, tenant_id or tenantId or DEFAULT_TENANT_ID
)
@router.get("/measures/{measure_id}/versions/{version_number}")
@@ -598,12 +206,10 @@ async def get_measure_version(
version_number: int,
tenant_id: Optional[str] = Query(None, alias="tenant_id"),
tenantId: Optional[str] = Query(None, alias="tenantId"),
db: Session = Depends(get_db),
):
service: TOMService = Depends(get_tom_service),
) -> Any:
"""Get a specific TOM measure version with full snapshot."""
from .versioning_utils import get_version
tid = tenant_id or tenantId or DEFAULT_TENANT_ID
v = get_version(db, "tom", measure_id, version_number, tid)
if not v:
raise HTTPException(status_code=404, detail=f"Version {version_number} not found")
return v
with translate_domain_errors():
return service.get_version(
measure_id, version_number, tenant_id or tenantId or DEFAULT_TENANT_ID
)