feat(sdk): API-Referenz Frontend + Backend-Konsolidierung (Shared Utilities, CRUD Factory)
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 32s
CI / test-python-backend-compliance (push) Successful in 30s
CI / test-python-document-crawler (push) Successful in 21s
CI / test-python-dsms-gateway (push) Successful in 18s

- API-Referenz Seite (/sdk/api-docs) mit ~690 Endpoints, Suche, Filter, Modul-Index
- Shared db_utils.py (row_to_dict) + tenant_utils Integration in 6 Route-Dateien
- CRUD Factory (crud_factory.py) fuer zukuenftige Module
- Version-Route Auto-Registration in versioning_utils.py
- 1338 Tests bestanden, -232 Zeilen Duplikat-Code

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-07 17:07:43 +01:00
parent 7ec6b9f6c0
commit 6509e64dd9
19 changed files with 1921 additions and 390 deletions

View File

@@ -0,0 +1,312 @@
'use client'
import { useState, useMemo, useRef } from 'react'
import { apiModules } from '@/lib/sdk/api-docs/endpoints'
import type { HttpMethod, BackendService } from '@/lib/sdk/api-docs/types'
const METHOD_COLORS: Record<HttpMethod, string> = {
GET: 'bg-green-100 text-green-800',
POST: 'bg-blue-100 text-blue-800',
PUT: 'bg-yellow-100 text-yellow-800',
DELETE: 'bg-red-100 text-red-800',
PATCH: 'bg-purple-100 text-purple-800',
}
type ServiceFilter = 'all' | BackendService
export default function ApiDocsPage() {
const [search, setSearch] = useState('')
const [serviceFilter, setServiceFilter] = useState<ServiceFilter>('all')
const [methodFilter, setMethodFilter] = useState<HttpMethod | 'all'>('all')
const [expandedModules, setExpandedModules] = useState<Set<string>>(new Set())
const moduleRefs = useRef<Record<string, HTMLDivElement | null>>({})
const filteredModules = useMemo(() => {
const q = search.toLowerCase()
return apiModules
.filter((m) => serviceFilter === 'all' || m.service === serviceFilter)
.map((m) => {
const eps = m.endpoints.filter((e) => {
if (methodFilter !== 'all' && e.method !== methodFilter) return false
if (!q) return true
return (
e.path.toLowerCase().includes(q) ||
e.description.toLowerCase().includes(q) ||
m.name.toLowerCase().includes(q) ||
m.id.toLowerCase().includes(q)
)
})
return { ...m, endpoints: eps }
})
.filter((m) => m.endpoints.length > 0)
}, [search, serviceFilter, methodFilter])
const stats = useMemo(() => {
const total = apiModules.reduce((s, m) => s + m.endpoints.length, 0)
const python = apiModules.filter((m) => m.service === 'python').reduce((s, m) => s + m.endpoints.length, 0)
const go = apiModules.filter((m) => m.service === 'go').reduce((s, m) => s + m.endpoints.length, 0)
return { total, python, go, modules: apiModules.length }
}, [])
const filteredTotal = filteredModules.reduce((s, m) => s + m.endpoints.length, 0)
const toggleModule = (id: string) => {
setExpandedModules((prev) => {
const next = new Set(prev)
if (next.has(id)) next.delete(id)
else next.add(id)
return next
})
}
const expandAll = () => setExpandedModules(new Set(filteredModules.map((m) => m.id)))
const collapseAll = () => setExpandedModules(new Set())
const scrollToModule = (id: string) => {
setExpandedModules((prev) => new Set([...prev, id]))
setTimeout(() => {
moduleRefs.current[id]?.scrollIntoView({ behavior: 'smooth', block: 'start' })
}, 100)
}
return (
<div className="min-h-screen bg-gray-50">
{/* Header */}
<div className="bg-white border-b border-gray-200 sticky top-0 z-20">
<div className="max-w-7xl mx-auto px-4 py-4">
<div className="flex items-center justify-between mb-3">
<div>
<h1 className="text-xl font-bold text-gray-900">API-Referenz</h1>
<p className="text-sm text-gray-500 mt-0.5">
{stats.total} Endpoints in {stats.modules} Modulen
</p>
</div>
<div className="flex gap-2">
<button
onClick={expandAll}
className="px-3 py-1.5 text-xs font-medium text-gray-600 bg-gray-100 rounded-lg hover:bg-gray-200 transition-colors"
>
Alle aufklappen
</button>
<button
onClick={collapseAll}
className="px-3 py-1.5 text-xs font-medium text-gray-600 bg-gray-100 rounded-lg hover:bg-gray-200 transition-colors"
>
Alle zuklappen
</button>
</div>
</div>
{/* Search + Filters */}
<div className="flex flex-wrap gap-3 items-center">
<div className="relative flex-1 min-w-[240px]">
<svg
className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-gray-400"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z" />
</svg>
<input
type="text"
placeholder="Endpoint, Beschreibung oder Modul suchen..."
value={search}
onChange={(e) => setSearch(e.target.value)}
className="w-full pl-9 pr-4 py-2 text-sm border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
/>
{search && (
<button
onClick={() => setSearch('')}
className="absolute right-3 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-600"
>
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
)}
</div>
{/* Service Filter */}
<div className="flex rounded-lg border border-gray-300 overflow-hidden">
{([['all', 'Alle'], ['python', 'Python/FastAPI'], ['go', 'Go/Gin']] as const).map(([val, label]) => (
<button
key={val}
onClick={() => setServiceFilter(val)}
className={`px-3 py-2 text-xs font-medium transition-colors ${
serviceFilter === val
? 'bg-gray-900 text-white'
: 'bg-white text-gray-600 hover:bg-gray-50'
}`}
>
{label}
</button>
))}
</div>
{/* Method Filter */}
<div className="flex gap-1.5">
{(['all', 'GET', 'POST', 'PUT', 'DELETE', 'PATCH'] as const).map((m) => (
<button
key={m}
onClick={() => setMethodFilter(m)}
className={`px-2.5 py-1.5 text-xs font-mono font-bold rounded-md transition-colors ${
methodFilter === m
? m === 'all'
? 'bg-gray-900 text-white'
: METHOD_COLORS[m] + ' ring-2 ring-offset-1 ring-gray-400'
: 'bg-gray-100 text-gray-500 hover:bg-gray-200'
}`}
>
{m === 'all' ? 'ALLE' : m}
</button>
))}
</div>
</div>
</div>
</div>
<div className="max-w-7xl mx-auto px-4 py-6">
{/* Stats Cards */}
<div className="grid grid-cols-4 gap-4 mb-6">
{[
{ label: 'Endpoints gesamt', value: stats.total, color: 'text-gray-900' },
{ label: 'Python / FastAPI', value: stats.python, color: 'text-blue-700' },
{ label: 'Go / Gin', value: stats.go, color: 'text-emerald-700' },
{ label: 'Module', value: stats.modules, color: 'text-purple-700' },
].map((s) => (
<div key={s.label} className="bg-white rounded-lg border border-gray-200 p-4">
<p className="text-xs text-gray-500 mb-1">{s.label}</p>
<p className={`text-2xl font-bold ${s.color}`}>{s.value}</p>
</div>
))}
</div>
<div className="flex gap-6">
{/* Module Index (Sidebar) */}
<div className="hidden lg:block w-64 flex-shrink-0">
<div className="bg-white rounded-lg border border-gray-200 p-4 sticky top-[140px] max-h-[calc(100vh-180px)] overflow-y-auto">
<h3 className="text-xs font-semibold text-gray-400 uppercase tracking-wider mb-3">
Modul-Index ({filteredModules.length})
</h3>
<div className="space-y-0.5">
{filteredModules.map((m) => (
<button
key={m.id}
onClick={() => scrollToModule(m.id)}
className="w-full text-left px-2 py-1.5 text-xs rounded hover:bg-gray-100 transition-colors group flex items-center justify-between"
>
<span className="truncate text-gray-700 group-hover:text-gray-900">
{m.id}
</span>
<span className={`text-[10px] px-1.5 py-0.5 rounded-full ${
m.service === 'python' ? 'bg-blue-50 text-blue-600' : 'bg-emerald-50 text-emerald-600'
}`}>
{m.endpoints.length}
</span>
</button>
))}
</div>
</div>
</div>
{/* Main Content */}
<div className="flex-1 min-w-0">
{search && (
<p className="text-sm text-gray-500 mb-4">
{filteredTotal} Treffer in {filteredModules.length} Modulen
</p>
)}
<div className="space-y-3">
{filteredModules.map((m) => {
const isExpanded = expandedModules.has(m.id)
return (
<div
key={m.id}
ref={(el) => { moduleRefs.current[m.id] = el }}
className="bg-white rounded-lg border border-gray-200 overflow-hidden"
>
{/* Module Header */}
<button
onClick={() => toggleModule(m.id)}
className="w-full flex items-center justify-between px-4 py-3 hover:bg-gray-50 transition-colors"
>
<div className="flex items-center gap-3 min-w-0">
<svg
className={`w-4 h-4 text-gray-400 flex-shrink-0 transition-transform ${isExpanded ? 'rotate-90' : ''}`}
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 5l7 7-7 7" />
</svg>
<span className={`text-[10px] font-mono font-bold px-2 py-0.5 rounded ${
m.service === 'python' ? 'bg-blue-50 text-blue-700' : 'bg-emerald-50 text-emerald-700'
}`}>
{m.service === 'python' ? 'PY' : 'GO'}
</span>
<span className="text-sm font-medium text-gray-900 truncate">{m.name}</span>
</div>
<div className="flex items-center gap-2 flex-shrink-0 ml-3">
<span className="text-xs text-gray-400 font-mono">{m.basePath}</span>
<span className="text-xs bg-gray-100 text-gray-600 px-2 py-0.5 rounded-full">
{m.endpoints.length}
</span>
</div>
</button>
{/* Endpoints Table */}
{isExpanded && (
<div className="border-t border-gray-100">
<table className="w-full">
<thead>
<tr className="bg-gray-50 text-xs text-gray-500">
<th className="text-left px-4 py-2 w-20">Methode</th>
<th className="text-left px-4 py-2">Pfad</th>
<th className="text-left px-4 py-2">Beschreibung</th>
</tr>
</thead>
<tbody>
{m.endpoints.map((e, i) => (
<tr
key={`${e.method}-${e.path}-${i}`}
className="border-t border-gray-50 hover:bg-gray-50/50 transition-colors"
>
<td className="px-4 py-2">
<span className={`inline-block text-[11px] font-mono font-bold px-2 py-0.5 rounded ${METHOD_COLORS[e.method]}`}>
{e.method}
</span>
</td>
<td className="px-4 py-2 font-mono text-xs text-gray-800">
{e.path}
</td>
<td className="px-4 py-2 text-xs text-gray-600">
{e.description}
</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</div>
)
})}
</div>
{filteredModules.length === 0 && (
<div className="text-center py-12 text-gray-400">
<svg className="w-12 h-12 mx-auto mb-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={1.5} d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z" />
</svg>
<p className="text-sm">Keine Endpoints gefunden</p>
<p className="text-xs mt-1">Suchbegriff oder Filter anpassen</p>
</div>
)}
</div>
</div>
</div>
</div>
)
}

View File

@@ -726,6 +726,18 @@ export function SDKSidebar({ collapsed = false, onCollapsedChange }: SDKSidebarP
isActive={pathname === '/sdk/catalog-manager'} isActive={pathname === '/sdk/catalog-manager'}
collapsed={collapsed} collapsed={collapsed}
/> />
<AdditionalModuleItem
href="/sdk/api-docs"
icon={
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2}
d="M10 20l4-16m4 4l4 4-4 4M6 16l-4-4 4-4" />
</svg>
}
label="API-Referenz"
isActive={pathname === '/sdk/api-docs'}
collapsed={collapsed}
/>
<Link <Link
href="/sdk/change-requests" href="/sdk/change-requests"
className={`flex items-center gap-3 px-4 py-2.5 text-sm transition-colors ${ className={`flex items-center gap-3 px-4 py-2.5 text-sm transition-colors ${

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH'
export type BackendService = 'python' | 'go'
export interface ApiEndpoint {
method: HttpMethod
path: string
description: string
service: BackendService
}
export interface ApiModule {
id: string
name: string
service: BackendService
basePath: string
endpoints: ApiEndpoint[]
}

View File

@@ -0,0 +1,216 @@
"""
Generic CRUD Router Factory for Compliance API.
Creates standardized CRUD endpoints (list, create, get, update, delete)
for simple resource tables that follow the tenant-isolated pattern:
- Table has `id`, `tenant_id`, `created_at`, `updated_at` columns
- All queries filtered by tenant_id
Usage:
router = create_crud_router(
prefix="/security-backlog",
table_name="compliance_security_backlog",
tag="security-backlog",
columns=["title", "description", "type", "severity", "status", ...],
search_columns=["title", "description"],
filter_columns=["status", "severity", "type"],
order_by="created_at DESC",
resource_name="Security item",
)
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Callable
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import text
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id
from .db_utils import row_to_dict
logger = logging.getLogger(__name__)
def create_crud_router(
prefix: str,
table_name: str,
tag: str,
columns: List[str],
search_columns: Optional[List[str]] = None,
filter_columns: Optional[List[str]] = None,
order_by: str = "created_at DESC",
resource_name: str = "Item",
stats_query: Optional[str] = None,
stats_defaults: Optional[Dict[str, int]] = None,
) -> APIRouter:
"""Create a CRUD router with list, create, get/{id}, update/{id}, delete/{id}.
Args:
prefix: URL prefix (e.g. "/security-backlog")
table_name: PostgreSQL table name
tag: OpenAPI tag
columns: Writable column names (excluding id, tenant_id, created_at, updated_at)
search_columns: Columns to ILIKE-search (default: ["title", "description"])
filter_columns: Columns to filter by exact match via query params
order_by: SQL ORDER BY clause
resource_name: Human-readable name for error messages
stats_query: Optional custom SQL for /stats endpoint (must accept :tenant_id param)
stats_defaults: Default dict for stats when no rows found
"""
router = APIRouter(prefix=prefix, tags=[tag])
_search_cols = search_columns or ["title", "description"]
_filter_cols = filter_columns or []
# ── LIST ──────────────────────────────────────────────────────────────
@router.get("")
async def list_items(
search: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
**kwargs,
):
where = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
# Dynamic filter columns from query string
# We can't use **kwargs with FastAPI easily, so we handle this in a wrapper
if search and _search_cols:
clauses = [f"{c} ILIKE :search" for c in _search_cols]
where.append(f"({' OR '.join(clauses)})")
params["search"] = f"%{search}%"
where_sql = " AND ".join(where)
total_row = db.execute(
text(f"SELECT COUNT(*) FROM {table_name} WHERE {where_sql}"),
params,
).fetchone()
total = total_row[0] if total_row else 0
rows = db.execute(
text(f"""
SELECT * FROM {table_name}
WHERE {where_sql}
ORDER BY {order_by}
LIMIT :limit OFFSET :offset
"""),
params,
).fetchall()
return {"items": [row_to_dict(r) for r in rows], "total": total}
# ── STATS (optional) ─────────────────────────────────────────────────
if stats_query:
@router.get("/stats")
async def get_stats(
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
):
row = db.execute(text(stats_query), {"tenant_id": tenant_id}).fetchone()
if row:
d = dict(row._mapping)
return {k: (v or 0) for k, v in d.items()}
return stats_defaults or {}
# ── CREATE ────────────────────────────────────────────────────────────
@router.post("", status_code=201)
async def create_item(
payload: dict = {},
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
):
col_names = ["tenant_id"]
col_params = [":tenant_id"]
values: Dict[str, Any] = {"tenant_id": tenant_id}
for col in columns:
if col in payload:
col_names.append(col)
col_params.append(f":{col}")
values[col] = payload[col]
row = db.execute(
text(f"""
INSERT INTO {table_name} ({', '.join(col_names)})
VALUES ({', '.join(col_params)})
RETURNING *
"""),
values,
).fetchone()
db.commit()
return row_to_dict(row)
# ── GET BY ID ─────────────────────────────────────────────────────────
@router.get("/{item_id}")
async def get_item(
item_id: str,
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
):
row = db.execute(
text(f"SELECT * FROM {table_name} WHERE id = :id AND tenant_id = :tenant_id"),
{"id": item_id, "tenant_id": tenant_id},
).fetchone()
if not row:
raise HTTPException(status_code=404, detail=f"{resource_name} not found")
return row_to_dict(row)
# ── UPDATE ────────────────────────────────────────────────────────────
@router.put("/{item_id}")
async def update_item(
item_id: str,
payload: dict = {},
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
):
updates: Dict[str, Any] = {
"id": item_id,
"tenant_id": tenant_id,
"updated_at": datetime.utcnow(),
}
set_clauses = ["updated_at = :updated_at"]
for field, value in payload.items():
if field in columns:
updates[field] = value
set_clauses.append(f"{field} = :{field}")
if len(set_clauses) == 1:
raise HTTPException(status_code=400, detail="No fields to update")
row = db.execute(
text(f"""
UPDATE {table_name}
SET {', '.join(set_clauses)}
WHERE id = :id AND tenant_id = :tenant_id
RETURNING *
"""),
updates,
).fetchone()
db.commit()
if not row:
raise HTTPException(status_code=404, detail=f"{resource_name} not found")
return row_to_dict(row)
# ── DELETE ────────────────────────────────────────────────────────────
@router.delete("/{item_id}", status_code=204)
async def delete_item(
item_id: str,
db: Session = Depends(get_db),
tenant_id: str = Depends(get_tenant_id),
):
result = db.execute(
text(f"DELETE FROM {table_name} WHERE id = :id AND tenant_id = :tenant_id"),
{"id": item_id, "tenant_id": tenant_id},
)
db.commit()
if result.rowcount == 0:
raise HTTPException(status_code=404, detail=f"{resource_name} not found")
return router

View File

@@ -0,0 +1,25 @@
"""
Shared database utility functions for Compliance API routes.
Provides common helpers used across multiple route files:
- row_to_dict: Convert SQLAlchemy Row to JSON-safe dict
"""
from datetime import datetime, date
from typing import Any, Dict
def row_to_dict(row) -> Dict[str, Any]:
"""Convert a SQLAlchemy Row/RowMapping to a JSON-serializable dict.
Handles datetime serialization and non-standard types.
"""
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif isinstance(val, date):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result

View File

@@ -13,7 +13,7 @@ Endpoints:
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Optional, List, Any, Dict from typing import Optional, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query, Header
from pydantic import BaseModel from pydantic import BaseModel
@@ -21,12 +21,12 @@ from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/escalations", tags=["escalations"]) router = APIRouter(prefix="/escalations", tags=["escalations"])
DEFAULT_TENANT_ID = '9282a473-5c95-4b3a-bf78-0ecc0ec71d3e'
# ============================================================================= # =============================================================================
# Pydantic Schemas # Pydantic Schemas
@@ -59,17 +59,6 @@ class EscalationStatusUpdate(BaseModel):
resolved_at: Optional[datetime] = None resolved_at: Optional[datetime] = None
def _row_to_dict(row) -> Dict[str, Any]:
"""Convert a SQLAlchemy row to a serialisable dict."""
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, type(None))):
result[key] = str(val)
return result
# ============================================================================= # =============================================================================
# Routes # Routes
# ============================================================================= # =============================================================================
@@ -80,14 +69,12 @@ async def list_escalations(
priority: Optional[str] = Query(None), priority: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=500), limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""List escalations with optional filters.""" """List escalations with optional filters."""
tid = tenant_id or DEFAULT_TENANT_ID
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tid, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
if status: if status:
where_clauses.append("status = :status") where_clauses.append("status = :status")
@@ -122,13 +109,11 @@ async def list_escalations(
@router.post("", status_code=201) @router.post("", status_code=201)
async def create_escalation( async def create_escalation(
request: EscalationCreate, request: EscalationCreate,
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
user_id: Optional[str] = Header(None, alias="x-user-id"), user_id: Optional[str] = Header(None, alias="x-user-id"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Create a new escalation.""" """Create a new escalation."""
tid = tenant_id or DEFAULT_TENANT_ID
row = db.execute( row = db.execute(
text( text(
""" """
@@ -142,7 +127,7 @@ async def create_escalation(
""" """
), ),
{ {
"tenant_id": tid, "tenant_id": tenant_id,
"title": request.title, "title": request.title,
"description": request.description, "description": request.description,
"priority": request.priority, "priority": request.priority,
@@ -161,18 +146,16 @@ async def create_escalation(
@router.get("/stats") @router.get("/stats")
async def get_stats( async def get_stats(
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Return counts per status and priority.""" """Return counts per status and priority."""
tid = tenant_id or DEFAULT_TENANT_ID
status_rows = db.execute( status_rows = db.execute(
text( text(
"SELECT status, COUNT(*) as cnt FROM compliance_escalations " "SELECT status, COUNT(*) as cnt FROM compliance_escalations "
"WHERE tenant_id = :tenant_id GROUP BY status" "WHERE tenant_id = :tenant_id GROUP BY status"
), ),
{"tenant_id": tid}, {"tenant_id": tenant_id},
).fetchall() ).fetchall()
priority_rows = db.execute( priority_rows = db.execute(
@@ -180,12 +163,12 @@ async def get_stats(
"SELECT priority, COUNT(*) as cnt FROM compliance_escalations " "SELECT priority, COUNT(*) as cnt FROM compliance_escalations "
"WHERE tenant_id = :tenant_id GROUP BY priority" "WHERE tenant_id = :tenant_id GROUP BY priority"
), ),
{"tenant_id": tid}, {"tenant_id": tenant_id},
).fetchall() ).fetchall()
total_row = db.execute( total_row = db.execute(
text("SELECT COUNT(*) FROM compliance_escalations WHERE tenant_id = :tenant_id"), text("SELECT COUNT(*) FROM compliance_escalations WHERE tenant_id = :tenant_id"),
{"tenant_id": tid}, {"tenant_id": tenant_id},
).fetchone() ).fetchone()
active_row = db.execute( active_row = db.execute(
@@ -193,7 +176,7 @@ async def get_stats(
"SELECT COUNT(*) FROM compliance_escalations " "SELECT COUNT(*) FROM compliance_escalations "
"WHERE tenant_id = :tenant_id AND status NOT IN ('resolved', 'closed')" "WHERE tenant_id = :tenant_id AND status NOT IN ('resolved', 'closed')"
), ),
{"tenant_id": tid}, {"tenant_id": tenant_id},
).fetchone() ).fetchone()
by_status = {"open": 0, "in_progress": 0, "escalated": 0, "resolved": 0, "closed": 0} by_status = {"open": 0, "in_progress": 0, "escalated": 0, "resolved": 0, "closed": 0}
@@ -217,17 +200,16 @@ async def get_stats(
@router.get("/{escalation_id}") @router.get("/{escalation_id}")
async def get_escalation( async def get_escalation(
escalation_id: str, escalation_id: str,
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get a single escalation by ID.""" """Get a single escalation by ID."""
tid = tenant_id or DEFAULT_TENANT_ID
row = db.execute( row = db.execute(
text( text(
"SELECT * FROM compliance_escalations " "SELECT * FROM compliance_escalations "
"WHERE id = :id AND tenant_id = :tenant_id" "WHERE id = :id AND tenant_id = :tenant_id"
), ),
{"id": escalation_id, "tenant_id": tid}, {"id": escalation_id, "tenant_id": tenant_id},
).fetchone() ).fetchone()
if not row: if not row:
raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found") raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found")
@@ -238,18 +220,16 @@ async def get_escalation(
async def update_escalation( async def update_escalation(
escalation_id: str, escalation_id: str,
request: EscalationUpdate, request: EscalationUpdate,
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Update an escalation's fields.""" """Update an escalation's fields."""
tid = tenant_id or DEFAULT_TENANT_ID
existing = db.execute( existing = db.execute(
text( text(
"SELECT id FROM compliance_escalations " "SELECT id FROM compliance_escalations "
"WHERE id = :id AND tenant_id = :tenant_id" "WHERE id = :id AND tenant_id = :tenant_id"
), ),
{"id": escalation_id, "tenant_id": tid}, {"id": escalation_id, "tenant_id": tenant_id},
).fetchone() ).fetchone()
if not existing: if not existing:
raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found") raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found")
@@ -281,18 +261,16 @@ async def update_escalation(
async def update_status( async def update_status(
escalation_id: str, escalation_id: str,
request: EscalationStatusUpdate, request: EscalationStatusUpdate,
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Update only the status of an escalation.""" """Update only the status of an escalation."""
tid = tenant_id or DEFAULT_TENANT_ID
existing = db.execute( existing = db.execute(
text( text(
"SELECT id FROM compliance_escalations " "SELECT id FROM compliance_escalations "
"WHERE id = :id AND tenant_id = :tenant_id" "WHERE id = :id AND tenant_id = :tenant_id"
), ),
{"id": escalation_id, "tenant_id": tid}, {"id": escalation_id, "tenant_id": tenant_id},
).fetchone() ).fetchone()
if not existing: if not existing:
raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found") raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found")
@@ -321,18 +299,16 @@ async def update_status(
@router.delete("/{escalation_id}") @router.delete("/{escalation_id}")
async def delete_escalation( async def delete_escalation(
escalation_id: str, escalation_id: str,
tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), tenant_id: str = Depends(_get_tenant_id),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Delete an escalation.""" """Delete an escalation."""
tid = tenant_id or DEFAULT_TENANT_ID
existing = db.execute( existing = db.execute(
text( text(
"SELECT id FROM compliance_escalations " "SELECT id FROM compliance_escalations "
"WHERE id = :id AND tenant_id = :tenant_id" "WHERE id = :id AND tenant_id = :tenant_id"
), ),
{"id": escalation_id, "tenant_id": tid}, {"id": escalation_id, "tenant_id": tenant_id},
).fetchone() ).fetchone()
if not existing: if not existing:
raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found") raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found")

View File

@@ -18,19 +18,18 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional, List, Any, Dict from typing import Optional, List, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/legal-templates", tags=["legal-templates"]) router = APIRouter(prefix="/legal-templates", tags=["legal-templates"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
VALID_DOCUMENT_TYPES = { VALID_DOCUMENT_TYPES = {
# Original types # Original types
"privacy_policy", "privacy_policy",
@@ -105,30 +104,6 @@ class LegalTemplateUpdate(BaseModel):
inspiration_sources: Optional[List[Any]] = None inspiration_sources: Optional[List[Any]] = None
# =============================================================================
# Helpers
# =============================================================================
def _row_to_dict(row) -> Dict[str, Any]:
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result
def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
if x_tenant_id:
try:
UUID(x_tenant_id)
return x_tenant_id
except ValueError:
pass
return DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Routes # Routes
# ============================================================================= # =============================================================================
@@ -142,10 +117,9 @@ async def list_legal_templates(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List legal templates with optional filters.""" """List legal templates with optional filters."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -192,10 +166,9 @@ async def list_legal_templates(
@router.get("/status") @router.get("/status")
async def get_templates_status( async def get_templates_status(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return template counts by document_type.""" """Return template counts by document_type."""
tenant_id = _get_tenant_id(x_tenant_id)
total_row = db.execute( total_row = db.execute(
text("SELECT COUNT(*) FROM compliance_legal_templates WHERE tenant_id = :tenant_id"), text("SELECT COUNT(*) FROM compliance_legal_templates WHERE tenant_id = :tenant_id"),
@@ -234,10 +207,9 @@ async def get_templates_status(
@router.get("/sources") @router.get("/sources")
async def get_template_sources( async def get_template_sources(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return distinct source_name values.""" """Return distinct source_name values."""
tenant_id = _get_tenant_id(x_tenant_id)
rows = db.execute( rows = db.execute(
text("SELECT DISTINCT source_name FROM compliance_legal_templates WHERE tenant_id = :tenant_id ORDER BY source_name"), text("SELECT DISTINCT source_name FROM compliance_legal_templates WHERE tenant_id = :tenant_id ORDER BY source_name"),
@@ -251,10 +223,9 @@ async def get_template_sources(
async def get_legal_template( async def get_legal_template(
template_id: str, template_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Fetch a single template by ID.""" """Fetch a single template by ID."""
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute( row = db.execute(
text("SELECT * FROM compliance_legal_templates WHERE id = :id AND tenant_id = :tenant_id"), text("SELECT * FROM compliance_legal_templates WHERE id = :id AND tenant_id = :tenant_id"),
{"id": template_id, "tenant_id": tenant_id}, {"id": template_id, "tenant_id": tenant_id},
@@ -268,10 +239,9 @@ async def get_legal_template(
async def create_legal_template( async def create_legal_template(
payload: LegalTemplateCreate, payload: LegalTemplateCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Create a new legal template.""" """Create a new legal template."""
tenant_id = _get_tenant_id(x_tenant_id)
if payload.document_type not in VALID_DOCUMENT_TYPES: if payload.document_type not in VALID_DOCUMENT_TYPES:
raise HTTPException( raise HTTPException(
@@ -335,10 +305,9 @@ async def update_legal_template(
template_id: str, template_id: str,
payload: LegalTemplateUpdate, payload: LegalTemplateUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Update an existing legal template.""" """Update an existing legal template."""
tenant_id = _get_tenant_id(x_tenant_id)
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
if not updates: if not updates:
@@ -385,10 +354,9 @@ async def update_legal_template(
async def delete_legal_template( async def delete_legal_template(
template_id: str, template_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Delete a legal template.""" """Delete a legal template."""
tenant_id = _get_tenant_id(x_tenant_id)
result = db.execute( result = db.execute(
text("DELETE FROM compliance_legal_templates WHERE id = :id AND tenant_id = :tenant_id"), text("DELETE FROM compliance_legal_templates WHERE id = :id AND tenant_id = :tenant_id"),
{"id": template_id, "tenant_id": tenant_id}, {"id": template_id, "tenant_id": tenant_id},

View File

@@ -16,19 +16,18 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional, List, Any, Dict from typing import Optional, List, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/loeschfristen", tags=["loeschfristen"]) router = APIRouter(prefix="/loeschfristen", tags=["loeschfristen"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
# ============================================================================= # =============================================================================
# Pydantic Schemas # Pydantic Schemas
@@ -105,26 +104,6 @@ JSONB_FIELDS = {
} }
def _row_to_dict(row) -> Dict[str, Any]:
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result
def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
if x_tenant_id:
try:
UUID(x_tenant_id)
return x_tenant_id
except ValueError:
pass
return DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Routes # Routes
# ============================================================================= # =============================================================================
@@ -137,10 +116,9 @@ async def list_loeschfristen(
limit: int = Query(500, ge=1, le=1000), limit: int = Query(500, ge=1, le=1000),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List Loeschfristen with optional filters.""" """List Loeschfristen with optional filters."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -189,10 +167,9 @@ async def list_loeschfristen(
@router.get("/stats") @router.get("/stats")
async def get_loeschfristen_stats( async def get_loeschfristen_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return Loeschfristen statistics.""" """Return Loeschfristen statistics."""
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute(text(""" row = db.execute(text("""
SELECT SELECT
@@ -222,10 +199,9 @@ async def get_loeschfristen_stats(
async def create_loeschfrist( async def create_loeschfrist(
payload: LoeschfristCreate, payload: LoeschfristCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Create a new Loeschfrist policy.""" """Create a new Loeschfrist policy."""
tenant_id = _get_tenant_id(x_tenant_id)
data = payload.model_dump() data = payload.model_dump()
@@ -257,9 +233,8 @@ async def create_loeschfrist(
async def get_loeschfrist( async def get_loeschfrist(
policy_id: str, policy_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute( row = db.execute(
text("SELECT * FROM compliance_loeschfristen WHERE id = :id AND tenant_id = :tenant_id"), text("SELECT * FROM compliance_loeschfristen WHERE id = :id AND tenant_id = :tenant_id"),
{"id": policy_id, "tenant_id": tenant_id}, {"id": policy_id, "tenant_id": tenant_id},
@@ -274,10 +249,9 @@ async def update_loeschfrist(
policy_id: str, policy_id: str,
payload: LoeschfristUpdate, payload: LoeschfristUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Full update of a Loeschfrist policy.""" """Full update of a Loeschfrist policy."""
tenant_id = _get_tenant_id(x_tenant_id)
updates: Dict[str, Any] = {"id": policy_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()} updates: Dict[str, Any] = {"id": policy_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()}
set_clauses = ["updated_at = :updated_at"] set_clauses = ["updated_at = :updated_at"]
@@ -314,10 +288,9 @@ async def update_loeschfrist_status(
policy_id: str, policy_id: str,
payload: StatusUpdate, payload: StatusUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Quick status update.""" """Quick status update."""
tenant_id = _get_tenant_id(x_tenant_id)
valid = {"DRAFT", "ACTIVE", "REVIEW_NEEDED", "ARCHIVED"} valid = {"DRAFT", "ACTIVE", "REVIEW_NEEDED", "ARCHIVED"}
if payload.status not in valid: if payload.status not in valid:
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid)}") raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {', '.join(valid)}")
@@ -342,9 +315,8 @@ async def update_loeschfrist_status(
async def delete_loeschfrist( async def delete_loeschfrist(
policy_id: str, policy_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
result = db.execute( result = db.execute(
text("DELETE FROM compliance_loeschfristen WHERE id = :id AND tenant_id = :tenant_id"), text("DELETE FROM compliance_loeschfristen WHERE id = :id AND tenant_id = :tenant_id"),
{"id": policy_id, "tenant_id": tenant_id}, {"id": policy_id, "tenant_id": tenant_id},
@@ -362,11 +334,10 @@ async def delete_loeschfrist(
async def list_loeschfristen_versions( async def list_loeschfristen_versions(
policy_id: str, policy_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List all versions for a Loeschfrist.""" """List all versions for a Loeschfrist."""
from .versioning_utils import list_versions from .versioning_utils import list_versions
tenant_id = _get_tenant_id(x_tenant_id)
return list_versions(db, "loeschfristen", policy_id, tenant_id) return list_versions(db, "loeschfristen", policy_id, tenant_id)
@@ -375,11 +346,10 @@ async def get_loeschfristen_version(
policy_id: str, policy_id: str,
version_number: int, version_number: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Get a specific Loeschfristen version with full snapshot.""" """Get a specific Loeschfristen version with full snapshot."""
from .versioning_utils import get_version from .versioning_utils import get_version
tenant_id = _get_tenant_id(x_tenant_id)
v = get_version(db, "loeschfristen", policy_id, version_number, tenant_id) v = get_version(db, "loeschfristen", policy_id, version_number, tenant_id)
if not v: if not v:
raise HTTPException(status_code=404, detail=f"Version {version_number} not found") raise HTTPException(status_code=404, detail=f"Version {version_number} not found")

View File

@@ -14,7 +14,6 @@ Endpoints:
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Optional, List, Any, Dict from typing import Optional, List, Any, Dict
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query, Header
from pydantic import BaseModel from pydantic import BaseModel
@@ -22,12 +21,12 @@ from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/obligations", tags=["obligations"]) router = APIRouter(prefix="/obligations", tags=["obligations"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
# ============================================================================= # =============================================================================
# Pydantic Schemas # Pydantic Schemas
@@ -65,25 +64,6 @@ class ObligationStatusUpdate(BaseModel):
status: str status: str
def _row_to_dict(row) -> Dict[str, Any]:
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result
def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
if x_tenant_id:
try:
UUID(x_tenant_id)
return x_tenant_id
except ValueError:
pass
return DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Routes # Routes
@@ -98,10 +78,9 @@ async def list_obligations(
limit: int = Query(100, ge=1, le=500), limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List obligations with optional filters.""" """List obligations with optional filters."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -159,10 +138,9 @@ async def list_obligations(
@router.get("/stats") @router.get("/stats")
async def get_obligation_stats( async def get_obligation_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return obligation counts per status and priority.""" """Return obligation counts per status and priority."""
tenant_id = _get_tenant_id(x_tenant_id)
rows = db.execute(text(""" rows = db.execute(text("""
SELECT SELECT
@@ -187,11 +165,10 @@ async def get_obligation_stats(
async def create_obligation( async def create_obligation(
payload: ObligationCreate, payload: ObligationCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
x_user_id: Optional[str] = Header(None), x_user_id: Optional[str] = Header(None),
): ):
"""Create a new compliance obligation.""" """Create a new compliance obligation."""
tenant_id = _get_tenant_id(x_tenant_id)
logger.info("create_obligation user_id=%s tenant_id=%s title=%s", x_user_id, tenant_id, payload.title) logger.info("create_obligation user_id=%s tenant_id=%s title=%s", x_user_id, tenant_id, payload.title)
import json import json
@@ -228,9 +205,8 @@ async def create_obligation(
async def get_obligation( async def get_obligation(
obligation_id: str, obligation_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute(text(""" row = db.execute(text("""
SELECT * FROM compliance_obligations SELECT * FROM compliance_obligations
WHERE id = :id AND tenant_id = :tenant_id WHERE id = :id AND tenant_id = :tenant_id
@@ -245,11 +221,10 @@ async def update_obligation(
obligation_id: str, obligation_id: str,
payload: ObligationUpdate, payload: ObligationUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
x_user_id: Optional[str] = Header(None), x_user_id: Optional[str] = Header(None),
): ):
"""Update an obligation's fields.""" """Update an obligation's fields."""
tenant_id = _get_tenant_id(x_tenant_id)
logger.info("update_obligation user_id=%s tenant_id=%s id=%s", x_user_id, tenant_id, obligation_id) logger.info("update_obligation user_id=%s tenant_id=%s id=%s", x_user_id, tenant_id, obligation_id)
import json import json
@@ -285,11 +260,10 @@ async def update_obligation_status(
obligation_id: str, obligation_id: str,
payload: ObligationStatusUpdate, payload: ObligationStatusUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
x_user_id: Optional[str] = Header(None), x_user_id: Optional[str] = Header(None),
): ):
"""Quick status update for an obligation.""" """Quick status update for an obligation."""
tenant_id = _get_tenant_id(x_tenant_id)
logger.info("update_obligation_status user_id=%s tenant_id=%s id=%s status=%s", x_user_id, tenant_id, obligation_id, payload.status) logger.info("update_obligation_status user_id=%s tenant_id=%s id=%s status=%s", x_user_id, tenant_id, obligation_id, payload.status)
valid_statuses = {"pending", "in-progress", "completed", "overdue"} valid_statuses = {"pending", "in-progress", "completed", "overdue"}
if payload.status not in valid_statuses: if payload.status not in valid_statuses:
@@ -312,10 +286,9 @@ async def update_obligation_status(
async def delete_obligation( async def delete_obligation(
obligation_id: str, obligation_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
x_user_id: Optional[str] = Header(None), x_user_id: Optional[str] = Header(None),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
logger.info("delete_obligation user_id=%s tenant_id=%s id=%s", x_user_id, tenant_id, obligation_id) logger.info("delete_obligation user_id=%s tenant_id=%s id=%s", x_user_id, tenant_id, obligation_id)
result = db.execute(text(""" result = db.execute(text("""
DELETE FROM compliance_obligations DELETE FROM compliance_obligations
@@ -334,11 +307,10 @@ async def delete_obligation(
async def list_obligation_versions( async def list_obligation_versions(
obligation_id: str, obligation_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List all versions for an Obligation.""" """List all versions for an Obligation."""
from .versioning_utils import list_versions from .versioning_utils import list_versions
tenant_id = _get_tenant_id(x_tenant_id)
return list_versions(db, "obligation", obligation_id, tenant_id) return list_versions(db, "obligation", obligation_id, tenant_id)
@@ -347,11 +319,10 @@ async def get_obligation_version(
obligation_id: str, obligation_id: str,
version_number: int, version_number: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Get a specific Obligation version with full snapshot.""" """Get a specific Obligation version with full snapshot."""
from .versioning_utils import get_version from .versioning_utils import get_version
tenant_id = _get_tenant_id(x_tenant_id)
v = get_version(db, "obligation", obligation_id, version_number, tenant_id) v = get_version(db, "obligation", obligation_id, version_number, tenant_id)
if not v: if not v:
raise HTTPException(status_code=404, detail=f"Version {version_number} not found") raise HTTPException(status_code=404, detail=f"Version {version_number} not found")

View File

@@ -13,19 +13,18 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional, Any, Dict from typing import Optional, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/quality", tags=["quality"]) router = APIRouter(prefix="/quality", tags=["quality"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
# ============================================================================= # =============================================================================
# Pydantic Schemas # Pydantic Schemas
@@ -69,25 +68,6 @@ class TestUpdate(BaseModel):
last_run: Optional[datetime] = None last_run: Optional[datetime] = None
def _row_to_dict(row) -> Dict[str, Any]:
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result
def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
if x_tenant_id:
try:
UUID(x_tenant_id)
return x_tenant_id
except ValueError:
pass
return DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Stats # Stats
@@ -96,10 +76,9 @@ def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
@router.get("/stats") @router.get("/stats")
async def get_quality_stats( async def get_quality_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return quality dashboard stats.""" """Return quality dashboard stats."""
tenant_id = _get_tenant_id(x_tenant_id)
metrics_row = db.execute(text(""" metrics_row = db.execute(text("""
SELECT SELECT
@@ -142,10 +121,9 @@ async def list_metrics(
limit: int = Query(100, ge=1, le=500), limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List quality metrics.""" """List quality metrics."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -181,10 +159,9 @@ async def list_metrics(
async def create_metric( async def create_metric(
payload: MetricCreate, payload: MetricCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Create a new quality metric.""" """Create a new quality metric."""
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute(text(""" row = db.execute(text("""
INSERT INTO compliance_quality_metrics INSERT INTO compliance_quality_metrics
@@ -211,10 +188,9 @@ async def update_metric(
metric_id: str, metric_id: str,
payload: MetricUpdate, payload: MetricUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Update a quality metric.""" """Update a quality metric."""
tenant_id = _get_tenant_id(x_tenant_id)
updates: Dict[str, Any] = {"id": metric_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()} updates: Dict[str, Any] = {"id": metric_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()}
set_clauses = ["updated_at = :updated_at"] set_clauses = ["updated_at = :updated_at"]
@@ -243,9 +219,8 @@ async def update_metric(
async def delete_metric( async def delete_metric(
metric_id: str, metric_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
result = db.execute(text(""" result = db.execute(text("""
DELETE FROM compliance_quality_metrics DELETE FROM compliance_quality_metrics
WHERE id = :id AND tenant_id = :tenant_id WHERE id = :id AND tenant_id = :tenant_id
@@ -266,10 +241,9 @@ async def list_tests(
limit: int = Query(100, ge=1, le=500), limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List quality tests.""" """List quality tests."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -305,10 +279,9 @@ async def list_tests(
async def create_test( async def create_test(
payload: TestCreate, payload: TestCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Create a new quality test entry.""" """Create a new quality test entry."""
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute(text(""" row = db.execute(text("""
INSERT INTO compliance_quality_tests INSERT INTO compliance_quality_tests
@@ -334,10 +307,9 @@ async def update_test(
test_id: str, test_id: str,
payload: TestUpdate, payload: TestUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Update a quality test.""" """Update a quality test."""
tenant_id = _get_tenant_id(x_tenant_id)
updates: Dict[str, Any] = {"id": test_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()} updates: Dict[str, Any] = {"id": test_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()}
set_clauses = ["updated_at = :updated_at"] set_clauses = ["updated_at = :updated_at"]
@@ -366,9 +338,8 @@ async def update_test(
async def delete_test( async def delete_test(
test_id: str, test_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
result = db.execute(text(""" result = db.execute(text("""
DELETE FROM compliance_quality_tests DELETE FROM compliance_quality_tests
WHERE id = :id AND tenant_id = :tenant_id WHERE id = :id AND tenant_id = :tenant_id

View File

@@ -13,19 +13,18 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional, Any, Dict from typing import Optional, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, Query, Header from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from classroom_engine.database import get_db from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/security-backlog", tags=["security-backlog"]) router = APIRouter(prefix="/security-backlog", tags=["security-backlog"])
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
# ============================================================================= # =============================================================================
# Pydantic Schemas # Pydantic Schemas
@@ -61,25 +60,6 @@ class SecurityItemUpdate(BaseModel):
remediation: Optional[str] = None remediation: Optional[str] = None
def _row_to_dict(row) -> Dict[str, Any]:
result = dict(row._mapping)
for key, val in result.items():
if isinstance(val, datetime):
result[key] = val.isoformat()
elif hasattr(val, '__str__') and not isinstance(val, (str, int, float, bool, list, dict, type(None))):
result[key] = str(val)
return result
def _get_tenant_id(x_tenant_id: Optional[str] = Header(None)) -> str:
if x_tenant_id:
try:
UUID(x_tenant_id)
return x_tenant_id
except ValueError:
pass
return DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Routes # Routes
@@ -94,10 +74,9 @@ async def list_security_items(
limit: int = Query(100, ge=1, le=500), limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""List security backlog items with optional filters.""" """List security backlog items with optional filters."""
tenant_id = _get_tenant_id(x_tenant_id)
where_clauses = ["tenant_id = :tenant_id"] where_clauses = ["tenant_id = :tenant_id"]
params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset} params: Dict[str, Any] = {"tenant_id": tenant_id, "limit": limit, "offset": offset}
@@ -155,10 +134,9 @@ async def list_security_items(
@router.get("/stats") @router.get("/stats")
async def get_security_stats( async def get_security_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Return security backlog counts.""" """Return security backlog counts."""
tenant_id = _get_tenant_id(x_tenant_id)
rows = db.execute(text(""" rows = db.execute(text("""
SELECT SELECT
@@ -189,10 +167,9 @@ async def get_security_stats(
async def create_security_item( async def create_security_item(
payload: SecurityItemCreate, payload: SecurityItemCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Create a new security backlog item.""" """Create a new security backlog item."""
tenant_id = _get_tenant_id(x_tenant_id)
row = db.execute(text(""" row = db.execute(text("""
INSERT INTO compliance_security_backlog INSERT INTO compliance_security_backlog
@@ -226,10 +203,9 @@ async def update_security_item(
item_id: str, item_id: str,
payload: SecurityItemUpdate, payload: SecurityItemUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
"""Update a security backlog item.""" """Update a security backlog item."""
tenant_id = _get_tenant_id(x_tenant_id)
updates: Dict[str, Any] = {"id": item_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()} updates: Dict[str, Any] = {"id": item_id, "tenant_id": tenant_id, "updated_at": datetime.utcnow()}
set_clauses = ["updated_at = :updated_at"] set_clauses = ["updated_at = :updated_at"]
@@ -258,9 +234,8 @@ async def update_security_item(
async def delete_security_item( async def delete_security_item(
item_id: str, item_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_tenant_id: Optional[str] = Header(None), tenant_id: str = Depends(_get_tenant_id),
): ):
tenant_id = _get_tenant_id(x_tenant_id)
result = db.execute(text(""" result = db.execute(text("""
DELETE FROM compliance_security_backlog DELETE FROM compliance_security_backlog
WHERE id = :id AND tenant_id = :tenant_id WHERE id = :id AND tenant_id = :tenant_id

View File

@@ -10,9 +10,13 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from .tenant_utils import get_tenant_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Table → FK column mapping # Table → FK column mapping
@@ -173,3 +177,50 @@ def get_version(
"approved_at": r[8].isoformat() if r[8] else None, "approved_at": r[8].isoformat() if r[8] else None,
"created_at": r[9].isoformat() if r[9] else None, "created_at": r[9].isoformat() if r[9] else None,
} }
def register_version_routes(
router: APIRouter,
doc_type: str,
id_param: str = "item_id",
resource_name: str = "Item",
):
"""Register GET /{id}/versions and GET /{id}/versions/{v} on an existing router.
Uses a standardized path param name `item_id` in the generated routes.
The actual URL path parameter can be customized via `id_param`.
Args:
router: The APIRouter to add version routes to
doc_type: One of the keys in VERSION_TABLES
id_param: Path parameter name in the URL (e.g. "obligation_id")
resource_name: Human-readable name for error messages
"""
# Capture doc_type and resource_name in closure
_doc_type = doc_type
_resource_name = resource_name
@router.get(f"/{{{id_param}}}/versions")
async def list_item_versions(
request: Request,
db: Session = Depends(get_db),
tid: str = Depends(get_tenant_id),
):
doc_id = request.path_params[id_param]
return list_versions(db, _doc_type, doc_id, tid)
@router.get(f"/{{{id_param}}}/versions/{{version_number}}")
async def get_item_version(
version_number: int,
request: Request,
db: Session = Depends(get_db),
tid: str = Depends(get_tenant_id),
):
doc_id = request.path_params[id_param]
v = get_version(db, _doc_type, doc_id, version_number, tid)
if not v:
raise HTTPException(
status_code=404,
detail=f"{_resource_name} version {version_number} not found",
)
return v

View File

@@ -11,9 +11,10 @@ from compliance.api.escalation_routes import (
EscalationCreate, EscalationCreate,
EscalationUpdate, EscalationUpdate,
EscalationStatusUpdate, EscalationStatusUpdate,
_row_to_dict,
DEFAULT_TENANT_ID,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
from classroom_engine.database import get_db from classroom_engine.database import get_db

View File

@@ -10,13 +10,13 @@ from fastapi import FastAPI
from compliance.api.legal_template_routes import ( from compliance.api.legal_template_routes import (
LegalTemplateCreate, LegalTemplateCreate,
LegalTemplateUpdate, LegalTemplateUpdate,
_row_to_dict,
_get_tenant_id,
DEFAULT_TENANT_ID,
VALID_DOCUMENT_TYPES, VALID_DOCUMENT_TYPES,
VALID_STATUSES, VALID_STATUSES,
router, router,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
from classroom_engine.database import get_db from classroom_engine.database import get_db
app = FastAPI() app = FastAPI()
@@ -205,22 +205,6 @@ class TestLegalTemplateDB:
assert isinstance(result["placeholders"], list) assert isinstance(result["placeholders"], list)
assert "{{COMPANY_NAME}}" in result["placeholders"] assert "{{COMPANY_NAME}}" in result["placeholders"]
def test_get_tenant_id_default(self):
"""_get_tenant_id returns default when no header provided."""
result = _get_tenant_id(None)
assert result == DEFAULT_TENANT_ID
def test_get_tenant_id_valid_uuid(self):
"""_get_tenant_id returns provided UUID when valid."""
custom_uuid = "12345678-1234-1234-1234-123456789abc"
result = _get_tenant_id(custom_uuid)
assert result == custom_uuid
def test_get_tenant_id_invalid_uuid(self):
"""_get_tenant_id falls back to default for invalid UUID."""
result = _get_tenant_id("not-a-uuid")
assert result == DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# TestLegalTemplateSearch # TestLegalTemplateSearch

View File

@@ -10,12 +10,12 @@ from compliance.api.loeschfristen_routes import (
LoeschfristCreate, LoeschfristCreate,
LoeschfristUpdate, LoeschfristUpdate,
StatusUpdate, StatusUpdate,
_row_to_dict,
_get_tenant_id,
DEFAULT_TENANT_ID,
JSONB_FIELDS, JSONB_FIELDS,
router, router,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)
@@ -128,16 +128,6 @@ class TestRowToDict:
assert result["retention_duration"] == 7 assert result["retention_duration"] == 7
class TestGetTenantId:
def test_valid_uuid_is_returned(self):
assert _get_tenant_id("9282a473-5c95-4b3a-bf78-0ecc0ec71d3e") == "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
def test_invalid_uuid_returns_default(self):
assert _get_tenant_id("not-a-uuid") == DEFAULT_TENANT_ID
def test_none_returns_default(self):
assert _get_tenant_id(None) == DEFAULT_TENANT_ID
class TestJsonbFields: class TestJsonbFields:
def test_jsonb_fields_set(self): def test_jsonb_fields_set(self):

View File

@@ -12,10 +12,10 @@ from compliance.api.obligation_routes import (
ObligationCreate, ObligationCreate,
ObligationUpdate, ObligationUpdate,
ObligationStatusUpdate, ObligationStatusUpdate,
_row_to_dict,
_get_tenant_id,
DEFAULT_TENANT_ID,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
from classroom_engine.database import get_db from classroom_engine.database import get_db
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -308,37 +308,6 @@ class TestRowToDict:
assert result["flag"] is False assert result["flag"] is False
# =============================================================================
# Helper Tests — _get_tenant_id
# =============================================================================
class TestGetTenantId:
def test_valid_uuid_returned(self):
tenant_id = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
result = _get_tenant_id(x_tenant_id=tenant_id)
assert result == tenant_id
def test_different_valid_uuid(self):
tenant_id = "12345678-1234-1234-1234-123456789abc"
result = _get_tenant_id(x_tenant_id=tenant_id)
assert result == tenant_id
def test_none_returns_default(self):
result = _get_tenant_id(x_tenant_id=None)
assert result == DEFAULT_TENANT_ID
def test_invalid_uuid_returns_default(self):
result = _get_tenant_id(x_tenant_id="not-a-valid-uuid")
assert result == DEFAULT_TENANT_ID
def test_empty_string_returns_default(self):
result = _get_tenant_id(x_tenant_id="")
assert result == DEFAULT_TENANT_ID
def test_partial_uuid_returns_default(self):
result = _get_tenant_id(x_tenant_id="9282a473-5c95-4b3a")
assert result == DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# Business Logic Tests # Business Logic Tests

View File

@@ -18,10 +18,10 @@ from compliance.api.quality_routes import (
MetricUpdate, MetricUpdate,
TestCreate, TestCreate,
TestUpdate, TestUpdate,
_row_to_dict,
_get_tenant_id,
DEFAULT_TENANT_ID,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
from classroom_engine.database import get_db from classroom_engine.database import get_db
# ============================================================================= # =============================================================================
@@ -283,31 +283,6 @@ class TestRowToDict:
assert result["count"] == 10 assert result["count"] == 10
# =============================================================================
# Helper Tests — _get_tenant_id
# =============================================================================
class TestGetTenantId:
def test_valid_uuid_returned(self):
result = _get_tenant_id(x_tenant_id=DEFAULT_TENANT)
assert result == DEFAULT_TENANT
def test_none_returns_default(self):
result = _get_tenant_id(x_tenant_id=None)
assert result == DEFAULT_TENANT_ID
def test_invalid_uuid_returns_default(self):
result = _get_tenant_id(x_tenant_id="invalid-uuid")
assert result == DEFAULT_TENANT_ID
def test_empty_string_returns_default(self):
result = _get_tenant_id(x_tenant_id="")
assert result == DEFAULT_TENANT_ID
def test_other_valid_tenant(self):
result = _get_tenant_id(x_tenant_id=OTHER_TENANT)
assert result == OTHER_TENANT
# ============================================================================= # =============================================================================
# HTTP Tests — GET /quality/stats # HTTP Tests — GET /quality/stats
@@ -910,19 +885,12 @@ class TestTenantIsolation:
resp_b = client.get("/quality/tests", headers={"X-Tenant-Id": OTHER_TENANT}) resp_b = client.get("/quality/tests", headers={"X-Tenant-Id": OTHER_TENANT})
assert resp_b.json()["total"] == 0 assert resp_b.json()["total"] == 0
def test_invalid_tenant_header_falls_back_to_default(self, mock_db): def test_invalid_tenant_header_returns_400(self, mock_db):
count_row = MagicMock()
count_row.__getitem__ = lambda self, i: 0
execute_result = MagicMock()
execute_result.fetchone.return_value = count_row
execute_result.fetchall.return_value = []
mock_db.execute.return_value = execute_result
response = client.get( response = client.get(
"/quality/metrics", "/quality/metrics",
headers={"X-Tenant-Id": "bad-uuid"}, headers={"X-Tenant-Id": "bad-uuid"},
) )
assert response.status_code == 200 assert response.status_code == 400
def test_delete_wrong_tenant_returns_404(self, mock_db): def test_delete_wrong_tenant_returns_404(self, mock_db):
"""Deleting a metric that belongs to a different tenant returns 404.""" """Deleting a metric that belongs to a different tenant returns 404."""

View File

@@ -16,10 +16,10 @@ from compliance.api.security_backlog_routes import (
router, router,
SecurityItemCreate, SecurityItemCreate,
SecurityItemUpdate, SecurityItemUpdate,
_row_to_dict,
_get_tenant_id,
DEFAULT_TENANT_ID,
) )
from compliance.api.db_utils import row_to_dict as _row_to_dict
DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e"
from classroom_engine.database import get_db from classroom_engine.database import get_db
# ============================================================================= # =============================================================================
@@ -241,35 +241,6 @@ class TestRowToDict:
assert result["active"] is True assert result["active"] is True
# =============================================================================
# Helper Tests — _get_tenant_id
# =============================================================================
class TestGetTenantId:
def test_valid_uuid_returned(self):
result = _get_tenant_id(x_tenant_id=DEFAULT_TENANT)
assert result == DEFAULT_TENANT
def test_none_returns_default(self):
result = _get_tenant_id(x_tenant_id=None)
assert result == DEFAULT_TENANT_ID
def test_invalid_uuid_returns_default(self):
result = _get_tenant_id(x_tenant_id="not-a-uuid")
assert result == DEFAULT_TENANT_ID
def test_empty_string_returns_default(self):
result = _get_tenant_id(x_tenant_id="")
assert result == DEFAULT_TENANT_ID
def test_different_valid_tenant(self):
result = _get_tenant_id(x_tenant_id=OTHER_TENANT)
assert result == OTHER_TENANT
def test_partial_uuid_returns_default(self):
result = _get_tenant_id(x_tenant_id="9282a473-5c95-4b3a")
assert result == DEFAULT_TENANT_ID
# ============================================================================= # =============================================================================
# HTTP Tests — GET /security-backlog # HTTP Tests — GET /security-backlog
@@ -657,21 +628,12 @@ class TestTenantIsolation:
assert resp_b.status_code == 200 assert resp_b.status_code == 200
assert resp_b.json()["total"] == 0 assert resp_b.json()["total"] == 0
def test_invalid_tenant_header_falls_back_to_default(self, mock_db): def test_invalid_tenant_header_returns_400(self, mock_db):
count_row = MagicMock()
count_row.__getitem__ = lambda self, i: 0
execute_result = MagicMock()
execute_result.fetchone.return_value = count_row
execute_result.fetchall.return_value = []
mock_db.execute.return_value = execute_result
response = client.get( response = client.get(
"/security-backlog", "/security-backlog",
headers={"X-Tenant-Id": "not-a-real-uuid"}, headers={"X-Tenant-Id": "not-a-real-uuid"},
) )
assert response.status_code == 200 assert response.status_code == 400
# Should succeed (falls back to DEFAULT_TENANT_ID)
assert "items" in response.json()
def test_create_uses_tenant_from_header(self, mock_db): def test_create_uses_tenant_from_header(self, mock_db):
created_row = make_item_row({"tenant_id": OTHER_TENANT}) created_row = make_item_row({"tenant_id": OTHER_TENANT})