Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
262 lines
9.1 KiB
Python
262 lines
9.1 KiB
Python
"""
|
|
TrOCR API - REST endpoints for TrOCR handwriting OCR.
|
|
|
|
Provides:
|
|
- /ocr/trocr - Single image OCR
|
|
- /ocr/trocr/batch - Batch image processing
|
|
- /ocr/trocr/status - Model status
|
|
- /ocr/trocr/cache - Cache statistics
|
|
"""
|
|
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional
|
|
import json
|
|
import logging
|
|
|
|
from services.trocr_service import (
|
|
run_trocr_ocr_enhanced,
|
|
run_trocr_batch,
|
|
run_trocr_batch_stream,
|
|
get_model_status,
|
|
get_cache_stats,
|
|
preload_trocr_model,
|
|
OCRResult,
|
|
BatchOCRResult
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
|
|
|
|
|
|
# =============================================================================
|
|
# MODELS
|
|
# =============================================================================
|
|
|
|
class TrOCRResponse(BaseModel):
|
|
"""Response model for single image OCR."""
|
|
text: str = Field(..., description="Extracted text")
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
|
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
|
|
model: str = Field(..., description="Model used for OCR")
|
|
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
|
|
from_cache: bool = Field(False, description="Whether result was from cache")
|
|
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
|
|
word_count: int = Field(0, description="Number of words detected")
|
|
|
|
|
|
class BatchOCRResponse(BaseModel):
|
|
"""Response model for batch OCR."""
|
|
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
|
|
total_time_ms: int = Field(..., ge=0, description="Total processing time")
|
|
processed_count: int = Field(..., ge=0, description="Number of images processed")
|
|
cached_count: int = Field(0, description="Number of results from cache")
|
|
error_count: int = Field(0, description="Number of errors")
|
|
|
|
|
|
class ModelStatusResponse(BaseModel):
|
|
"""Response model for model status."""
|
|
status: str = Field(..., description="Model status: available, not_installed")
|
|
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
|
|
model_name: Optional[str] = Field(None, description="Name of loaded model")
|
|
device: Optional[str] = Field(None, description="Device model is running on")
|
|
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
|
|
|
|
|
|
class CacheStatsResponse(BaseModel):
|
|
"""Response model for cache statistics."""
|
|
size: int = Field(..., ge=0, description="Current cache size")
|
|
max_size: int = Field(..., ge=0, description="Maximum cache size")
|
|
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
|
|
|
|
|
|
# =============================================================================
|
|
# ENDPOINTS
|
|
# =============================================================================
|
|
|
|
@router.get("/status", response_model=ModelStatusResponse)
|
|
async def get_trocr_status():
|
|
"""
|
|
Get TrOCR model status.
|
|
|
|
Returns information about whether the model is loaded and available.
|
|
"""
|
|
return get_model_status()
|
|
|
|
|
|
@router.get("/cache", response_model=CacheStatsResponse)
|
|
async def get_trocr_cache_stats():
|
|
"""
|
|
Get TrOCR cache statistics.
|
|
|
|
Returns information about the OCR result cache.
|
|
"""
|
|
return get_cache_stats()
|
|
|
|
|
|
@router.post("/preload")
|
|
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
|
|
"""
|
|
Preload TrOCR model into memory.
|
|
|
|
This speeds up the first OCR request by loading the model ahead of time.
|
|
"""
|
|
success = preload_trocr_model(handwritten=handwritten)
|
|
if success:
|
|
return {"status": "success", "message": "Model preloaded successfully"}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to preload model")
|
|
|
|
|
|
@router.post("", response_model=TrOCRResponse)
|
|
async def run_trocr(
|
|
file: UploadFile = File(..., description="Image file to process"),
|
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
|
split_lines: bool = Query(True, description="Split image into lines"),
|
|
use_cache: bool = Query(True, description="Use result caching")
|
|
):
|
|
"""
|
|
Run TrOCR on a single image.
|
|
|
|
Supports PNG, JPG, and other common image formats.
|
|
"""
|
|
# Validate file type
|
|
if not file.content_type or not file.content_type.startswith("image/"):
|
|
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
try:
|
|
image_data = await file.read()
|
|
|
|
result = await run_trocr_ocr_enhanced(
|
|
image_data,
|
|
handwritten=handwritten,
|
|
split_lines=split_lines,
|
|
use_cache=use_cache
|
|
)
|
|
|
|
return TrOCRResponse(
|
|
text=result.text,
|
|
confidence=result.confidence,
|
|
processing_time_ms=result.processing_time_ms,
|
|
model=result.model,
|
|
has_lora_adapter=result.has_lora_adapter,
|
|
from_cache=result.from_cache,
|
|
image_hash=result.image_hash,
|
|
word_count=len(result.text.split()) if result.text else 0
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"TrOCR API error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/batch", response_model=BatchOCRResponse)
|
|
async def run_trocr_batch_endpoint(
|
|
files: List[UploadFile] = File(..., description="Image files to process"),
|
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
|
split_lines: bool = Query(True, description="Split images into lines"),
|
|
use_cache: bool = Query(True, description="Use result caching")
|
|
):
|
|
"""
|
|
Run TrOCR on multiple images.
|
|
|
|
Processes images sequentially and returns all results.
|
|
"""
|
|
if not files:
|
|
raise HTTPException(status_code=400, detail="No files provided")
|
|
|
|
if len(files) > 50:
|
|
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
|
|
|
try:
|
|
images = []
|
|
for file in files:
|
|
if not file.content_type or not file.content_type.startswith("image/"):
|
|
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
|
images.append(await file.read())
|
|
|
|
batch_result = await run_trocr_batch(
|
|
images,
|
|
handwritten=handwritten,
|
|
split_lines=split_lines,
|
|
use_cache=use_cache
|
|
)
|
|
|
|
return BatchOCRResponse(
|
|
results=[
|
|
TrOCRResponse(
|
|
text=r.text,
|
|
confidence=r.confidence,
|
|
processing_time_ms=r.processing_time_ms,
|
|
model=r.model,
|
|
has_lora_adapter=r.has_lora_adapter,
|
|
from_cache=r.from_cache,
|
|
image_hash=r.image_hash,
|
|
word_count=len(r.text.split()) if r.text else 0
|
|
)
|
|
for r in batch_result.results
|
|
],
|
|
total_time_ms=batch_result.total_time_ms,
|
|
processed_count=batch_result.processed_count,
|
|
cached_count=batch_result.cached_count,
|
|
error_count=batch_result.error_count
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"TrOCR batch API error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/batch/stream")
|
|
async def run_trocr_batch_stream_endpoint(
|
|
files: List[UploadFile] = File(..., description="Image files to process"),
|
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
|
split_lines: bool = Query(True, description="Split images into lines"),
|
|
use_cache: bool = Query(True, description="Use result caching")
|
|
):
|
|
"""
|
|
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
|
|
|
|
Returns a stream of progress events as images are processed.
|
|
"""
|
|
if not files:
|
|
raise HTTPException(status_code=400, detail="No files provided")
|
|
|
|
if len(files) > 50:
|
|
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
|
|
|
try:
|
|
images = []
|
|
for file in files:
|
|
if not file.content_type or not file.content_type.startswith("image/"):
|
|
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
|
images.append(await file.read())
|
|
|
|
async def event_generator():
|
|
async for update in run_trocr_batch_stream(
|
|
images,
|
|
handwritten=handwritten,
|
|
split_lines=split_lines,
|
|
use_cache=use_cache
|
|
):
|
|
yield f"data: {json.dumps(update)}\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive"
|
|
}
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"TrOCR stream API error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|