Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
261
klausur-service/backend/trocr_api.py
Normal file
261
klausur-service/backend/trocr_api.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
TrOCR API - REST endpoints for TrOCR handwriting OCR.
|
||||
|
||||
Provides:
|
||||
- /ocr/trocr - Single image OCR
|
||||
- /ocr/trocr/batch - Batch image processing
|
||||
- /ocr/trocr/status - Model status
|
||||
- /ocr/trocr/cache - Cache statistics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.trocr_service import (
|
||||
run_trocr_ocr_enhanced,
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
get_model_status,
|
||||
get_cache_stats,
|
||||
preload_trocr_model,
|
||||
OCRResult,
|
||||
BatchOCRResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MODELS
|
||||
# =============================================================================
|
||||
|
||||
class TrOCRResponse(BaseModel):
|
||||
"""Response model for single image OCR."""
|
||||
text: str = Field(..., description="Extracted text")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
||||
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
|
||||
model: str = Field(..., description="Model used for OCR")
|
||||
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
|
||||
from_cache: bool = Field(False, description="Whether result was from cache")
|
||||
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
|
||||
word_count: int = Field(0, description="Number of words detected")
|
||||
|
||||
|
||||
class BatchOCRResponse(BaseModel):
|
||||
"""Response model for batch OCR."""
|
||||
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
|
||||
total_time_ms: int = Field(..., ge=0, description="Total processing time")
|
||||
processed_count: int = Field(..., ge=0, description="Number of images processed")
|
||||
cached_count: int = Field(0, description="Number of results from cache")
|
||||
error_count: int = Field(0, description="Number of errors")
|
||||
|
||||
|
||||
class ModelStatusResponse(BaseModel):
|
||||
"""Response model for model status."""
|
||||
status: str = Field(..., description="Model status: available, not_installed")
|
||||
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
|
||||
model_name: Optional[str] = Field(None, description="Name of loaded model")
|
||||
device: Optional[str] = Field(None, description="Device model is running on")
|
||||
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
"""Response model for cache statistics."""
|
||||
size: int = Field(..., ge=0, description="Current cache size")
|
||||
max_size: int = Field(..., ge=0, description="Maximum cache size")
|
||||
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/status", response_model=ModelStatusResponse)
|
||||
async def get_trocr_status():
|
||||
"""
|
||||
Get TrOCR model status.
|
||||
|
||||
Returns information about whether the model is loaded and available.
|
||||
"""
|
||||
return get_model_status()
|
||||
|
||||
|
||||
@router.get("/cache", response_model=CacheStatsResponse)
|
||||
async def get_trocr_cache_stats():
|
||||
"""
|
||||
Get TrOCR cache statistics.
|
||||
|
||||
Returns information about the OCR result cache.
|
||||
"""
|
||||
return get_cache_stats()
|
||||
|
||||
|
||||
@router.post("/preload")
|
||||
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
|
||||
"""
|
||||
Preload TrOCR model into memory.
|
||||
|
||||
This speeds up the first OCR request by loading the model ahead of time.
|
||||
"""
|
||||
success = preload_trocr_model(handwritten=handwritten)
|
||||
if success:
|
||||
return {"status": "success", "message": "Model preloaded successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to preload model")
|
||||
|
||||
|
||||
@router.post("", response_model=TrOCRResponse)
|
||||
async def run_trocr(
|
||||
file: UploadFile = File(..., description="Image file to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split image into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on a single image.
|
||||
|
||||
Supports PNG, JPG, and other common image formats.
|
||||
"""
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
try:
|
||||
image_data = await file.read()
|
||||
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return TrOCRResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
model=result.model,
|
||||
has_lora_adapter=result.has_lora_adapter,
|
||||
from_cache=result.from_cache,
|
||||
image_hash=result.image_hash,
|
||||
word_count=len(result.text.split()) if result.text else 0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=BatchOCRResponse)
|
||||
async def run_trocr_batch_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images.
|
||||
|
||||
Processes images sequentially and returns all results.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
batch_result = await run_trocr_batch(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return BatchOCRResponse(
|
||||
results=[
|
||||
TrOCRResponse(
|
||||
text=r.text,
|
||||
confidence=r.confidence,
|
||||
processing_time_ms=r.processing_time_ms,
|
||||
model=r.model,
|
||||
has_lora_adapter=r.has_lora_adapter,
|
||||
from_cache=r.from_cache,
|
||||
image_hash=r.image_hash,
|
||||
word_count=len(r.text.split()) if r.text else 0
|
||||
)
|
||||
for r in batch_result.results
|
||||
],
|
||||
total_time_ms=batch_result.total_time_ms,
|
||||
processed_count=batch_result.processed_count,
|
||||
cached_count=batch_result.cached_count,
|
||||
error_count=batch_result.error_count
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR batch API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch/stream")
|
||||
async def run_trocr_batch_stream_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
|
||||
|
||||
Returns a stream of progress events as images are processed.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
async def event_generator():
|
||||
async for update in run_trocr_batch_stream(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
):
|
||||
yield f"data: {json.dumps(update)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR stream API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user