# mypy: disable-error-code="arg-type,assignment" # SQLAlchemy 1.x Column() descriptors are Column[T] statically, T at runtime. """ TOM service — Technisch-Organisatorische Massnahmen (Art. 32 DSGVO). Phase 1 Step 4: extracted from ``compliance.api.tom_routes``. Covers TOM generator state persistence, the measures CRUD + bulk upsert, stats, CSV/JSON export, and version lookups via the shared ``compliance.api.versioning_utils``. """ import csv import io import json from datetime import datetime, timezone from typing import Any, Optional from fastapi.responses import StreamingResponse from sqlalchemy import func from sqlalchemy.orm import Session from compliance.db.tom_models import TOMMeasureDB, TOMStateDB from compliance.domain import ConflictError, NotFoundError, ValidationError from compliance.schemas.tom import ( TOMMeasureBulkBody, TOMMeasureCreate, TOMMeasureUpdate, TOMStateBody, ) DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e" _CSV_FIELDS = [ "control_id", "name", "description", "category", "type", "applicability", "implementation_status", "responsible_person", "responsible_department", "implementation_date", "review_date", "review_frequency", "priority", "complexity", "effectiveness_rating", ] def _parse_dt(val: Optional[str]) -> Optional[datetime]: """Parse an ISO-8601 string (accepting trailing 'Z') or return None.""" if not val: return None try: return datetime.fromisoformat(val.replace("Z", "+00:00")) except (ValueError, AttributeError): return None def _measure_to_dict(m: TOMMeasureDB) -> dict[str, Any]: return { "id": str(m.id), "tenant_id": m.tenant_id, "control_id": m.control_id, "name": m.name, "description": m.description, "category": m.category, "type": m.type, "applicability": m.applicability, "applicability_reason": m.applicability_reason, "implementation_status": m.implementation_status, "responsible_person": m.responsible_person, "responsible_department": m.responsible_department, "implementation_date": m.implementation_date.isoformat() if m.implementation_date else None, "review_date": m.review_date.isoformat() if m.review_date else None, "review_frequency": m.review_frequency, "priority": m.priority, "complexity": m.complexity, "linked_evidence": m.linked_evidence or [], "evidence_gaps": m.evidence_gaps or [], "related_controls": m.related_controls or {}, "verified_at": m.verified_at.isoformat() if m.verified_at else None, "verified_by": m.verified_by, "effectiveness_rating": m.effectiveness_rating, "created_by": m.created_by, "created_at": m.created_at.isoformat() if m.created_at else None, "updated_at": m.updated_at.isoformat() if m.updated_at else None, } class TOMService: """Business logic for TOM state, measures, stats, and export.""" def __init__(self, db: Session) -> None: self.db = db # ------------------------------------------------------------------ # State endpoints # ------------------------------------------------------------------ def get_state(self, tenant_id: str) -> dict[str, Any]: row = ( self.db.query(TOMStateDB) .filter(TOMStateDB.tenant_id == tenant_id) .first() ) if not row: return { "success": True, "data": { "tenantId": tenant_id, "state": {}, "version": 0, "isNew": True, }, } return { "success": True, "data": { "tenantId": tenant_id, "state": row.state, "version": row.version, "lastModified": row.updated_at.isoformat() if row.updated_at else None, }, } def save_state(self, body: TOMStateBody) -> dict[str, Any]: tid = body.get_tenant_id() existing = self.db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tid).first() if body.version is not None and existing and existing.version != body.version: raise ConflictError( "Version conflict. State was modified by another request." ) now = datetime.now(timezone.utc) if existing: existing.state = body.state existing.version = existing.version + 1 existing.updated_at = now else: existing = TOMStateDB( tenant_id=tid, state=body.state, version=1, created_at=now, updated_at=now, ) self.db.add(existing) self.db.commit() self.db.refresh(existing) return { "success": True, "data": { "tenantId": tid, "state": existing.state, "version": existing.version, "lastModified": existing.updated_at.isoformat() if existing.updated_at else None, }, } def delete_state(self, tenant_id: Optional[str]) -> dict[str, Any]: if not tenant_id: raise ValidationError("tenant_id is required") row = ( self.db.query(TOMStateDB).filter(TOMStateDB.tenant_id == tenant_id).first() ) deleted = False if row: self.db.delete(row) self.db.commit() deleted = True return { "success": True, "tenantId": tenant_id, "deleted": deleted, "deletedAt": datetime.now(timezone.utc).isoformat(), } # ------------------------------------------------------------------ # Measures CRUD # ------------------------------------------------------------------ def list_measures( self, tenant_id: str, category: Optional[str], implementation_status: Optional[str], priority: Optional[str], search: Optional[str], limit: int, offset: int, ) -> dict[str, Any]: q = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tenant_id) if category: q = q.filter(TOMMeasureDB.category == category) if implementation_status: q = q.filter(TOMMeasureDB.implementation_status == implementation_status) if priority: q = q.filter(TOMMeasureDB.priority == priority) if search: pattern = f"%{search}%" q = q.filter( (TOMMeasureDB.name.ilike(pattern)) | (TOMMeasureDB.description.ilike(pattern)) | (TOMMeasureDB.control_id.ilike(pattern)) ) total = q.count() rows = q.order_by(TOMMeasureDB.control_id).offset(offset).limit(limit).all() return { "measures": [_measure_to_dict(r) for r in rows], "total": total, "limit": limit, "offset": offset, } def create_measure( self, tenant_id: str, body: TOMMeasureCreate ) -> dict[str, Any]: existing = ( self.db.query(TOMMeasureDB) .filter( TOMMeasureDB.tenant_id == tenant_id, TOMMeasureDB.control_id == body.control_id, ) .first() ) if existing: raise ConflictError( f"Measure with control_id '{body.control_id}' already exists" ) now = datetime.now(timezone.utc) measure = TOMMeasureDB( tenant_id=tenant_id, control_id=body.control_id, name=body.name, description=body.description, category=body.category, type=body.type, applicability=body.applicability, applicability_reason=body.applicability_reason, implementation_status=body.implementation_status, responsible_person=body.responsible_person, responsible_department=body.responsible_department, implementation_date=_parse_dt(body.implementation_date), review_date=_parse_dt(body.review_date), review_frequency=body.review_frequency, priority=body.priority, complexity=body.complexity, linked_evidence=body.linked_evidence or [], evidence_gaps=body.evidence_gaps or [], related_controls=body.related_controls or {}, verified_at=_parse_dt(body.verified_at), verified_by=body.verified_by, effectiveness_rating=body.effectiveness_rating, created_at=now, updated_at=now, ) self.db.add(measure) self.db.commit() self.db.refresh(measure) return _measure_to_dict(measure) def update_measure(self, measure_id: Any, body: TOMMeasureUpdate) -> dict[str, Any]: row = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.id == measure_id).first() if not row: raise NotFoundError("Measure not found") for key, val in body.model_dump(exclude_unset=True).items(): if key in ("implementation_date", "review_date", "verified_at"): val = _parse_dt(val) setattr(row, key, val) row.updated_at = datetime.now(timezone.utc) self.db.commit() self.db.refresh(row) return _measure_to_dict(row) def bulk_upsert(self, body: TOMMeasureBulkBody) -> dict[str, Any]: tid = body.tenant_id or DEFAULT_TENANT_ID now = datetime.now(timezone.utc) created = 0 updated = 0 for item in body.measures: existing = ( self.db.query(TOMMeasureDB) .filter( TOMMeasureDB.tenant_id == tid, TOMMeasureDB.control_id == item.control_id, ) .first() ) if existing: existing.name = item.name existing.description = item.description existing.category = item.category existing.type = item.type existing.applicability = item.applicability existing.applicability_reason = item.applicability_reason existing.implementation_status = item.implementation_status existing.responsible_person = item.responsible_person existing.responsible_department = item.responsible_department existing.implementation_date = _parse_dt(item.implementation_date) existing.review_date = _parse_dt(item.review_date) existing.review_frequency = item.review_frequency existing.priority = item.priority existing.complexity = item.complexity existing.linked_evidence = item.linked_evidence or [] existing.evidence_gaps = item.evidence_gaps or [] existing.related_controls = item.related_controls or {} existing.updated_at = now updated += 1 else: self.db.add( TOMMeasureDB( tenant_id=tid, control_id=item.control_id, name=item.name, description=item.description, category=item.category, type=item.type, applicability=item.applicability, applicability_reason=item.applicability_reason, implementation_status=item.implementation_status, responsible_person=item.responsible_person, responsible_department=item.responsible_department, implementation_date=_parse_dt(item.implementation_date), review_date=_parse_dt(item.review_date), review_frequency=item.review_frequency, priority=item.priority, complexity=item.complexity, linked_evidence=item.linked_evidence or [], evidence_gaps=item.evidence_gaps or [], related_controls=item.related_controls or {}, created_at=now, updated_at=now, ) ) created += 1 self.db.commit() return { "success": True, "tenant_id": tid, "created": created, "updated": updated, "total": created + updated, } # ------------------------------------------------------------------ # Stats + export # ------------------------------------------------------------------ def stats(self, tenant_id: str) -> dict[str, Any]: base_q = self.db.query(TOMMeasureDB).filter(TOMMeasureDB.tenant_id == tenant_id) total = base_q.count() status_rows = ( self.db.query( TOMMeasureDB.implementation_status, func.count(TOMMeasureDB.id) ) .filter(TOMMeasureDB.tenant_id == tenant_id) .group_by(TOMMeasureDB.implementation_status) .all() ) by_status: dict[str, int] = {row[0]: row[1] for row in status_rows} cat_rows = ( self.db.query(TOMMeasureDB.category, func.count(TOMMeasureDB.id)) .filter(TOMMeasureDB.tenant_id == tenant_id) .group_by(TOMMeasureDB.category) .all() ) by_category: dict[str, int] = {row[0]: row[1] for row in cat_rows} now = datetime.now(timezone.utc) overdue = base_q.filter( TOMMeasureDB.review_date.isnot(None), TOMMeasureDB.review_date < now, ).count() return { "total": total, "by_status": by_status, "by_category": by_category, "overdue_review_count": overdue, "implemented": by_status.get("IMPLEMENTED", 0), "partial": by_status.get("PARTIAL", 0), "not_implemented": by_status.get("NOT_IMPLEMENTED", 0), } def export(self, tenant_id: str, fmt: str) -> StreamingResponse: rows = ( self.db.query(TOMMeasureDB) .filter(TOMMeasureDB.tenant_id == tenant_id) .order_by(TOMMeasureDB.control_id) .all() ) measures = [_measure_to_dict(r) for r in rows] if fmt == "json": return StreamingResponse( io.BytesIO( json.dumps(measures, ensure_ascii=False, indent=2).encode("utf-8") ), media_type="application/json", headers={"Content-Disposition": "attachment; filename=tom_export.json"}, ) # CSV (semicolon-separated to match VVT convention) output = io.StringIO() writer = csv.DictWriter( output, fieldnames=_CSV_FIELDS, delimiter=";", extrasaction="ignore" ) writer.writeheader() for m in measures: writer.writerow(m) output.seek(0) return StreamingResponse( io.BytesIO(output.getvalue().encode("utf-8")), media_type="text/csv; charset=utf-8", headers={"Content-Disposition": "attachment; filename=tom_export.csv"}, ) # ------------------------------------------------------------------ # Versioning (delegates to shared versioning_utils) # ------------------------------------------------------------------ def list_versions(self, measure_id: str, tenant_id: str) -> Any: from compliance.api.versioning_utils import list_versions return list_versions(self.db, "tom", measure_id, tenant_id) def get_version( self, measure_id: str, version_number: int, tenant_id: str ) -> Any: from compliance.api.versioning_utils import get_version v = get_version(self.db, "tom", measure_id, version_number, tenant_id) if not v: raise NotFoundError(f"Version {version_number} not found") return v