[split-required] Split 500-850 LOC files (batch 2)
backend-lehrer (10 files): - game/database.py (785 → 5), correction_api.py (683 → 4) - classroom_engine/antizipation.py (676 → 5) - llm_gateway schools/edu_search already done in prior batch klausur-service (12 files): - orientation_crop_api.py (694 → 5), pdf_export.py (677 → 4) - zeugnis_crawler.py (676 → 5), grid_editor_api.py (671 → 5) - eh_templates.py (658 → 5), mail/api.py (651 → 5) - qdrant_service.py (638 → 5), training_api.py (625 → 4) website (6 pages): - middleware (696 → 8), mail (733 → 6), consent (628 → 8) - compliance/risks (622 → 5), export (502 → 5), brandbook (629 → 7) studio-v2 (3 components): - B2BMigrationWizard (848 → 3), CleanupPanel (765 → 2) - dashboard-experimental (739 → 2) admin-lehrer (4 files): - uebersetzungen (769 → 4), manager (670 → 2) - ChunkBrowserQA (675 → 6), dsfa/page (674 → 5) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
303
klausur-service/backend/training_routes.py
Normal file
303
klausur-service/backend/training_routes.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Training API — FastAPI route handlers.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from training_models import (
|
||||
TrainingStatus,
|
||||
TrainingConfig,
|
||||
_state,
|
||||
)
|
||||
from training_simulation import (
|
||||
simulate_training_progress,
|
||||
training_metrics_generator,
|
||||
batch_ocr_progress_generator,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRAINING JOBS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/jobs", response_model=List[dict])
|
||||
async def list_training_jobs():
|
||||
"""Get all training jobs."""
|
||||
return list(_state.jobs.values())
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=dict)
|
||||
async def get_training_job(job_id: str):
|
||||
"""Get details for a specific training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return _state.jobs[job_id]
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=dict)
|
||||
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
|
||||
"""Create and start a new training job."""
|
||||
# Check if there's already an active job
|
||||
if _state.active_job_id:
|
||||
active_job = _state.jobs.get(_state.active_job_id)
|
||||
if active_job and active_job["status"] in [
|
||||
TrainingStatus.TRAINING.value,
|
||||
TrainingStatus.PREPARING.value,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Another training job is already running"
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = str(uuid.uuid4())
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": config.name,
|
||||
"model_type": config.model_type.value,
|
||||
"status": TrainingStatus.QUEUED.value,
|
||||
"progress": 0,
|
||||
"current_epoch": 0,
|
||||
"total_epochs": config.epochs,
|
||||
"loss": 1.0,
|
||||
"val_loss": 1.0,
|
||||
"learning_rate": config.learning_rate,
|
||||
"documents_processed": 0,
|
||||
"total_documents": len(config.bundeslaender) * 50, # Estimate
|
||||
"started_at": None,
|
||||
"estimated_completion": None,
|
||||
"completed_at": None,
|
||||
"error_message": None,
|
||||
"metrics": {
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"accuracy": 0.0,
|
||||
"loss_history": [],
|
||||
"val_loss_history": [],
|
||||
},
|
||||
"config": config.dict(),
|
||||
}
|
||||
|
||||
_state.jobs[job_id] = job
|
||||
_state.active_job_id = job_id
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"id": job_id, "status": "queued", "message": "Training job created"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/pause", response_model=dict)
|
||||
async def pause_training_job(job_id: str):
|
||||
"""Pause a running training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not running")
|
||||
|
||||
job["status"] = TrainingStatus.PAUSED.value
|
||||
return {"success": True, "message": "Training paused"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/resume", response_model=dict)
|
||||
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
|
||||
"""Resume a paused training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.PAUSED.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not paused")
|
||||
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
_state.active_job_id = job_id
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"success": True, "message": "Training resumed"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel", response_model=dict)
|
||||
async def cancel_training_job(job_id: str):
|
||||
"""Cancel a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.CANCELLED.value
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
if _state.active_job_id == job_id:
|
||||
_state.active_job_id = None
|
||||
|
||||
return {"success": True, "message": "Training cancelled"}
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}", response_model=dict)
|
||||
async def delete_training_job(job_id: str):
|
||||
"""Delete a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete running job")
|
||||
|
||||
del _state.jobs[job_id]
|
||||
return {"success": True, "message": "Job deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MODEL VERSIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/models", response_model=List[dict])
|
||||
async def list_model_versions():
|
||||
"""Get all trained model versions."""
|
||||
return list(_state.model_versions.values())
|
||||
|
||||
|
||||
@router.get("/models/{version_id}", response_model=dict)
|
||||
async def get_model_version(version_id: str):
|
||||
"""Get details for a specific model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
return _state.model_versions[version_id]
|
||||
|
||||
|
||||
@router.post("/models/{version_id}/activate", response_model=dict)
|
||||
async def activate_model_version(version_id: str):
|
||||
"""Set a model version as active."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Deactivate all other versions of same type
|
||||
model = _state.model_versions[version_id]
|
||||
for v in _state.model_versions.values():
|
||||
if v["model_type"] == model["model_type"]:
|
||||
v["is_active"] = False
|
||||
|
||||
model["is_active"] = True
|
||||
return {"success": True, "message": "Model activated"}
|
||||
|
||||
|
||||
@router.delete("/models/{version_id}", response_model=dict)
|
||||
async def delete_model_version(version_id: str):
|
||||
"""Delete a model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
model = _state.model_versions[version_id]
|
||||
if model["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete active model")
|
||||
|
||||
del _state.model_versions[version_id]
|
||||
return {"success": True, "message": "Model deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATASET STATS & STATUS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/dataset/stats", response_model=dict)
|
||||
async def get_dataset_stats():
|
||||
"""Get statistics about the training dataset."""
|
||||
from metrics_db import get_zeugnis_stats
|
||||
|
||||
zeugnis_stats = await get_zeugnis_stats()
|
||||
|
||||
return {
|
||||
"total_documents": zeugnis_stats.get("total_documents", 0),
|
||||
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
|
||||
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
|
||||
"by_bundesland": {
|
||||
bl["bundesland"]: bl.get("doc_count", 0)
|
||||
for bl in zeugnis_stats.get("per_bundesland", [])
|
||||
},
|
||||
"by_doc_type": {
|
||||
"verordnung": 150,
|
||||
"schulordnung": 80,
|
||||
"handreichung": 45,
|
||||
"erlass": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status", response_model=dict)
|
||||
async def get_training_status():
|
||||
"""Get overall training system status."""
|
||||
active_job = None
|
||||
if _state.active_job_id and _state.active_job_id in _state.jobs:
|
||||
active_job = _state.jobs[_state.active_job_id]
|
||||
|
||||
return {
|
||||
"is_training": _state.active_job_id is not None and active_job is not None and
|
||||
active_job["status"] == TrainingStatus.TRAINING.value,
|
||||
"active_job_id": _state.active_job_id,
|
||||
"total_jobs": len(_state.jobs),
|
||||
"completed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.COMPLETED.value
|
||||
),
|
||||
"failed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.FAILED.value
|
||||
),
|
||||
"model_versions": len(_state.model_versions),
|
||||
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SSE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/stream")
|
||||
async def stream_training_metrics(job_id: str, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming training metrics.
|
||||
|
||||
Streams real-time training progress for a specific job.
|
||||
"""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return StreamingResponse(
|
||||
training_metrics_generator(job_id, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ocr/stream")
|
||||
async def stream_batch_ocr(images_count: int, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming batch OCR progress.
|
||||
|
||||
Simulates batch OCR processing with progress updates.
|
||||
"""
|
||||
if images_count < 1 or images_count > 100:
|
||||
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
|
||||
|
||||
return StreamingResponse(
|
||||
batch_ocr_progress_generator(images_count, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user