Files
breakpilot-lehrer/klausur-service/backend/trocr_api.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

262 lines
9.1 KiB
Python

"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))