Files
breakpilot-compliance/backend-compliance/compliance/services/tom_service.py
Sharang Parnerkar d571412657 refactor(backend/api): extract TOMService (Step 4 — file 3 of 18)
compliance/api/tom_routes.py (609 LOC) -> 215 LOC thin routes +
434-line TOMService. Request bodies (TOMStateBody, TOMMeasureCreate,
TOMMeasureUpdate, TOMMeasureBulkItem, TOMMeasureBulkBody) moved to
compliance/schemas/tom.py (joining the existing response models from
the Step 3 split).

Single-service split (not two like banner): state, measures CRUD + bulk
upsert, stats, export, and version lookups are all tightly coupled
around the TOMMeasureDB aggregate, so splitting would create artificial
boundaries. TOMService is 434 LOC — comfortably under the 500 hard cap.

Domain error mapping:
  - ConflictError   -> 409 (version conflict on state save; duplicate control_id on create)
  - NotFoundError   -> 404 (missing measure on update; missing version)
  - ValidationError -> 400 (missing tenant_id on DELETE /state)

Legacy test compat: the existing tests/test_tom_routes.py imports
TOMMeasureBulkItem, _parse_dt, _measure_to_dict, and DEFAULT_TENANT_ID
directly from compliance.api.tom_routes. All re-exported via __all__ so
the 44-test file runs unchanged.

mypy.ini flips compliance.api.tom_routes from ignore_errors=True to
False. TOMService carries the scoped Column[T] header.

Verified:
  - 217/217 pytest (173 baseline + 44 TOM) pass
  - OpenAPI 360/484 unchanged
  - mypy compliance/ -> Success on 124 source files
  - tom_routes.py 609 -> 215 LOC
  - Hard-cap violations: 16 -> 15

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 19:42:17 +02:00

435 lines
16 KiB
Python

# mypy: disable-error-code="arg-type,assignment"
# SQLAlchemy 1.x Column() descriptors are Column[T] statically, T at runtime.
"""
TOM service — Technisch-Organisatorische Massnahmen (Art. 32 DSGVO).
Phase 1 Step 4: extracted from ``compliance.api.tom_routes``. Covers TOM
generator state persistence, the measures CRUD + bulk upsert, stats,
CSV/JSON export, and version lookups via the shared
``compliance.api.versioning_utils``.
"""
import csv
import io
import json
from datetime import datetime, timezone
from typing import Any, Optional
from fastapi.responses import StreamingResponse
from sqlalchemy import func
from sqlalchemy.orm import Session
from compliance.db.tom_models import TOMMeasureDB, TOMStateDB
from compliance.domain import ConflictError, NotFoundError, ValidationError
from compliance.schemas.tom import (
TOMMeasureBulkBody,
TOMMeasureCreate,
TOMMeasureUpdate,
TOMStateBody,
)
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
_CSV_FIELDS = [
"control_id", "name", "description", "category", "type",
"applicability", "implementation_status", "responsible_person",
"responsible_department", "implementation_date", "review_date",
"review_frequency", "priority", "complexity", "effectiveness_rating",
]
def _parse_dt(val: Optional[str]) -> Optional[datetime]:
"""Parse an ISO-8601 string (accepting trailing 'Z') or return None."""
if not val:
return None
try:
return datetime.fromisoformat(val.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return None
def _measure_to_dict(m: TOMMeasureDB) -> dict[str, Any]:
return {
"id": str(m.id),
"tenant_id": m.tenant_id,
"control_id": m.control_id,
"name": m.name,
"description": m.description,
"category": m.category,
"type": m.type,
"applicability": m.applicability,
"applicability_reason": m.applicability_reason,
"implementation_status": m.implementation_status,
"responsible_person": m.responsible_person,
"responsible_department": m.responsible_department,
"implementation_date": m.implementation_date.isoformat() if m.implementation_date else None,
"review_date": m.review_date.isoformat() if m.review_date else None,
"review_frequency": m.review_frequency,
"priority": m.priority,
"complexity": m.complexity,
"linked_evidence": m.linked_evidence or [],
"evidence_gaps": m.evidence_gaps or [],
"related_controls": m.related_controls or {},
"verified_at": m.verified_at.isoformat() if m.verified_at else None,
"verified_by": m.verified_by,
"effectiveness_rating": m.effectiveness_rating,
"created_by": m.created_by,
"created_at": m.created_at.isoformat() if m.created_at else None,
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
}
class TOMService:
"""Business logic for TOM state, measures, stats, and export."""
def __init__(self, db: Session) -> None:
self.db = db
# ------------------------------------------------------------------
# State endpoints
# ------------------------------------------------------------------
def get_state(self, tenant_id: str) -> dict[str, Any]:
row = (
self.db.query(TOMStateDB)
.filter(TOMStateDB.tenant_id == tenant_id)
.first()
)
if not row:
return {
"success": True,
"data": {
"tenantId": tenant_id,
"state": {},
"version": 0,
"isNew": True,
},
}
return {
"success": True,
"data": {
"tenantId": tenant_id,
"state": row.state,
"version": row.version,
"lastModified": row.updated_at.isoformat() if row.updated_at else None,
},
}
def save_state(self, body: TOMStateBody) -> dict[str, Any]:
tid = body.get_tenant_id()
existing = self.db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tid).first()
if body.version is not None and existing and existing.version != body.version:
raise ConflictError(
"Version conflict. State was modified by another request."
)
now = datetime.now(timezone.utc)
if existing:
existing.state = body.state
existing.version = existing.version + 1
existing.updated_at = now
else:
existing = TOMStateDB(
tenant_id=tid,
state=body.state,
version=1,
created_at=now,
updated_at=now,
)
self.db.add(existing)
self.db.commit()
self.db.refresh(existing)
return {
"success": True,
"data": {
"tenantId": tid,
"state": existing.state,
"version": existing.version,
"lastModified": existing.updated_at.isoformat() if existing.updated_at else None,
},
}
def delete_state(self, tenant_id: Optional[str]) -> dict[str, Any]:
if not tenant_id:
raise ValidationError("tenant_id is required")
row = (
self.db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tenant_id).first()
)
deleted = False
if row:
self.db.delete(row)
self.db.commit()
deleted = True
return {
"success": True,
"tenantId": tenant_id,
"deleted": deleted,
"deletedAt": datetime.now(timezone.utc).isoformat(),
}
# ------------------------------------------------------------------
# Measures CRUD
# ------------------------------------------------------------------
def list_measures(
self,
tenant_id: str,
category: Optional[str],
implementation_status: Optional[str],
priority: Optional[str],
search: Optional[str],
limit: int,
offset: int,
) -> dict[str, Any]:
q = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tenant_id)
if category:
q = q.filter(TOMMeasureDB.category == category)
if implementation_status:
q = q.filter(TOMMeasureDB.implementation_status == implementation_status)
if priority:
q = q.filter(TOMMeasureDB.priority == priority)
if search:
pattern = f"%{search}%"
q = q.filter(
(TOMMeasureDB.name.ilike(pattern))
| (TOMMeasureDB.description.ilike(pattern))
| (TOMMeasureDB.control_id.ilike(pattern))
)
total = q.count()
rows = q.order_by(TOMMeasureDB.control_id).offset(offset).limit(limit).all()
return {
"measures": [_measure_to_dict(r) for r in rows],
"total": total,
"limit": limit,
"offset": offset,
}
def create_measure(
self, tenant_id: str, body: TOMMeasureCreate
) -> dict[str, Any]:
existing = (
self.db.query(TOMMeasureDB)
.filter(
TOMMeasureDB.tenant_id == tenant_id,
TOMMeasureDB.control_id == body.control_id,
)
.first()
)
if existing:
raise ConflictError(
f"Measure with control_id '{body.control_id}' already exists"
)
now = datetime.now(timezone.utc)
measure = TOMMeasureDB(
tenant_id=tenant_id,
control_id=body.control_id,
name=body.name,
description=body.description,
category=body.category,
type=body.type,
applicability=body.applicability,
applicability_reason=body.applicability_reason,
implementation_status=body.implementation_status,
responsible_person=body.responsible_person,
responsible_department=body.responsible_department,
implementation_date=_parse_dt(body.implementation_date),
review_date=_parse_dt(body.review_date),
review_frequency=body.review_frequency,
priority=body.priority,
complexity=body.complexity,
linked_evidence=body.linked_evidence or [],
evidence_gaps=body.evidence_gaps or [],
related_controls=body.related_controls or {},
verified_at=_parse_dt(body.verified_at),
verified_by=body.verified_by,
effectiveness_rating=body.effectiveness_rating,
created_at=now,
updated_at=now,
)
self.db.add(measure)
self.db.commit()
self.db.refresh(measure)
return _measure_to_dict(measure)
def update_measure(self, measure_id: Any, body: TOMMeasureUpdate) -> dict[str, Any]:
row = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.id == measure_id).first()
if not row:
raise NotFoundError("Measure not found")
for key, val in body.model_dump(exclude_unset=True).items():
if key in ("implementation_date", "review_date", "verified_at"):
val = _parse_dt(val)
setattr(row, key, val)
row.updated_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(row)
return _measure_to_dict(row)
def bulk_upsert(self, body: TOMMeasureBulkBody) -> dict[str, Any]:
tid = body.tenant_id or DEFAULT_TENANT_ID
now = datetime.now(timezone.utc)
created = 0
updated = 0
for item in body.measures:
existing = (
self.db.query(TOMMeasureDB)
.filter(
TOMMeasureDB.tenant_id == tid,
TOMMeasureDB.control_id == item.control_id,
)
.first()
)
if existing:
existing.name = item.name
existing.description = item.description
existing.category = item.category
existing.type = item.type
existing.applicability = item.applicability
existing.applicability_reason = item.applicability_reason
existing.implementation_status = item.implementation_status
existing.responsible_person = item.responsible_person
existing.responsible_department = item.responsible_department
existing.implementation_date = _parse_dt(item.implementation_date)
existing.review_date = _parse_dt(item.review_date)
existing.review_frequency = item.review_frequency
existing.priority = item.priority
existing.complexity = item.complexity
existing.linked_evidence = item.linked_evidence or []
existing.evidence_gaps = item.evidence_gaps or []
existing.related_controls = item.related_controls or {}
existing.updated_at = now
updated += 1
else:
self.db.add(
TOMMeasureDB(
tenant_id=tid,
control_id=item.control_id,
name=item.name,
description=item.description,
category=item.category,
type=item.type,
applicability=item.applicability,
applicability_reason=item.applicability_reason,
implementation_status=item.implementation_status,
responsible_person=item.responsible_person,
responsible_department=item.responsible_department,
implementation_date=_parse_dt(item.implementation_date),
review_date=_parse_dt(item.review_date),
review_frequency=item.review_frequency,
priority=item.priority,
complexity=item.complexity,
linked_evidence=item.linked_evidence or [],
evidence_gaps=item.evidence_gaps or [],
related_controls=item.related_controls or {},
created_at=now,
updated_at=now,
)
)
created += 1
self.db.commit()
return {
"success": True,
"tenant_id": tid,
"created": created,
"updated": updated,
"total": created + updated,
}
# ------------------------------------------------------------------
# Stats + export
# ------------------------------------------------------------------
def stats(self, tenant_id: str) -> dict[str, Any]:
base_q = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tenant_id)
total = base_q.count()
status_rows = (
self.db.query(
TOMMeasureDB.implementation_status, func.count(TOMMeasureDB.id)
)
.filter(TOMMeasureDB.tenant_id == tenant_id)
.group_by(TOMMeasureDB.implementation_status)
.all()
)
by_status: dict[str, int] = {row[0]: row[1] for row in status_rows}
cat_rows = (
self.db.query(TOMMeasureDB.category, func.count(TOMMeasureDB.id))
.filter(TOMMeasureDB.tenant_id == tenant_id)
.group_by(TOMMeasureDB.category)
.all()
)
by_category: dict[str, int] = {row[0]: row[1] for row in cat_rows}
now = datetime.now(timezone.utc)
overdue = base_q.filter(
TOMMeasureDB.review_date.isnot(None),
TOMMeasureDB.review_date < now,
).count()
return {
"total": total,
"by_status": by_status,
"by_category": by_category,
"overdue_review_count": overdue,
"implemented": by_status.get("IMPLEMENTED", 0),
"partial": by_status.get("PARTIAL", 0),
"not_implemented": by_status.get("NOT_IMPLEMENTED", 0),
}
def export(self, tenant_id: str, fmt: str) -> StreamingResponse:
rows = (
self.db.query(TOMMeasureDB)
.filter(TOMMeasureDB.tenant_id == tenant_id)
.order_by(TOMMeasureDB.control_id)
.all()
)
measures = [_measure_to_dict(r) for r in rows]
if fmt == "json":
return StreamingResponse(
io.BytesIO(
json.dumps(measures, ensure_ascii=False, indent=2).encode("utf-8")
),
media_type="application/json",
headers={"Content-Disposition": "attachment; filename=tom_export.json"},
)
# CSV (semicolon-separated to match VVT convention)
output = io.StringIO()
writer = csv.DictWriter(
output, fieldnames=_CSV_FIELDS, delimiter=";", extrasaction="ignore"
)
writer.writeheader()
for m in measures:
writer.writerow(m)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode("utf-8")),
media_type="text/csv; charset=utf-8",
headers={"Content-Disposition": "attachment; filename=tom_export.csv"},
)
# ------------------------------------------------------------------
# Versioning (delegates to shared versioning_utils)
# ------------------------------------------------------------------
def list_versions(self, measure_id: str, tenant_id: str) -> Any:
from compliance.api.versioning_utils import list_versions
return list_versions(self.db, "tom", measure_id, tenant_id)
def get_version(
self, measure_id: str, version_number: int, tenant_id: str
) -> Any:
from compliance.api.versioning_utils import get_version
v = get_version(self.db, "tom", measure_id, version_number, tenant_id)
if not v:
raise NotFoundError(f"Version {version_number} not found")
return v