Restructure: Move 52 files into 7 domain packages
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m22s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 23s
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m22s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 23s
korrektur/ zeugnis/ admin/ compliance/ worksheet/ training/ metrics/ 52 shims, relative imports, RAG untouched. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
6
klausur-service/backend/training/__init__.py
Normal file
6
klausur-service/backend/training/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
training package — training API, simulation, export, TrOCR.
|
||||
|
||||
Backward-compatible re-exports: consumers can still use
|
||||
``from training_api import ...`` etc. via the shim files in backend/.
|
||||
"""
|
||||
31
klausur-service/backend/training/api.py
Normal file
31
klausur-service/backend/training/api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Training API — barrel re-export.
|
||||
|
||||
The actual code lives in:
|
||||
- training_models.py (enums, Pydantic models, in-memory state)
|
||||
- training_simulation.py (simulate_training_progress, SSE generators)
|
||||
- training_routes.py (FastAPI router + all endpoints)
|
||||
"""
|
||||
|
||||
# Models & enums
|
||||
from .models import ( # noqa: F401
|
||||
TrainingStatus,
|
||||
ModelType,
|
||||
TrainingConfig,
|
||||
TrainingMetrics,
|
||||
TrainingJob,
|
||||
ModelVersion,
|
||||
DatasetStats,
|
||||
TrainingState,
|
||||
_state,
|
||||
)
|
||||
|
||||
# Simulation helpers
|
||||
from .simulation import ( # noqa: F401
|
||||
simulate_training_progress,
|
||||
training_metrics_generator,
|
||||
batch_ocr_progress_generator,
|
||||
)
|
||||
|
||||
# Router
|
||||
from .routes import router # noqa: F401
|
||||
448
klausur-service/backend/training/export_service.py
Normal file
448
klausur-service/backend/training/export_service.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Training Export Service for OCR Labeling Data
|
||||
|
||||
Exports labeled OCR data in formats suitable for fine-tuning:
|
||||
- TrOCR (Microsoft's Transformer-based OCR model)
|
||||
- llama3.2-vision (Meta's Vision-Language Model)
|
||||
- Generic JSONL format
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Daten bleiben lokal auf dem Mac Mini
|
||||
- Keine Cloud-Uploads ohne explizite Zustimmung
|
||||
- Export-Pfade sind konfigurierbar
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
# Export directory configuration
|
||||
EXPORT_BASE_PATH = os.getenv("OCR_EXPORT_PATH", "/app/ocr-exports")
|
||||
TROCR_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "trocr")
|
||||
LLAMA_VISION_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "llama-vision")
|
||||
GENERIC_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "generic")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSample:
|
||||
"""A single training sample for OCR fine-tuning."""
|
||||
id: str
|
||||
image_path: str
|
||||
ground_truth: str
|
||||
ocr_text: Optional[str] = None
|
||||
ocr_confidence: Optional[float] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportResult:
|
||||
"""Result of a training data export."""
|
||||
export_format: str
|
||||
export_path: str
|
||||
sample_count: int
|
||||
batch_id: str
|
||||
created_at: datetime
|
||||
manifest_path: str
|
||||
|
||||
|
||||
class TrOCRExporter:
|
||||
"""
|
||||
Export training data for TrOCR fine-tuning.
|
||||
|
||||
TrOCR expects:
|
||||
- Image files (PNG/JPG)
|
||||
- A CSV/TSV file with: image_path, text
|
||||
- Or a JSONL file with: {"file_name": "img.png", "text": "ground truth"}
|
||||
|
||||
We use the JSONL format for flexibility.
|
||||
"""
|
||||
|
||||
def __init__(self, export_path: str = TROCR_EXPORT_PATH):
|
||||
self.export_path = export_path
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
|
||||
def export(
|
||||
self,
|
||||
samples: List[TrainingSample],
|
||||
batch_id: str,
|
||||
copy_images: bool = True,
|
||||
) -> ExportResult:
|
||||
"""
|
||||
Export samples in TrOCR format.
|
||||
|
||||
Args:
|
||||
samples: List of training samples
|
||||
batch_id: Unique batch identifier
|
||||
copy_images: Whether to copy images to export directory
|
||||
|
||||
Returns:
|
||||
ExportResult with export details
|
||||
"""
|
||||
batch_path = os.path.join(self.export_path, batch_id)
|
||||
images_path = os.path.join(batch_path, "images")
|
||||
os.makedirs(images_path, exist_ok=True)
|
||||
|
||||
# Export data
|
||||
export_data = []
|
||||
for sample in samples:
|
||||
# Copy image if requested
|
||||
if copy_images and os.path.exists(sample.image_path):
|
||||
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
|
||||
dest_path = os.path.join(images_path, image_filename)
|
||||
shutil.copy2(sample.image_path, dest_path)
|
||||
image_ref = f"images/{image_filename}"
|
||||
else:
|
||||
image_ref = sample.image_path
|
||||
|
||||
export_data.append({
|
||||
"file_name": image_ref,
|
||||
"text": sample.ground_truth,
|
||||
"id": sample.id,
|
||||
})
|
||||
|
||||
# Write JSONL file
|
||||
jsonl_path = os.path.join(batch_path, "train.jsonl")
|
||||
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
||||
for item in export_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Write manifest
|
||||
manifest = {
|
||||
"format": "trocr",
|
||||
"version": "1.0",
|
||||
"batch_id": batch_id,
|
||||
"sample_count": len(samples),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"files": {
|
||||
"data": "train.jsonl",
|
||||
"images": "images/",
|
||||
},
|
||||
"model_config": {
|
||||
"base_model": "microsoft/trocr-base-handwritten",
|
||||
"task": "handwriting-recognition",
|
||||
},
|
||||
}
|
||||
manifest_path = os.path.join(batch_path, "manifest.json")
|
||||
with open(manifest_path, 'w') as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
return ExportResult(
|
||||
export_format="trocr",
|
||||
export_path=batch_path,
|
||||
sample_count=len(samples),
|
||||
batch_id=batch_id,
|
||||
created_at=datetime.utcnow(),
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
|
||||
class LlamaVisionExporter:
|
||||
"""
|
||||
Export training data for llama3.2-vision fine-tuning.
|
||||
|
||||
Llama Vision fine-tuning expects:
|
||||
- JSONL format with base64-encoded images or image URLs
|
||||
- Format: {"messages": [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}, {"role": "assistant", "content": "..."}]}
|
||||
|
||||
We create a supervised fine-tuning dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, export_path: str = LLAMA_VISION_EXPORT_PATH):
|
||||
self.export_path = export_path
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
|
||||
def _encode_image_base64(self, image_path: str) -> Optional[str]:
|
||||
"""Encode image to base64."""
|
||||
try:
|
||||
with open(image_path, 'rb') as f:
|
||||
return base64.b64encode(f.read()).decode('utf-8')
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def export(
|
||||
self,
|
||||
samples: List[TrainingSample],
|
||||
batch_id: str,
|
||||
include_base64: bool = False,
|
||||
copy_images: bool = True,
|
||||
) -> ExportResult:
|
||||
"""
|
||||
Export samples in Llama Vision fine-tuning format.
|
||||
|
||||
Args:
|
||||
samples: List of training samples
|
||||
batch_id: Unique batch identifier
|
||||
include_base64: Whether to include base64-encoded images in JSONL
|
||||
copy_images: Whether to copy images to export directory
|
||||
|
||||
Returns:
|
||||
ExportResult with export details
|
||||
"""
|
||||
batch_path = os.path.join(self.export_path, batch_id)
|
||||
images_path = os.path.join(batch_path, "images")
|
||||
os.makedirs(images_path, exist_ok=True)
|
||||
|
||||
# OCR instruction prompt
|
||||
system_prompt = (
|
||||
"Du bist ein OCR-Experte für deutsche Handschrift. "
|
||||
"Lies den handgeschriebenen Text im Bild und gib ihn wortgetreu wieder."
|
||||
)
|
||||
|
||||
# Export data
|
||||
export_data = []
|
||||
for sample in samples:
|
||||
# Copy image if requested
|
||||
if copy_images and os.path.exists(sample.image_path):
|
||||
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
|
||||
dest_path = os.path.join(images_path, image_filename)
|
||||
shutil.copy2(sample.image_path, dest_path)
|
||||
image_ref = f"images/{image_filename}"
|
||||
else:
|
||||
image_ref = sample.image_path
|
||||
|
||||
# Build message format
|
||||
user_content = [
|
||||
{"type": "image_url", "image_url": {"url": image_ref}},
|
||||
{"type": "text", "text": "Lies den handgeschriebenen Text in diesem Bild."},
|
||||
]
|
||||
|
||||
# Optionally include base64
|
||||
if include_base64:
|
||||
b64 = self._encode_image_base64(sample.image_path)
|
||||
if b64:
|
||||
ext = Path(sample.image_path).suffix.lower().replace('.', '')
|
||||
mime = {'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'}.get(ext, 'image/png')
|
||||
user_content[0] = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime};base64,{b64}"}
|
||||
}
|
||||
|
||||
export_data.append({
|
||||
"id": sample.id,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": sample.ground_truth},
|
||||
],
|
||||
})
|
||||
|
||||
# Write JSONL file
|
||||
jsonl_path = os.path.join(batch_path, "train.jsonl")
|
||||
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
||||
for item in export_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Write manifest
|
||||
manifest = {
|
||||
"format": "llama_vision",
|
||||
"version": "1.0",
|
||||
"batch_id": batch_id,
|
||||
"sample_count": len(samples),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"files": {
|
||||
"data": "train.jsonl",
|
||||
"images": "images/",
|
||||
},
|
||||
"model_config": {
|
||||
"base_model": "llama3.2-vision:11b",
|
||||
"task": "handwriting-ocr",
|
||||
"system_prompt": system_prompt,
|
||||
},
|
||||
}
|
||||
manifest_path = os.path.join(batch_path, "manifest.json")
|
||||
with open(manifest_path, 'w') as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
return ExportResult(
|
||||
export_format="llama_vision",
|
||||
export_path=batch_path,
|
||||
sample_count=len(samples),
|
||||
batch_id=batch_id,
|
||||
created_at=datetime.utcnow(),
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
|
||||
class GenericExporter:
|
||||
"""
|
||||
Export training data in a generic JSONL format.
|
||||
|
||||
This format is compatible with most ML frameworks and can be
|
||||
easily converted to other formats.
|
||||
"""
|
||||
|
||||
def __init__(self, export_path: str = GENERIC_EXPORT_PATH):
|
||||
self.export_path = export_path
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
|
||||
def export(
|
||||
self,
|
||||
samples: List[TrainingSample],
|
||||
batch_id: str,
|
||||
copy_images: bool = True,
|
||||
) -> ExportResult:
|
||||
"""
|
||||
Export samples in generic JSONL format.
|
||||
|
||||
Args:
|
||||
samples: List of training samples
|
||||
batch_id: Unique batch identifier
|
||||
copy_images: Whether to copy images to export directory
|
||||
|
||||
Returns:
|
||||
ExportResult with export details
|
||||
"""
|
||||
batch_path = os.path.join(self.export_path, batch_id)
|
||||
images_path = os.path.join(batch_path, "images")
|
||||
os.makedirs(images_path, exist_ok=True)
|
||||
|
||||
# Export data
|
||||
export_data = []
|
||||
for sample in samples:
|
||||
# Copy image if requested
|
||||
if copy_images and os.path.exists(sample.image_path):
|
||||
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
|
||||
dest_path = os.path.join(images_path, image_filename)
|
||||
shutil.copy2(sample.image_path, dest_path)
|
||||
image_ref = f"images/{image_filename}"
|
||||
else:
|
||||
image_ref = sample.image_path
|
||||
|
||||
export_data.append({
|
||||
"id": sample.id,
|
||||
"image_path": image_ref,
|
||||
"ground_truth": sample.ground_truth,
|
||||
"ocr_text": sample.ocr_text,
|
||||
"ocr_confidence": sample.ocr_confidence,
|
||||
"metadata": sample.metadata or {},
|
||||
})
|
||||
|
||||
# Write JSONL file
|
||||
jsonl_path = os.path.join(batch_path, "data.jsonl")
|
||||
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
||||
for item in export_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Also write as single JSON for convenience
|
||||
json_path = os.path.join(batch_path, "data.json")
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Write manifest
|
||||
manifest = {
|
||||
"format": "generic",
|
||||
"version": "1.0",
|
||||
"batch_id": batch_id,
|
||||
"sample_count": len(samples),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"files": {
|
||||
"data_jsonl": "data.jsonl",
|
||||
"data_json": "data.json",
|
||||
"images": "images/",
|
||||
},
|
||||
}
|
||||
manifest_path = os.path.join(batch_path, "manifest.json")
|
||||
with open(manifest_path, 'w') as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
return ExportResult(
|
||||
export_format="generic",
|
||||
export_path=batch_path,
|
||||
sample_count=len(samples),
|
||||
batch_id=batch_id,
|
||||
created_at=datetime.utcnow(),
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
|
||||
class TrainingExportService:
|
||||
"""
|
||||
Main service for exporting OCR labeling data to various training formats.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.trocr_exporter = TrOCRExporter()
|
||||
self.llama_vision_exporter = LlamaVisionExporter()
|
||||
self.generic_exporter = GenericExporter()
|
||||
|
||||
def export(
|
||||
self,
|
||||
samples: List[TrainingSample],
|
||||
export_format: str,
|
||||
batch_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ExportResult:
|
||||
"""
|
||||
Export training samples in the specified format.
|
||||
|
||||
Args:
|
||||
samples: List of training samples
|
||||
export_format: 'trocr', 'llama_vision', or 'generic'
|
||||
batch_id: Optional batch ID (generated if not provided)
|
||||
**kwargs: Additional format-specific options
|
||||
|
||||
Returns:
|
||||
ExportResult with export details
|
||||
"""
|
||||
if not batch_id:
|
||||
batch_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if export_format == "trocr":
|
||||
return self.trocr_exporter.export(samples, batch_id, **kwargs)
|
||||
elif export_format == "llama_vision":
|
||||
return self.llama_vision_exporter.export(samples, batch_id, **kwargs)
|
||||
elif export_format == "generic":
|
||||
return self.generic_exporter.export(samples, batch_id, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown export format: {export_format}")
|
||||
|
||||
def list_exports(self, export_format: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
List all available exports.
|
||||
|
||||
Args:
|
||||
export_format: Optional filter by format
|
||||
|
||||
Returns:
|
||||
List of export manifests
|
||||
"""
|
||||
exports = []
|
||||
|
||||
paths_to_check = []
|
||||
if export_format is None or export_format == "trocr":
|
||||
paths_to_check.append((TROCR_EXPORT_PATH, "trocr"))
|
||||
if export_format is None or export_format == "llama_vision":
|
||||
paths_to_check.append((LLAMA_VISION_EXPORT_PATH, "llama_vision"))
|
||||
if export_format is None or export_format == "generic":
|
||||
paths_to_check.append((GENERIC_EXPORT_PATH, "generic"))
|
||||
|
||||
for base_path, fmt in paths_to_check:
|
||||
if not os.path.exists(base_path):
|
||||
continue
|
||||
for batch_dir in os.listdir(base_path):
|
||||
manifest_path = os.path.join(base_path, batch_dir, "manifest.json")
|
||||
if os.path.exists(manifest_path):
|
||||
with open(manifest_path, 'r') as f:
|
||||
manifest = json.load(f)
|
||||
manifest["export_path"] = os.path.join(base_path, batch_dir)
|
||||
exports.append(manifest)
|
||||
|
||||
return sorted(exports, key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_export_service: Optional[TrainingExportService] = None
|
||||
|
||||
|
||||
def get_training_export_service() -> TrainingExportService:
|
||||
"""Get or create the training export service singleton."""
|
||||
global _export_service
|
||||
if _export_service is None:
|
||||
_export_service = TrainingExportService()
|
||||
return _export_service
|
||||
118
klausur-service/backend/training/models.py
Normal file
118
klausur-service/backend/training/models.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Training API — enums, request/response models, and in-memory state.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENUMS
|
||||
# ============================================================================
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
PREPARING = "preparing"
|
||||
TRAINING = "training"
|
||||
VALIDATING = "validating"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ZEUGNIS = "zeugnis"
|
||||
KLAUSUR = "klausur"
|
||||
GENERAL = "general"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# REQUEST/RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Configuration for a training job."""
|
||||
name: str = Field(..., description="Name for the training job")
|
||||
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
|
||||
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
|
||||
batch_size: int = Field(16, ge=1, le=128)
|
||||
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
|
||||
epochs: int = Field(10, ge=1, le=100)
|
||||
warmup_steps: int = Field(500, ge=0, le=10000)
|
||||
weight_decay: float = Field(0.01, ge=0, le=1)
|
||||
gradient_accumulation: int = Field(4, ge=1, le=32)
|
||||
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
|
||||
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Metrics from a training job."""
|
||||
precision: float = 0.0
|
||||
recall: float = 0.0
|
||||
f1_score: float = 0.0
|
||||
accuracy: float = 0.0
|
||||
loss_history: List[float] = []
|
||||
val_loss_history: List[float] = []
|
||||
|
||||
|
||||
class TrainingJob(BaseModel):
|
||||
"""A training job with full details."""
|
||||
id: str
|
||||
name: str
|
||||
model_type: ModelType
|
||||
status: TrainingStatus
|
||||
progress: float
|
||||
current_epoch: int
|
||||
total_epochs: int
|
||||
loss: float
|
||||
val_loss: float
|
||||
learning_rate: float
|
||||
documents_processed: int
|
||||
total_documents: int
|
||||
started_at: Optional[datetime]
|
||||
estimated_completion: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
error_message: Optional[str]
|
||||
metrics: TrainingMetrics
|
||||
config: TrainingConfig
|
||||
|
||||
|
||||
class ModelVersion(BaseModel):
|
||||
"""A trained model version."""
|
||||
id: str
|
||||
job_id: str
|
||||
version: str
|
||||
model_type: ModelType
|
||||
created_at: datetime
|
||||
metrics: TrainingMetrics
|
||||
is_active: bool
|
||||
size_mb: float
|
||||
bundeslaender: List[str]
|
||||
|
||||
|
||||
class DatasetStats(BaseModel):
|
||||
"""Statistics about the training dataset."""
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
training_allowed: int
|
||||
by_bundesland: Dict[str, int]
|
||||
by_doc_type: Dict[str, int]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# IN-MEMORY STATE (Replace with database in production)
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""Global training state."""
|
||||
jobs: Dict[str, dict] = field(default_factory=dict)
|
||||
model_versions: Dict[str, dict] = field(default_factory=dict)
|
||||
active_job_id: Optional[str] = None
|
||||
|
||||
|
||||
_state = TrainingState()
|
||||
303
klausur-service/backend/training/routes.py
Normal file
303
klausur-service/backend/training/routes.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Training API — FastAPI route handlers.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
TrainingStatus,
|
||||
TrainingConfig,
|
||||
_state,
|
||||
)
|
||||
from .simulation import (
|
||||
simulate_training_progress,
|
||||
training_metrics_generator,
|
||||
batch_ocr_progress_generator,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRAINING JOBS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/jobs", response_model=List[dict])
|
||||
async def list_training_jobs():
|
||||
"""Get all training jobs."""
|
||||
return list(_state.jobs.values())
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=dict)
|
||||
async def get_training_job(job_id: str):
|
||||
"""Get details for a specific training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return _state.jobs[job_id]
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=dict)
|
||||
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
|
||||
"""Create and start a new training job."""
|
||||
# Check if there's already an active job
|
||||
if _state.active_job_id:
|
||||
active_job = _state.jobs.get(_state.active_job_id)
|
||||
if active_job and active_job["status"] in [
|
||||
TrainingStatus.TRAINING.value,
|
||||
TrainingStatus.PREPARING.value,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Another training job is already running"
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = str(uuid.uuid4())
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": config.name,
|
||||
"model_type": config.model_type.value,
|
||||
"status": TrainingStatus.QUEUED.value,
|
||||
"progress": 0,
|
||||
"current_epoch": 0,
|
||||
"total_epochs": config.epochs,
|
||||
"loss": 1.0,
|
||||
"val_loss": 1.0,
|
||||
"learning_rate": config.learning_rate,
|
||||
"documents_processed": 0,
|
||||
"total_documents": len(config.bundeslaender) * 50, # Estimate
|
||||
"started_at": None,
|
||||
"estimated_completion": None,
|
||||
"completed_at": None,
|
||||
"error_message": None,
|
||||
"metrics": {
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"accuracy": 0.0,
|
||||
"loss_history": [],
|
||||
"val_loss_history": [],
|
||||
},
|
||||
"config": config.dict(),
|
||||
}
|
||||
|
||||
_state.jobs[job_id] = job
|
||||
_state.active_job_id = job_id
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"id": job_id, "status": "queued", "message": "Training job created"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/pause", response_model=dict)
|
||||
async def pause_training_job(job_id: str):
|
||||
"""Pause a running training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not running")
|
||||
|
||||
job["status"] = TrainingStatus.PAUSED.value
|
||||
return {"success": True, "message": "Training paused"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/resume", response_model=dict)
|
||||
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
|
||||
"""Resume a paused training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.PAUSED.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not paused")
|
||||
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
_state.active_job_id = job_id
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"success": True, "message": "Training resumed"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel", response_model=dict)
|
||||
async def cancel_training_job(job_id: str):
|
||||
"""Cancel a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.CANCELLED.value
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
if _state.active_job_id == job_id:
|
||||
_state.active_job_id = None
|
||||
|
||||
return {"success": True, "message": "Training cancelled"}
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}", response_model=dict)
|
||||
async def delete_training_job(job_id: str):
|
||||
"""Delete a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete running job")
|
||||
|
||||
del _state.jobs[job_id]
|
||||
return {"success": True, "message": "Job deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MODEL VERSIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/models", response_model=List[dict])
|
||||
async def list_model_versions():
|
||||
"""Get all trained model versions."""
|
||||
return list(_state.model_versions.values())
|
||||
|
||||
|
||||
@router.get("/models/{version_id}", response_model=dict)
|
||||
async def get_model_version(version_id: str):
|
||||
"""Get details for a specific model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
return _state.model_versions[version_id]
|
||||
|
||||
|
||||
@router.post("/models/{version_id}/activate", response_model=dict)
|
||||
async def activate_model_version(version_id: str):
|
||||
"""Set a model version as active."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Deactivate all other versions of same type
|
||||
model = _state.model_versions[version_id]
|
||||
for v in _state.model_versions.values():
|
||||
if v["model_type"] == model["model_type"]:
|
||||
v["is_active"] = False
|
||||
|
||||
model["is_active"] = True
|
||||
return {"success": True, "message": "Model activated"}
|
||||
|
||||
|
||||
@router.delete("/models/{version_id}", response_model=dict)
|
||||
async def delete_model_version(version_id: str):
|
||||
"""Delete a model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
model = _state.model_versions[version_id]
|
||||
if model["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete active model")
|
||||
|
||||
del _state.model_versions[version_id]
|
||||
return {"success": True, "message": "Model deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATASET STATS & STATUS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/dataset/stats", response_model=dict)
|
||||
async def get_dataset_stats():
|
||||
"""Get statistics about the training dataset."""
|
||||
from metrics_db import get_zeugnis_stats
|
||||
|
||||
zeugnis_stats = await get_zeugnis_stats()
|
||||
|
||||
return {
|
||||
"total_documents": zeugnis_stats.get("total_documents", 0),
|
||||
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
|
||||
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
|
||||
"by_bundesland": {
|
||||
bl["bundesland"]: bl.get("doc_count", 0)
|
||||
for bl in zeugnis_stats.get("per_bundesland", [])
|
||||
},
|
||||
"by_doc_type": {
|
||||
"verordnung": 150,
|
||||
"schulordnung": 80,
|
||||
"handreichung": 45,
|
||||
"erlass": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status", response_model=dict)
|
||||
async def get_training_status():
|
||||
"""Get overall training system status."""
|
||||
active_job = None
|
||||
if _state.active_job_id and _state.active_job_id in _state.jobs:
|
||||
active_job = _state.jobs[_state.active_job_id]
|
||||
|
||||
return {
|
||||
"is_training": _state.active_job_id is not None and active_job is not None and
|
||||
active_job["status"] == TrainingStatus.TRAINING.value,
|
||||
"active_job_id": _state.active_job_id,
|
||||
"total_jobs": len(_state.jobs),
|
||||
"completed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.COMPLETED.value
|
||||
),
|
||||
"failed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.FAILED.value
|
||||
),
|
||||
"model_versions": len(_state.model_versions),
|
||||
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SSE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/stream")
|
||||
async def stream_training_metrics(job_id: str, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming training metrics.
|
||||
|
||||
Streams real-time training progress for a specific job.
|
||||
"""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return StreamingResponse(
|
||||
training_metrics_generator(job_id, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ocr/stream")
|
||||
async def stream_batch_ocr(images_count: int, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming batch OCR progress.
|
||||
|
||||
Simulates batch OCR processing with progress updates.
|
||||
"""
|
||||
if images_count < 1 or images_count > 100:
|
||||
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
|
||||
|
||||
return StreamingResponse(
|
||||
batch_ocr_progress_generator(images_count, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
190
klausur-service/backend/training/simulation.py
Normal file
190
klausur-service/backend/training/simulation.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Training API — simulation helper and SSE generators.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .models import TrainingStatus, _state
|
||||
|
||||
|
||||
async def simulate_training_progress(job_id: str):
|
||||
"""Simulate training progress (replace with actual training logic)."""
|
||||
if job_id not in _state.jobs:
|
||||
return
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
job["started_at"] = datetime.now().isoformat()
|
||||
|
||||
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
|
||||
current_step = 0
|
||||
|
||||
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
|
||||
# Update progress
|
||||
progress = (current_step / total_steps) * 100
|
||||
current_epoch = current_step // 100 + 1
|
||||
|
||||
# Simulate decreasing loss
|
||||
base_loss = 0.8 * (1 - progress / 100) + 0.1
|
||||
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
|
||||
val_loss = loss * 1.1
|
||||
|
||||
# Update job state
|
||||
job["progress"] = progress
|
||||
job["current_epoch"] = min(current_epoch, job["total_epochs"])
|
||||
job["loss"] = round(loss, 4)
|
||||
job["val_loss"] = round(val_loss, 4)
|
||||
job["documents_processed"] = int((progress / 100) * job["total_documents"])
|
||||
|
||||
# Update metrics
|
||||
job["metrics"]["loss_history"].append(round(loss, 4))
|
||||
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
|
||||
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
|
||||
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
|
||||
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
|
||||
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
|
||||
|
||||
# Keep only last 50 history points
|
||||
if len(job["metrics"]["loss_history"]) > 50:
|
||||
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
|
||||
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
|
||||
|
||||
# Estimate completion
|
||||
if progress > 0:
|
||||
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
|
||||
remaining = (elapsed / progress) * (100 - progress)
|
||||
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
|
||||
|
||||
current_step += 1
|
||||
await asyncio.sleep(0.5) # Simulate work
|
||||
|
||||
# Mark as completed
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
job["status"] = TrainingStatus.COMPLETED.value
|
||||
job["progress"] = 100
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
# Create model version
|
||||
version_id = str(uuid.uuid4())
|
||||
_state.model_versions[version_id] = {
|
||||
"id": version_id,
|
||||
"job_id": job_id,
|
||||
"version": f"v{len(_state.model_versions) + 1}.0",
|
||||
"model_type": job["model_type"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"metrics": job["metrics"],
|
||||
"is_active": True,
|
||||
"size_mb": 245.7,
|
||||
"bundeslaender": job["config"]["bundeslaender"],
|
||||
}
|
||||
|
||||
_state.active_job_id = None
|
||||
|
||||
|
||||
async def training_metrics_generator(job_id: str, request):
|
||||
"""
|
||||
SSE generator for streaming training metrics.
|
||||
|
||||
Yields JSON-encoded training status updates every 500ms.
|
||||
"""
|
||||
while True:
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Get job status
|
||||
if job_id not in _state.jobs:
|
||||
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
|
||||
break
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
|
||||
# Build metrics response
|
||||
metrics_data = {
|
||||
"job_id": job["id"],
|
||||
"status": job["status"],
|
||||
"progress": job["progress"],
|
||||
"current_epoch": job["current_epoch"],
|
||||
"total_epochs": job["total_epochs"],
|
||||
"current_step": int(job["progress"] * job["total_epochs"]),
|
||||
"total_steps": job["total_epochs"] * 100,
|
||||
"elapsed_time_ms": 0,
|
||||
"estimated_remaining_ms": 0,
|
||||
"metrics": {
|
||||
"loss": job["loss"],
|
||||
"val_loss": job["val_loss"],
|
||||
"accuracy": job["metrics"]["accuracy"],
|
||||
"learning_rate": job["learning_rate"]
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"epoch": i + 1,
|
||||
"step": (i + 1) * 10,
|
||||
"loss": loss,
|
||||
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
|
||||
"learning_rate": job["learning_rate"],
|
||||
"timestamp": 0
|
||||
}
|
||||
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
|
||||
]
|
||||
}
|
||||
|
||||
# Calculate elapsed time
|
||||
if job["started_at"]:
|
||||
started = datetime.fromisoformat(job["started_at"])
|
||||
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
|
||||
|
||||
# Calculate remaining time
|
||||
if job["estimated_completion"]:
|
||||
estimated = datetime.fromisoformat(job["estimated_completion"])
|
||||
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
|
||||
|
||||
# Send SSE event
|
||||
yield f"data: {json.dumps(metrics_data)}\n\n"
|
||||
|
||||
# Check if job completed
|
||||
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
|
||||
break
|
||||
|
||||
# Wait before next update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
async def batch_ocr_progress_generator(images_count: int, request):
|
||||
"""
|
||||
SSE generator for batch OCR progress simulation.
|
||||
|
||||
In production, this would integrate with actual OCR processing.
|
||||
"""
|
||||
import random
|
||||
|
||||
for i in range(images_count):
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(random.uniform(0.3, 0.8))
|
||||
|
||||
progress_data = {
|
||||
"type": "progress",
|
||||
"current": i + 1,
|
||||
"total": images_count,
|
||||
"progress_percent": ((i + 1) / images_count) * 100,
|
||||
"elapsed_ms": (i + 1) * 500,
|
||||
"estimated_remaining_ms": (images_count - i - 1) * 500,
|
||||
"result": {
|
||||
"text": f"Sample recognized text for image {i + 1}",
|
||||
"confidence": round(random.uniform(0.7, 0.98), 2),
|
||||
"processing_time_ms": random.randint(200, 600),
|
||||
"from_cache": random.random() < 0.2
|
||||
}
|
||||
}
|
||||
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
# Send completion event
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
|
||||
261
klausur-service/backend/training/trocr_api.py
Normal file
261
klausur-service/backend/training/trocr_api.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
TrOCR API - REST endpoints for TrOCR handwriting OCR.
|
||||
|
||||
Provides:
|
||||
- /ocr/trocr - Single image OCR
|
||||
- /ocr/trocr/batch - Batch image processing
|
||||
- /ocr/trocr/status - Model status
|
||||
- /ocr/trocr/cache - Cache statistics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.trocr_service import (
|
||||
run_trocr_ocr_enhanced,
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
get_model_status,
|
||||
get_cache_stats,
|
||||
preload_trocr_model,
|
||||
OCRResult,
|
||||
BatchOCRResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MODELS
|
||||
# =============================================================================
|
||||
|
||||
class TrOCRResponse(BaseModel):
|
||||
"""Response model for single image OCR."""
|
||||
text: str = Field(..., description="Extracted text")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
||||
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
|
||||
model: str = Field(..., description="Model used for OCR")
|
||||
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
|
||||
from_cache: bool = Field(False, description="Whether result was from cache")
|
||||
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
|
||||
word_count: int = Field(0, description="Number of words detected")
|
||||
|
||||
|
||||
class BatchOCRResponse(BaseModel):
|
||||
"""Response model for batch OCR."""
|
||||
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
|
||||
total_time_ms: int = Field(..., ge=0, description="Total processing time")
|
||||
processed_count: int = Field(..., ge=0, description="Number of images processed")
|
||||
cached_count: int = Field(0, description="Number of results from cache")
|
||||
error_count: int = Field(0, description="Number of errors")
|
||||
|
||||
|
||||
class ModelStatusResponse(BaseModel):
|
||||
"""Response model for model status."""
|
||||
status: str = Field(..., description="Model status: available, not_installed")
|
||||
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
|
||||
model_name: Optional[str] = Field(None, description="Name of loaded model")
|
||||
device: Optional[str] = Field(None, description="Device model is running on")
|
||||
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
"""Response model for cache statistics."""
|
||||
size: int = Field(..., ge=0, description="Current cache size")
|
||||
max_size: int = Field(..., ge=0, description="Maximum cache size")
|
||||
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/status", response_model=ModelStatusResponse)
|
||||
async def get_trocr_status():
|
||||
"""
|
||||
Get TrOCR model status.
|
||||
|
||||
Returns information about whether the model is loaded and available.
|
||||
"""
|
||||
return get_model_status()
|
||||
|
||||
|
||||
@router.get("/cache", response_model=CacheStatsResponse)
|
||||
async def get_trocr_cache_stats():
|
||||
"""
|
||||
Get TrOCR cache statistics.
|
||||
|
||||
Returns information about the OCR result cache.
|
||||
"""
|
||||
return get_cache_stats()
|
||||
|
||||
|
||||
@router.post("/preload")
|
||||
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
|
||||
"""
|
||||
Preload TrOCR model into memory.
|
||||
|
||||
This speeds up the first OCR request by loading the model ahead of time.
|
||||
"""
|
||||
success = preload_trocr_model(handwritten=handwritten)
|
||||
if success:
|
||||
return {"status": "success", "message": "Model preloaded successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to preload model")
|
||||
|
||||
|
||||
@router.post("", response_model=TrOCRResponse)
|
||||
async def run_trocr(
|
||||
file: UploadFile = File(..., description="Image file to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split image into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on a single image.
|
||||
|
||||
Supports PNG, JPG, and other common image formats.
|
||||
"""
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
try:
|
||||
image_data = await file.read()
|
||||
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return TrOCRResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
model=result.model,
|
||||
has_lora_adapter=result.has_lora_adapter,
|
||||
from_cache=result.from_cache,
|
||||
image_hash=result.image_hash,
|
||||
word_count=len(result.text.split()) if result.text else 0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=BatchOCRResponse)
|
||||
async def run_trocr_batch_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images.
|
||||
|
||||
Processes images sequentially and returns all results.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
batch_result = await run_trocr_batch(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return BatchOCRResponse(
|
||||
results=[
|
||||
TrOCRResponse(
|
||||
text=r.text,
|
||||
confidence=r.confidence,
|
||||
processing_time_ms=r.processing_time_ms,
|
||||
model=r.model,
|
||||
has_lora_adapter=r.has_lora_adapter,
|
||||
from_cache=r.from_cache,
|
||||
image_hash=r.image_hash,
|
||||
word_count=len(r.text.split()) if r.text else 0
|
||||
)
|
||||
for r in batch_result.results
|
||||
],
|
||||
total_time_ms=batch_result.total_time_ms,
|
||||
processed_count=batch_result.processed_count,
|
||||
cached_count=batch_result.cached_count,
|
||||
error_count=batch_result.error_count
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR batch API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch/stream")
|
||||
async def run_trocr_batch_stream_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
|
||||
|
||||
Returns a stream of progress events as images are processed.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
async def event_generator():
|
||||
async for update in run_trocr_batch_stream(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
):
|
||||
yield f"data: {json.dumps(update)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR stream API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user