This repository has been archived on 2026-02-15. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
BreakPilot Dev 53219e3eaf feat(klausur-service): Add Tesseract OCR, DSFA RAG, TrOCR, grid detection and vocab session store
New modules:
- tesseract_vocab_extractor.py: Bounding-box OCR with multi-PSM pipeline
- grid_detection_service.py: CV-based grid/table detection for worksheets
- vocab_session_store.py: PostgreSQL persistence for vocab sessions
- trocr_api.py: TrOCR handwriting recognition endpoint
- dsfa_rag_api.py + dsfa_corpus_ingestion.py: DSFA RAG corpus search

Changes:
- Dockerfile: Install tesseract-ocr + deu/eng language packs
- requirements.txt: Add PyMuPDF, pytesseract, Pillow
- main.py: Register new routers, init DB pools + Qdrant collections

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 00:00:19 +01:00

262 lines
9.1 KiB
Python

"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))