Files
breakpilot-lehrer/klausur-service/backend/services/trocr_service.py
Benjamin Admin be7f5f1872 feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management
D2: TrOCR ONNX export script (printed + handwritten, int8 quantization)
D3: PP-DocLayout ONNX export script (download or Docker-based conversion)
B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config)
A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND)
A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND)
B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding
C3: TrOCR-ONNX.md documentation
C4: OCR-Pipeline.md ONNX section added
C5: mkdocs.yml nav updated, optimum added to requirements.txt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 09:53:02 +01:00

731 lines
23 KiB
Python

"""
TrOCR Service
Microsoft's Transformer-based OCR for text recognition.
Besonders geeignet fuer:
- Gedruckten Text
- Saubere Scans
- Schnelle Verarbeitung
Model: microsoft/trocr-base-printed (oder handwritten Variante)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
Phase 2 Enhancements:
- Batch processing for multiple images
- SHA256-based caching for repeated requests
- Model preloading for faster first request
- Word-level bounding boxes with confidence
"""
import io
import os
import hashlib
import logging
import time
import asyncio
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias → base-printed
_trocr_model = None # backwards-compat alias → base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async — return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
- "pytorch" — always use PyTorch (original behaviour)
- "auto" — try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
if __name__ == "__main__":
import asyncio
import sys
handwritten = "--handwritten" in sys.argv
args = [a for a in sys.argv[1:] if not a.startswith("--")]
if args:
asyncio.run(test_trocr_ocr(args[0], handwritten=handwritten))
else:
print("Usage: python trocr_service.py <image_path> [--handwritten]")