Files
breakpilot-lehrer/klausur-service/backend/training_api.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

626 lines
21 KiB
Python

"""
Training API - Endpoints for managing AI training jobs
Provides endpoints for:
- Starting/stopping training jobs
- Monitoring training progress
- Managing model versions
- Configuring training parameters
- SSE streaming for real-time metrics
Phase 2.2: Server-Sent Events for live progress
"""
import os
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field, asdict
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS & MODELS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# Request/Response Models
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
global _state
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
# ============================================================================
# ROUTER
# ============================================================================
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
global _state
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
# Get stats from zeugnis sources
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12, # Estimate ~12 chunks per doc
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
# ============================================================================
# TRAINING STATUS
# ============================================================================
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINTS
# ============================================================================
async def training_metrics_generator(job_id: str, request: Request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
Usage:
const eventSource = new EventSource('/api/v1/admin/training/metrics/stream?job_id=xxx')
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data)
console.log(data.progress, data.metrics.loss)
}
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
async def batch_ocr_progress_generator(images_count: int, request: Request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
In production, integrate with actual TrOCR batch processing.
Args:
images_count: Number of images to process
Usage:
const eventSource = new EventSource('/api/v1/admin/training/ocr/stream?images_count=10')
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data)
if (data.type === 'progress') {
console.log(`${data.current}/${data.total}`)
}
}
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)