Files
breakpilot-lehrer/klausur-service/backend/services/trocr_onnx_service.py
Benjamin Admin be7f5f1872 feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management
D2: TrOCR ONNX export script (printed + handwritten, int8 quantization)
D3: PP-DocLayout ONNX export script (download or Docker-based conversion)
B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config)
A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND)
A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND)
B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding
C3: TrOCR-ONNX.md documentation
C4: OCR-Pipeline.md ONNX section added
C5: mkdocs.yml nav updated, optimum added to requirements.txt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 09:53:02 +01:00

431 lines
13 KiB
Python

"""
TrOCR ONNX Service
ONNX-optimized inference for TrOCR text recognition.
Uses optimum.onnxruntime.ORTModelForVision2Seq for hardware-accelerated
inference without requiring PyTorch at runtime.
Advantages over PyTorch backend:
- 2-4x faster inference on CPU
- Lower memory footprint (~300 MB vs ~600 MB)
- No PyTorch/CUDA dependency at runtime
- Apple Silicon (CoreML) and x86 (OpenVINO) acceleration
Model paths searched (in order):
1. TROCR_ONNX_DIR environment variable
2. /root/.cache/huggingface/onnx/trocr-base-{printed,handwritten}/ (Docker)
3. models/onnx/trocr-base-{printed,handwritten}/ (local dev)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import os
import logging
import time
import asyncio
from pathlib import Path
from typing import Tuple, Optional, List, Dict, Any
from datetime import datetime
logger = logging.getLogger(__name__)
# Re-use shared types and cache from trocr_service
from .trocr_service import (
OCRResult,
_compute_image_hash,
_cache_get,
_cache_set,
_split_into_lines,
)
# ---------------------------------------------------------------------------
# Module-level state
# ---------------------------------------------------------------------------
# {model_key: (processor, model)} — model_key = "printed" | "handwritten"
_onnx_models: Dict[str, Any] = {}
_onnx_available: Optional[bool] = None
_onnx_model_loaded_at: Optional[datetime] = None
# ---------------------------------------------------------------------------
# Path resolution
# ---------------------------------------------------------------------------
_VARIANT_NAMES = {
False: "trocr-base-printed",
True: "trocr-base-handwritten",
}
# HuggingFace model IDs (used for processor downloads)
_HF_MODEL_IDS = {
False: "microsoft/trocr-base-printed",
True: "microsoft/trocr-base-handwritten",
}
def _resolve_onnx_model_dir(handwritten: bool = False) -> Optional[Path]:
"""
Resolve the directory containing ONNX model files for the given variant.
Search order:
1. TROCR_ONNX_DIR env var (appended with variant name)
2. /root/.cache/huggingface/onnx/<variant>/ (Docker)
3. models/onnx/<variant>/ (local dev, relative to this file)
Returns the first directory that exists and contains at least one .onnx file,
or None if no valid directory is found.
"""
variant = _VARIANT_NAMES[handwritten]
candidates: List[Path] = []
# 1. Environment variable
env_dir = os.environ.get("TROCR_ONNX_DIR")
if env_dir:
candidates.append(Path(env_dir) / variant)
# Also allow the env var to point directly at a variant dir
candidates.append(Path(env_dir))
# 2. Docker path
candidates.append(Path(f"/root/.cache/huggingface/onnx/{variant}"))
# 3. Local dev path (relative to klausur-service/backend/)
backend_dir = Path(__file__).resolve().parent.parent
candidates.append(backend_dir / "models" / "onnx" / variant)
for candidate in candidates:
if candidate.is_dir():
# Check for ONNX files or a model config (optimum stores config.json)
onnx_files = list(candidate.glob("*.onnx"))
has_config = (candidate / "config.json").exists()
if onnx_files or has_config:
logger.info(f"ONNX model directory resolved: {candidate}")
return candidate
return None
# ---------------------------------------------------------------------------
# Availability checks
# ---------------------------------------------------------------------------
def _check_onnx_runtime_available() -> bool:
"""Check if onnxruntime and optimum are importable."""
try:
import onnxruntime # noqa: F401
from optimum.onnxruntime import ORTModelForVision2Seq # noqa: F401
from transformers import TrOCRProcessor # noqa: F401
return True
except ImportError as e:
logger.debug(f"ONNX runtime dependencies not available: {e}")
return False
def is_onnx_available(handwritten: bool = False) -> bool:
"""
Check whether ONNX inference is available for the given variant.
Returns True only when:
- onnxruntime + optimum are installed
- A valid model directory with ONNX files exists
"""
if not _check_onnx_runtime_available():
return False
return _resolve_onnx_model_dir(handwritten=handwritten) is not None
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def _get_onnx_model(handwritten: bool = False):
"""
Lazy-load ONNX model and processor.
Returns:
Tuple of (processor, model) or (None, None) if unavailable.
"""
global _onnx_model_loaded_at
model_key = "handwritten" if handwritten else "printed"
if model_key in _onnx_models:
return _onnx_models[model_key]
model_dir = _resolve_onnx_model_dir(handwritten=handwritten)
if model_dir is None:
logger.warning(
f"No ONNX model directory found for variant "
f"{'handwritten' if handwritten else 'printed'}"
)
return None, None
if not _check_onnx_runtime_available():
logger.warning("ONNX runtime dependencies not installed")
return None, None
try:
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
hf_id = _HF_MODEL_IDS[handwritten]
logger.info(f"Loading ONNX TrOCR model from {model_dir} (processor: {hf_id})")
t0 = time.monotonic()
# Load processor from HuggingFace (tokenizer + feature extractor)
processor = TrOCRProcessor.from_pretrained(hf_id)
# Load ONNX model from local directory
model = ORTModelForVision2Seq.from_pretrained(str(model_dir))
elapsed = time.monotonic() - t0
logger.info(
f"ONNX TrOCR model loaded in {elapsed:.1f}s "
f"(variant={model_key}, dir={model_dir})"
)
_onnx_models[model_key] = (processor, model)
_onnx_model_loaded_at = datetime.now()
return processor, model
except Exception as e:
logger.error(f"Failed to load ONNX TrOCR model ({model_key}): {e}")
import traceback
logger.error(traceback.format_exc())
return None, None
def preload_onnx_model(handwritten: bool = True) -> bool:
"""
Preload ONNX model at startup for faster first request.
Call from FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_onnx_model()
"""
logger.info(f"Preloading ONNX TrOCR model (handwritten={handwritten})...")
processor, model = _get_onnx_model(handwritten=handwritten)
if processor is not None and model is not None:
logger.info("ONNX TrOCR model preloaded successfully")
return True
else:
logger.warning("ONNX TrOCR model preloading failed")
return False
# ---------------------------------------------------------------------------
# Status
# ---------------------------------------------------------------------------
def get_onnx_model_status() -> Dict[str, Any]:
"""Get current ONNX model status information."""
runtime_ok = _check_onnx_runtime_available()
printed_dir = _resolve_onnx_model_dir(handwritten=False)
handwritten_dir = _resolve_onnx_model_dir(handwritten=True)
printed_loaded = "printed" in _onnx_models
handwritten_loaded = "handwritten" in _onnx_models
# Detect ONNX runtime providers
providers = []
if runtime_ok:
try:
import onnxruntime
providers = onnxruntime.get_available_providers()
except Exception:
pass
return {
"backend": "onnx",
"runtime_available": runtime_ok,
"providers": providers,
"printed": {
"model_dir": str(printed_dir) if printed_dir else None,
"available": printed_dir is not None and runtime_ok,
"loaded": printed_loaded,
},
"handwritten": {
"model_dir": str(handwritten_dir) if handwritten_dir else None,
"available": handwritten_dir is not None and runtime_ok,
"loaded": handwritten_loaded,
},
"loaded_at": _onnx_model_loaded_at.isoformat() if _onnx_model_loaded_at else None,
}
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
async def run_trocr_onnx(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Tuple[Optional[str], float]:
"""
Run TrOCR OCR using ONNX backend.
Mirrors the interface of trocr_service.run_trocr_ocr.
Args:
image_data: Raw image bytes (PNG, JPEG, etc.)
handwritten: Use handwritten model variant
split_lines: Split image into text lines before recognition
Returns:
Tuple of (extracted_text, confidence).
Returns (None, 0.0) on failure.
"""
processor, model = _get_onnx_model(handwritten=handwritten)
if processor is None or model is None:
logger.error("ONNX TrOCR model not available")
return None, 0.0
try:
from PIL import Image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text: List[str] = []
confidences: List[float] = []
for line_image in lines:
# Prepare input — processor returns PyTorch tensors
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
# Generate via ONNX (ORTModelForVision2Seq.generate is compatible)
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(
f"ONNX TrOCR extracted {len(text)} chars from {len(lines)} lines"
)
return text, confidence
except Exception as e:
logger.error(f"ONNX TrOCR failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_onnx_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
) -> OCRResult:
"""
Enhanced ONNX TrOCR with caching and detailed results.
Mirrors the interface of trocr_service.run_trocr_ocr_enhanced.
Args:
image_data: Raw image bytes
handwritten: Use handwritten model variant
split_lines: Split image into text lines
use_cache: Use SHA256-based in-memory cache
Returns:
OCRResult with text, confidence, timing, word boxes, etc.
"""
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash,
)
# Run ONNX inference
text, confidence = await run_trocr_onnx(
image_data, handwritten=handwritten, split_lines=split_lines
)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes: List[Dict[str, Any]] = []
if text:
words = text.split()
for word in words:
word_conf = min(
1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)
)
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0],
})
# Generate character confidences
char_confidences: List[float] = []
if text:
for char in text:
char_conf = min(
1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)
)
char_confidences.append(char_conf)
model_name = (
"trocr-base-handwritten-onnx" if handwritten else "trocr-base-printed-onnx"
)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model=model_name,
has_lora_adapter=False,
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash,
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes,
})
return result