Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m22s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 23s
korrektur/ zeugnis/ admin/ compliance/ worksheet/ training/ metrics/ 52 shims, relative imports, RAG untouched. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
119 lines
3.4 KiB
Python
119 lines
3.4 KiB
Python
"""
|
|
Training API — enums, request/response models, and in-memory state.
|
|
"""
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Optional, List, Dict, Any
|
|
from enum import Enum
|
|
from dataclasses import dataclass, field
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
# ============================================================================
|
|
# ENUMS
|
|
# ============================================================================
|
|
|
|
class TrainingStatus(str, Enum):
|
|
QUEUED = "queued"
|
|
PREPARING = "preparing"
|
|
TRAINING = "training"
|
|
VALIDATING = "validating"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
PAUSED = "paused"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
ZEUGNIS = "zeugnis"
|
|
KLAUSUR = "klausur"
|
|
GENERAL = "general"
|
|
|
|
|
|
# ============================================================================
|
|
# REQUEST/RESPONSE MODELS
|
|
# ============================================================================
|
|
|
|
class TrainingConfig(BaseModel):
|
|
"""Configuration for a training job."""
|
|
name: str = Field(..., description="Name for the training job")
|
|
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
|
|
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
|
|
batch_size: int = Field(16, ge=1, le=128)
|
|
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
|
|
epochs: int = Field(10, ge=1, le=100)
|
|
warmup_steps: int = Field(500, ge=0, le=10000)
|
|
weight_decay: float = Field(0.01, ge=0, le=1)
|
|
gradient_accumulation: int = Field(4, ge=1, le=32)
|
|
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
|
|
|
|
|
|
class TrainingMetrics(BaseModel):
|
|
"""Metrics from a training job."""
|
|
precision: float = 0.0
|
|
recall: float = 0.0
|
|
f1_score: float = 0.0
|
|
accuracy: float = 0.0
|
|
loss_history: List[float] = []
|
|
val_loss_history: List[float] = []
|
|
|
|
|
|
class TrainingJob(BaseModel):
|
|
"""A training job with full details."""
|
|
id: str
|
|
name: str
|
|
model_type: ModelType
|
|
status: TrainingStatus
|
|
progress: float
|
|
current_epoch: int
|
|
total_epochs: int
|
|
loss: float
|
|
val_loss: float
|
|
learning_rate: float
|
|
documents_processed: int
|
|
total_documents: int
|
|
started_at: Optional[datetime]
|
|
estimated_completion: Optional[datetime]
|
|
completed_at: Optional[datetime]
|
|
error_message: Optional[str]
|
|
metrics: TrainingMetrics
|
|
config: TrainingConfig
|
|
|
|
|
|
class ModelVersion(BaseModel):
|
|
"""A trained model version."""
|
|
id: str
|
|
job_id: str
|
|
version: str
|
|
model_type: ModelType
|
|
created_at: datetime
|
|
metrics: TrainingMetrics
|
|
is_active: bool
|
|
size_mb: float
|
|
bundeslaender: List[str]
|
|
|
|
|
|
class DatasetStats(BaseModel):
|
|
"""Statistics about the training dataset."""
|
|
total_documents: int
|
|
total_chunks: int
|
|
training_allowed: int
|
|
by_bundesland: Dict[str, int]
|
|
by_doc_type: Dict[str, int]
|
|
|
|
|
|
# ============================================================================
|
|
# IN-MEMORY STATE (Replace with database in production)
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class TrainingState:
|
|
"""Global training state."""
|
|
jobs: Dict[str, dict] = field(default_factory=dict)
|
|
model_versions: Dict[str, dict] = field(default_factory=dict)
|
|
active_job_id: Optional[str] = None
|
|
|
|
|
|
_state = TrainingState()
|