""" Training API — enums, request/response models, and in-memory state. """ import uuid from datetime import datetime from typing import Optional, List, Dict, Any from enum import Enum from dataclasses import dataclass, field from pydantic import BaseModel, Field # ============================================================================ # ENUMS # ============================================================================ class TrainingStatus(str, Enum): QUEUED = "queued" PREPARING = "preparing" TRAINING = "training" VALIDATING = "validating" COMPLETED = "completed" FAILED = "failed" PAUSED = "paused" CANCELLED = "cancelled" class ModelType(str, Enum): ZEUGNIS = "zeugnis" KLAUSUR = "klausur" GENERAL = "general" # ============================================================================ # REQUEST/RESPONSE MODELS # ============================================================================ class TrainingConfig(BaseModel): """Configuration for a training job.""" name: str = Field(..., description="Name for the training job") model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train") bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include") batch_size: int = Field(16, ge=1, le=128) learning_rate: float = Field(0.00005, ge=0.000001, le=0.1) epochs: int = Field(10, ge=1, le=100) warmup_steps: int = Field(500, ge=0, le=10000) weight_decay: float = Field(0.01, ge=0, le=1) gradient_accumulation: int = Field(4, ge=1, le=32) mixed_precision: bool = Field(True, description="Use FP16 mixed precision training") class TrainingMetrics(BaseModel): """Metrics from a training job.""" precision: float = 0.0 recall: float = 0.0 f1_score: float = 0.0 accuracy: float = 0.0 loss_history: List[float] = [] val_loss_history: List[float] = [] class TrainingJob(BaseModel): """A training job with full details.""" id: str name: str model_type: ModelType status: TrainingStatus progress: float current_epoch: int total_epochs: int loss: float val_loss: float learning_rate: float documents_processed: int total_documents: int started_at: Optional[datetime] estimated_completion: Optional[datetime] completed_at: Optional[datetime] error_message: Optional[str] metrics: TrainingMetrics config: TrainingConfig class ModelVersion(BaseModel): """A trained model version.""" id: str job_id: str version: str model_type: ModelType created_at: datetime metrics: TrainingMetrics is_active: bool size_mb: float bundeslaender: List[str] class DatasetStats(BaseModel): """Statistics about the training dataset.""" total_documents: int total_chunks: int training_allowed: int by_bundesland: Dict[str, int] by_doc_type: Dict[str, int] # ============================================================================ # IN-MEMORY STATE (Replace with database in production) # ============================================================================ @dataclass class TrainingState: """Global training state.""" jobs: Dict[str, dict] = field(default_factory=dict) model_versions: Dict[str, dict] = field(default_factory=dict) active_job_id: Optional[str] = None _state = TrainingState()