""" Training API - Endpoints for managing AI training jobs Provides endpoints for: - Starting/stopping training jobs - Monitoring training progress - Managing model versions - Configuring training parameters - SSE streaming for real-time metrics Phase 2.2: Server-Sent Events for live progress """ import os import json import uuid import asyncio from datetime import datetime, timedelta from typing import Optional, List, Dict, Any from enum import Enum from dataclasses import dataclass, field, asdict from fastapi import APIRouter, HTTPException, BackgroundTasks, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field # ============================================================================ # ENUMS & MODELS # ============================================================================ class TrainingStatus(str, Enum): QUEUED = "queued" PREPARING = "preparing" TRAINING = "training" VALIDATING = "validating" COMPLETED = "completed" FAILED = "failed" PAUSED = "paused" CANCELLED = "cancelled" class ModelType(str, Enum): ZEUGNIS = "zeugnis" KLAUSUR = "klausur" GENERAL = "general" # Request/Response Models class TrainingConfig(BaseModel): """Configuration for a training job.""" name: str = Field(..., description="Name for the training job") model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train") bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include") batch_size: int = Field(16, ge=1, le=128) learning_rate: float = Field(0.00005, ge=0.000001, le=0.1) epochs: int = Field(10, ge=1, le=100) warmup_steps: int = Field(500, ge=0, le=10000) weight_decay: float = Field(0.01, ge=0, le=1) gradient_accumulation: int = Field(4, ge=1, le=32) mixed_precision: bool = Field(True, description="Use FP16 mixed precision training") class TrainingMetrics(BaseModel): """Metrics from a training job.""" precision: float = 0.0 recall: float = 0.0 f1_score: float = 0.0 accuracy: float = 0.0 loss_history: List[float] = [] val_loss_history: List[float] = [] class TrainingJob(BaseModel): """A training job with full details.""" id: str name: str model_type: ModelType status: TrainingStatus progress: float current_epoch: int total_epochs: int loss: float val_loss: float learning_rate: float documents_processed: int total_documents: int started_at: Optional[datetime] estimated_completion: Optional[datetime] completed_at: Optional[datetime] error_message: Optional[str] metrics: TrainingMetrics config: TrainingConfig class ModelVersion(BaseModel): """A trained model version.""" id: str job_id: str version: str model_type: ModelType created_at: datetime metrics: TrainingMetrics is_active: bool size_mb: float bundeslaender: List[str] class DatasetStats(BaseModel): """Statistics about the training dataset.""" total_documents: int total_chunks: int training_allowed: int by_bundesland: Dict[str, int] by_doc_type: Dict[str, int] # ============================================================================ # IN-MEMORY STATE (Replace with database in production) # ============================================================================ @dataclass class TrainingState: """Global training state.""" jobs: Dict[str, dict] = field(default_factory=dict) model_versions: Dict[str, dict] = field(default_factory=dict) active_job_id: Optional[str] = None _state = TrainingState() # ============================================================================ # HELPER FUNCTIONS # ============================================================================ async def simulate_training_progress(job_id: str): """Simulate training progress (replace with actual training logic).""" global _state if job_id not in _state.jobs: return job = _state.jobs[job_id] job["status"] = TrainingStatus.TRAINING.value job["started_at"] = datetime.now().isoformat() total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch current_step = 0 while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value: # Update progress progress = (current_step / total_steps) * 100 current_epoch = current_step // 100 + 1 # Simulate decreasing loss base_loss = 0.8 * (1 - progress / 100) + 0.1 loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100)) val_loss = loss * 1.1 # Update job state job["progress"] = progress job["current_epoch"] = min(current_epoch, job["total_epochs"]) job["loss"] = round(loss, 4) job["val_loss"] = round(val_loss, 4) job["documents_processed"] = int((progress / 100) * job["total_documents"]) # Update metrics job["metrics"]["loss_history"].append(round(loss, 4)) job["metrics"]["val_loss_history"].append(round(val_loss, 4)) job["metrics"]["precision"] = round(0.5 + (progress / 200), 3) job["metrics"]["recall"] = round(0.45 + (progress / 200), 3) job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3) job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3) # Keep only last 50 history points if len(job["metrics"]["loss_history"]) > 50: job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:] job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:] # Estimate completion if progress > 0: elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds() remaining = (elapsed / progress) * (100 - progress) job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat() current_step += 1 await asyncio.sleep(0.5) # Simulate work # Mark as completed if job["status"] == TrainingStatus.TRAINING.value: job["status"] = TrainingStatus.COMPLETED.value job["progress"] = 100 job["completed_at"] = datetime.now().isoformat() # Create model version version_id = str(uuid.uuid4()) _state.model_versions[version_id] = { "id": version_id, "job_id": job_id, "version": f"v{len(_state.model_versions) + 1}.0", "model_type": job["model_type"], "created_at": datetime.now().isoformat(), "metrics": job["metrics"], "is_active": True, "size_mb": 245.7, "bundeslaender": job["config"]["bundeslaender"], } _state.active_job_id = None # ============================================================================ # ROUTER # ============================================================================ router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"]) @router.get("/jobs", response_model=List[dict]) async def list_training_jobs(): """Get all training jobs.""" return list(_state.jobs.values()) @router.get("/jobs/{job_id}", response_model=dict) async def get_training_job(job_id: str): """Get details for a specific training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") return _state.jobs[job_id] @router.post("/jobs", response_model=dict) async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks): """Create and start a new training job.""" global _state # Check if there's already an active job if _state.active_job_id: active_job = _state.jobs.get(_state.active_job_id) if active_job and active_job["status"] in [ TrainingStatus.TRAINING.value, TrainingStatus.PREPARING.value, ]: raise HTTPException( status_code=409, detail="Another training job is already running" ) # Create job job_id = str(uuid.uuid4()) job = { "id": job_id, "name": config.name, "model_type": config.model_type.value, "status": TrainingStatus.QUEUED.value, "progress": 0, "current_epoch": 0, "total_epochs": config.epochs, "loss": 1.0, "val_loss": 1.0, "learning_rate": config.learning_rate, "documents_processed": 0, "total_documents": len(config.bundeslaender) * 50, # Estimate "started_at": None, "estimated_completion": None, "completed_at": None, "error_message": None, "metrics": { "precision": 0.0, "recall": 0.0, "f1_score": 0.0, "accuracy": 0.0, "loss_history": [], "val_loss_history": [], }, "config": config.dict(), } _state.jobs[job_id] = job _state.active_job_id = job_id # Start training in background background_tasks.add_task(simulate_training_progress, job_id) return {"id": job_id, "status": "queued", "message": "Training job created"} @router.post("/jobs/{job_id}/pause", response_model=dict) async def pause_training_job(job_id: str): """Pause a running training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] != TrainingStatus.TRAINING.value: raise HTTPException(status_code=400, detail="Job is not running") job["status"] = TrainingStatus.PAUSED.value return {"success": True, "message": "Training paused"} @router.post("/jobs/{job_id}/resume", response_model=dict) async def resume_training_job(job_id: str, background_tasks: BackgroundTasks): """Resume a paused training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] != TrainingStatus.PAUSED.value: raise HTTPException(status_code=400, detail="Job is not paused") job["status"] = TrainingStatus.TRAINING.value _state.active_job_id = job_id background_tasks.add_task(simulate_training_progress, job_id) return {"success": True, "message": "Training resumed"} @router.post("/jobs/{job_id}/cancel", response_model=dict) async def cancel_training_job(job_id: str): """Cancel a training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] job["status"] = TrainingStatus.CANCELLED.value job["completed_at"] = datetime.now().isoformat() if _state.active_job_id == job_id: _state.active_job_id = None return {"success": True, "message": "Training cancelled"} @router.delete("/jobs/{job_id}", response_model=dict) async def delete_training_job(job_id: str): """Delete a training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] == TrainingStatus.TRAINING.value: raise HTTPException(status_code=400, detail="Cannot delete running job") del _state.jobs[job_id] return {"success": True, "message": "Job deleted"} # ============================================================================ # MODEL VERSIONS # ============================================================================ @router.get("/models", response_model=List[dict]) async def list_model_versions(): """Get all trained model versions.""" return list(_state.model_versions.values()) @router.get("/models/{version_id}", response_model=dict) async def get_model_version(version_id: str): """Get details for a specific model version.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") return _state.model_versions[version_id] @router.post("/models/{version_id}/activate", response_model=dict) async def activate_model_version(version_id: str): """Set a model version as active.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") # Deactivate all other versions of same type model = _state.model_versions[version_id] for v in _state.model_versions.values(): if v["model_type"] == model["model_type"]: v["is_active"] = False model["is_active"] = True return {"success": True, "message": "Model activated"} @router.delete("/models/{version_id}", response_model=dict) async def delete_model_version(version_id: str): """Delete a model version.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") model = _state.model_versions[version_id] if model["is_active"]: raise HTTPException(status_code=400, detail="Cannot delete active model") del _state.model_versions[version_id] return {"success": True, "message": "Model deleted"} # ============================================================================ # DATASET STATS # ============================================================================ @router.get("/dataset/stats", response_model=dict) async def get_dataset_stats(): """Get statistics about the training dataset.""" # Get stats from zeugnis sources from metrics_db import get_zeugnis_stats zeugnis_stats = await get_zeugnis_stats() return { "total_documents": zeugnis_stats.get("total_documents", 0), "total_chunks": zeugnis_stats.get("total_documents", 0) * 12, # Estimate ~12 chunks per doc "training_allowed": zeugnis_stats.get("training_allowed_documents", 0), "by_bundesland": { bl["bundesland"]: bl.get("doc_count", 0) for bl in zeugnis_stats.get("per_bundesland", []) }, "by_doc_type": { "verordnung": 150, "schulordnung": 80, "handreichung": 45, "erlass": 30, }, } # ============================================================================ # TRAINING STATUS # ============================================================================ @router.get("/status", response_model=dict) async def get_training_status(): """Get overall training system status.""" active_job = None if _state.active_job_id and _state.active_job_id in _state.jobs: active_job = _state.jobs[_state.active_job_id] return { "is_training": _state.active_job_id is not None and active_job is not None and active_job["status"] == TrainingStatus.TRAINING.value, "active_job_id": _state.active_job_id, "total_jobs": len(_state.jobs), "completed_jobs": sum( 1 for j in _state.jobs.values() if j["status"] == TrainingStatus.COMPLETED.value ), "failed_jobs": sum( 1 for j in _state.jobs.values() if j["status"] == TrainingStatus.FAILED.value ), "model_versions": len(_state.model_versions), "active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]), } # ============================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINTS # ============================================================================ async def training_metrics_generator(job_id: str, request: Request): """ SSE generator for streaming training metrics. Yields JSON-encoded training status updates every 500ms. """ while True: # Check if client disconnected if await request.is_disconnected(): break # Get job status if job_id not in _state.jobs: yield f"data: {json.dumps({'error': 'Job not found'})}\n\n" break job = _state.jobs[job_id] # Build metrics response metrics_data = { "job_id": job["id"], "status": job["status"], "progress": job["progress"], "current_epoch": job["current_epoch"], "total_epochs": job["total_epochs"], "current_step": int(job["progress"] * job["total_epochs"]), "total_steps": job["total_epochs"] * 100, "elapsed_time_ms": 0, "estimated_remaining_ms": 0, "metrics": { "loss": job["loss"], "val_loss": job["val_loss"], "accuracy": job["metrics"]["accuracy"], "learning_rate": job["learning_rate"] }, "history": [ { "epoch": i + 1, "step": (i + 1) * 10, "loss": loss, "val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None, "learning_rate": job["learning_rate"], "timestamp": 0 } for i, loss in enumerate(job["metrics"]["loss_history"][-50:]) ] } # Calculate elapsed time if job["started_at"]: started = datetime.fromisoformat(job["started_at"]) metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000) # Calculate remaining time if job["estimated_completion"]: estimated = datetime.fromisoformat(job["estimated_completion"]) metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000)) # Send SSE event yield f"data: {json.dumps(metrics_data)}\n\n" # Check if job completed if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]: break # Wait before next update await asyncio.sleep(0.5) @router.get("/metrics/stream") async def stream_training_metrics(job_id: str, request: Request): """ SSE endpoint for streaming training metrics. Streams real-time training progress for a specific job. Usage: const eventSource = new EventSource('/api/v1/admin/training/metrics/stream?job_id=xxx') eventSource.onmessage = (event) => { const data = JSON.parse(event.data) console.log(data.progress, data.metrics.loss) } """ if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") return StreamingResponse( training_metrics_generator(job_id, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # Disable nginx buffering } ) async def batch_ocr_progress_generator(images_count: int, request: Request): """ SSE generator for batch OCR progress simulation. In production, this would integrate with actual OCR processing. """ import random for i in range(images_count): # Check if client disconnected if await request.is_disconnected(): break # Simulate processing time await asyncio.sleep(random.uniform(0.3, 0.8)) progress_data = { "type": "progress", "current": i + 1, "total": images_count, "progress_percent": ((i + 1) / images_count) * 100, "elapsed_ms": (i + 1) * 500, "estimated_remaining_ms": (images_count - i - 1) * 500, "result": { "text": f"Sample recognized text for image {i + 1}", "confidence": round(random.uniform(0.7, 0.98), 2), "processing_time_ms": random.randint(200, 600), "from_cache": random.random() < 0.2 } } yield f"data: {json.dumps(progress_data)}\n\n" # Send completion event yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n" @router.get("/ocr/stream") async def stream_batch_ocr(images_count: int, request: Request): """ SSE endpoint for streaming batch OCR progress. Simulates batch OCR processing with progress updates. In production, integrate with actual TrOCR batch processing. Args: images_count: Number of images to process Usage: const eventSource = new EventSource('/api/v1/admin/training/ocr/stream?images_count=10') eventSource.onmessage = (event) => { const data = JSON.parse(event.data) if (data.type === 'progress') { console.log(`${data.current}/${data.total}`) } } """ if images_count < 1 or images_count > 100: raise HTTPException(status_code=400, detail="images_count must be between 1 and 100") return StreamingResponse( batch_ocr_progress_generator(images_count, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } )