""" Training API — FastAPI route handlers. """ import uuid from datetime import datetime from typing import List from fastapi import APIRouter, HTTPException, BackgroundTasks, Request from fastapi.responses import StreamingResponse from training_models import ( TrainingStatus, TrainingConfig, _state, ) from training_simulation import ( simulate_training_progress, training_metrics_generator, batch_ocr_progress_generator, ) router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"]) # ============================================================================ # TRAINING JOBS # ============================================================================ @router.get("/jobs", response_model=List[dict]) async def list_training_jobs(): """Get all training jobs.""" return list(_state.jobs.values()) @router.get("/jobs/{job_id}", response_model=dict) async def get_training_job(job_id: str): """Get details for a specific training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") return _state.jobs[job_id] @router.post("/jobs", response_model=dict) async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks): """Create and start a new training job.""" # Check if there's already an active job if _state.active_job_id: active_job = _state.jobs.get(_state.active_job_id) if active_job and active_job["status"] in [ TrainingStatus.TRAINING.value, TrainingStatus.PREPARING.value, ]: raise HTTPException( status_code=409, detail="Another training job is already running" ) # Create job job_id = str(uuid.uuid4()) job = { "id": job_id, "name": config.name, "model_type": config.model_type.value, "status": TrainingStatus.QUEUED.value, "progress": 0, "current_epoch": 0, "total_epochs": config.epochs, "loss": 1.0, "val_loss": 1.0, "learning_rate": config.learning_rate, "documents_processed": 0, "total_documents": len(config.bundeslaender) * 50, # Estimate "started_at": None, "estimated_completion": None, "completed_at": None, "error_message": None, "metrics": { "precision": 0.0, "recall": 0.0, "f1_score": 0.0, "accuracy": 0.0, "loss_history": [], "val_loss_history": [], }, "config": config.dict(), } _state.jobs[job_id] = job _state.active_job_id = job_id # Start training in background background_tasks.add_task(simulate_training_progress, job_id) return {"id": job_id, "status": "queued", "message": "Training job created"} @router.post("/jobs/{job_id}/pause", response_model=dict) async def pause_training_job(job_id: str): """Pause a running training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] != TrainingStatus.TRAINING.value: raise HTTPException(status_code=400, detail="Job is not running") job["status"] = TrainingStatus.PAUSED.value return {"success": True, "message": "Training paused"} @router.post("/jobs/{job_id}/resume", response_model=dict) async def resume_training_job(job_id: str, background_tasks: BackgroundTasks): """Resume a paused training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] != TrainingStatus.PAUSED.value: raise HTTPException(status_code=400, detail="Job is not paused") job["status"] = TrainingStatus.TRAINING.value _state.active_job_id = job_id background_tasks.add_task(simulate_training_progress, job_id) return {"success": True, "message": "Training resumed"} @router.post("/jobs/{job_id}/cancel", response_model=dict) async def cancel_training_job(job_id: str): """Cancel a training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] job["status"] = TrainingStatus.CANCELLED.value job["completed_at"] = datetime.now().isoformat() if _state.active_job_id == job_id: _state.active_job_id = None return {"success": True, "message": "Training cancelled"} @router.delete("/jobs/{job_id}", response_model=dict) async def delete_training_job(job_id: str): """Delete a training job.""" if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") job = _state.jobs[job_id] if job["status"] == TrainingStatus.TRAINING.value: raise HTTPException(status_code=400, detail="Cannot delete running job") del _state.jobs[job_id] return {"success": True, "message": "Job deleted"} # ============================================================================ # MODEL VERSIONS # ============================================================================ @router.get("/models", response_model=List[dict]) async def list_model_versions(): """Get all trained model versions.""" return list(_state.model_versions.values()) @router.get("/models/{version_id}", response_model=dict) async def get_model_version(version_id: str): """Get details for a specific model version.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") return _state.model_versions[version_id] @router.post("/models/{version_id}/activate", response_model=dict) async def activate_model_version(version_id: str): """Set a model version as active.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") # Deactivate all other versions of same type model = _state.model_versions[version_id] for v in _state.model_versions.values(): if v["model_type"] == model["model_type"]: v["is_active"] = False model["is_active"] = True return {"success": True, "message": "Model activated"} @router.delete("/models/{version_id}", response_model=dict) async def delete_model_version(version_id: str): """Delete a model version.""" if version_id not in _state.model_versions: raise HTTPException(status_code=404, detail="Model version not found") model = _state.model_versions[version_id] if model["is_active"]: raise HTTPException(status_code=400, detail="Cannot delete active model") del _state.model_versions[version_id] return {"success": True, "message": "Model deleted"} # ============================================================================ # DATASET STATS & STATUS # ============================================================================ @router.get("/dataset/stats", response_model=dict) async def get_dataset_stats(): """Get statistics about the training dataset.""" from metrics_db import get_zeugnis_stats zeugnis_stats = await get_zeugnis_stats() return { "total_documents": zeugnis_stats.get("total_documents", 0), "total_chunks": zeugnis_stats.get("total_documents", 0) * 12, "training_allowed": zeugnis_stats.get("training_allowed_documents", 0), "by_bundesland": { bl["bundesland"]: bl.get("doc_count", 0) for bl in zeugnis_stats.get("per_bundesland", []) }, "by_doc_type": { "verordnung": 150, "schulordnung": 80, "handreichung": 45, "erlass": 30, }, } @router.get("/status", response_model=dict) async def get_training_status(): """Get overall training system status.""" active_job = None if _state.active_job_id and _state.active_job_id in _state.jobs: active_job = _state.jobs[_state.active_job_id] return { "is_training": _state.active_job_id is not None and active_job is not None and active_job["status"] == TrainingStatus.TRAINING.value, "active_job_id": _state.active_job_id, "total_jobs": len(_state.jobs), "completed_jobs": sum( 1 for j in _state.jobs.values() if j["status"] == TrainingStatus.COMPLETED.value ), "failed_jobs": sum( 1 for j in _state.jobs.values() if j["status"] == TrainingStatus.FAILED.value ), "model_versions": len(_state.model_versions), "active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]), } # ============================================================================ # SSE ENDPOINTS # ============================================================================ @router.get("/metrics/stream") async def stream_training_metrics(job_id: str, request: Request): """ SSE endpoint for streaming training metrics. Streams real-time training progress for a specific job. """ if job_id not in _state.jobs: raise HTTPException(status_code=404, detail="Job not found") return StreamingResponse( training_metrics_generator(job_id, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) @router.get("/ocr/stream") async def stream_batch_ocr(images_count: int, request: Request): """ SSE endpoint for streaming batch OCR progress. Simulates batch OCR processing with progress updates. """ if images_count < 1 or images_count > 100: raise HTTPException(status_code=400, detail="images_count must be between 1 and 100") return StreamingResponse( batch_ocr_progress_generator(images_count, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } )