Restructure: Move 52 files into 7 domain packages
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m22s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 23s

korrektur/ zeugnis/ admin/ compliance/ worksheet/ training/ metrics/
52 shims, relative imports, RAG untouched.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 22:10:48 +02:00
parent 0504d22b8e
commit 165c493d1e
111 changed files with 11859 additions and 11609 deletions

View File

@@ -0,0 +1,6 @@
"""
training package — training API, simulation, export, TrOCR.
Backward-compatible re-exports: consumers can still use
``from training_api import ...`` etc. via the shim files in backend/.
"""

View File

@@ -0,0 +1,31 @@
"""
Training API — barrel re-export.
The actual code lives in:
- training_models.py (enums, Pydantic models, in-memory state)
- training_simulation.py (simulate_training_progress, SSE generators)
- training_routes.py (FastAPI router + all endpoints)
"""
# Models & enums
from .models import ( # noqa: F401
TrainingStatus,
ModelType,
TrainingConfig,
TrainingMetrics,
TrainingJob,
ModelVersion,
DatasetStats,
TrainingState,
_state,
)
# Simulation helpers
from .simulation import ( # noqa: F401
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
# Router
from .routes import router # noqa: F401

View File

@@ -0,0 +1,448 @@
"""
Training Export Service for OCR Labeling Data
Exports labeled OCR data in formats suitable for fine-tuning:
- TrOCR (Microsoft's Transformer-based OCR model)
- llama3.2-vision (Meta's Vision-Language Model)
- Generic JSONL format
DATENSCHUTZ/PRIVACY:
- Alle Daten bleiben lokal auf dem Mac Mini
- Keine Cloud-Uploads ohne explizite Zustimmung
- Export-Pfade sind konfigurierbar
"""
import os
import json
import base64
import shutil
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
import hashlib
# Export directory configuration
EXPORT_BASE_PATH = os.getenv("OCR_EXPORT_PATH", "/app/ocr-exports")
TROCR_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "trocr")
LLAMA_VISION_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "llama-vision")
GENERIC_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "generic")
@dataclass
class TrainingSample:
"""A single training sample for OCR fine-tuning."""
id: str
image_path: str
ground_truth: str
ocr_text: Optional[str] = None
ocr_confidence: Optional[float] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class ExportResult:
"""Result of a training data export."""
export_format: str
export_path: str
sample_count: int
batch_id: str
created_at: datetime
manifest_path: str
class TrOCRExporter:
"""
Export training data for TrOCR fine-tuning.
TrOCR expects:
- Image files (PNG/JPG)
- A CSV/TSV file with: image_path, text
- Or a JSONL file with: {"file_name": "img.png", "text": "ground truth"}
We use the JSONL format for flexibility.
"""
def __init__(self, export_path: str = TROCR_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in TrOCR format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"file_name": image_ref,
"text": sample.ground_truth,
"id": sample.id,
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "trocr",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "microsoft/trocr-base-handwritten",
"task": "handwriting-recognition",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="trocr",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class LlamaVisionExporter:
"""
Export training data for llama3.2-vision fine-tuning.
Llama Vision fine-tuning expects:
- JSONL format with base64-encoded images or image URLs
- Format: {"messages": [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}, {"role": "assistant", "content": "..."}]}
We create a supervised fine-tuning dataset.
"""
def __init__(self, export_path: str = LLAMA_VISION_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def _encode_image_base64(self, image_path: str) -> Optional[str]:
"""Encode image to base64."""
try:
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
except Exception:
return None
def export(
self,
samples: List[TrainingSample],
batch_id: str,
include_base64: bool = False,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in Llama Vision fine-tuning format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
include_base64: Whether to include base64-encoded images in JSONL
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# OCR instruction prompt
system_prompt = (
"Du bist ein OCR-Experte für deutsche Handschrift. "
"Lies den handgeschriebenen Text im Bild und gib ihn wortgetreu wieder."
)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
# Build message format
user_content = [
{"type": "image_url", "image_url": {"url": image_ref}},
{"type": "text", "text": "Lies den handgeschriebenen Text in diesem Bild."},
]
# Optionally include base64
if include_base64:
b64 = self._encode_image_base64(sample.image_path)
if b64:
ext = Path(sample.image_path).suffix.lower().replace('.', '')
mime = {'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'}.get(ext, 'image/png')
user_content[0] = {
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}
}
export_data.append({
"id": sample.id,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
{"role": "assistant", "content": sample.ground_truth},
],
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "llama_vision",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "llama3.2-vision:11b",
"task": "handwriting-ocr",
"system_prompt": system_prompt,
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="llama_vision",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class GenericExporter:
"""
Export training data in a generic JSONL format.
This format is compatible with most ML frameworks and can be
easily converted to other formats.
"""
def __init__(self, export_path: str = GENERIC_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in generic JSONL format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"id": sample.id,
"image_path": image_ref,
"ground_truth": sample.ground_truth,
"ocr_text": sample.ocr_text,
"ocr_confidence": sample.ocr_confidence,
"metadata": sample.metadata or {},
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "data.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Also write as single JSON for convenience
json_path = os.path.join(batch_path, "data.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
# Write manifest
manifest = {
"format": "generic",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data_jsonl": "data.jsonl",
"data_json": "data.json",
"images": "images/",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="generic",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class TrainingExportService:
"""
Main service for exporting OCR labeling data to various training formats.
"""
def __init__(self):
self.trocr_exporter = TrOCRExporter()
self.llama_vision_exporter = LlamaVisionExporter()
self.generic_exporter = GenericExporter()
def export(
self,
samples: List[TrainingSample],
export_format: str,
batch_id: Optional[str] = None,
**kwargs,
) -> ExportResult:
"""
Export training samples in the specified format.
Args:
samples: List of training samples
export_format: 'trocr', 'llama_vision', or 'generic'
batch_id: Optional batch ID (generated if not provided)
**kwargs: Additional format-specific options
Returns:
ExportResult with export details
"""
if not batch_id:
batch_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if export_format == "trocr":
return self.trocr_exporter.export(samples, batch_id, **kwargs)
elif export_format == "llama_vision":
return self.llama_vision_exporter.export(samples, batch_id, **kwargs)
elif export_format == "generic":
return self.generic_exporter.export(samples, batch_id, **kwargs)
else:
raise ValueError(f"Unknown export format: {export_format}")
def list_exports(self, export_format: Optional[str] = None) -> List[Dict]:
"""
List all available exports.
Args:
export_format: Optional filter by format
Returns:
List of export manifests
"""
exports = []
paths_to_check = []
if export_format is None or export_format == "trocr":
paths_to_check.append((TROCR_EXPORT_PATH, "trocr"))
if export_format is None or export_format == "llama_vision":
paths_to_check.append((LLAMA_VISION_EXPORT_PATH, "llama_vision"))
if export_format is None or export_format == "generic":
paths_to_check.append((GENERIC_EXPORT_PATH, "generic"))
for base_path, fmt in paths_to_check:
if not os.path.exists(base_path):
continue
for batch_dir in os.listdir(base_path):
manifest_path = os.path.join(base_path, batch_dir, "manifest.json")
if os.path.exists(manifest_path):
with open(manifest_path, 'r') as f:
manifest = json.load(f)
manifest["export_path"] = os.path.join(base_path, batch_dir)
exports.append(manifest)
return sorted(exports, key=lambda x: x.get("created_at", ""), reverse=True)
# Singleton instance
_export_service: Optional[TrainingExportService] = None
def get_training_export_service() -> TrainingExportService:
"""Get or create the training export service singleton."""
global _export_service
if _export_service is None:
_export_service = TrainingExportService()
return _export_service

View File

@@ -0,0 +1,118 @@
"""
Training API — enums, request/response models, and in-memory state.
"""
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()

View File

@@ -0,0 +1,303 @@
"""
Training API — FastAPI route handlers.
"""
import uuid
from datetime import datetime
from typing import List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from .models import (
TrainingStatus,
TrainingConfig,
_state,
)
from .simulation import (
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
# ============================================================================
# TRAINING JOBS
# ============================================================================
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS & STATUS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SSE ENDPOINTS
# ============================================================================
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)

View File

@@ -0,0 +1,190 @@
"""
Training API — simulation helper and SSE generators.
"""
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from .models import TrainingStatus, _state
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
async def training_metrics_generator(job_id: str, request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
async def batch_ocr_progress_generator(images_count: int, request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"

View File

@@ -0,0 +1,261 @@
"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))