""" TrOCR ONNX Service ONNX-optimized inference for TrOCR text recognition. Uses optimum.onnxruntime.ORTModelForVision2Seq for hardware-accelerated inference without requiring PyTorch at runtime. Advantages over PyTorch backend: - 2-4x faster inference on CPU - Lower memory footprint (~300 MB vs ~600 MB) - No PyTorch/CUDA dependency at runtime - Apple Silicon (CoreML) and x86 (OpenVINO) acceleration Model paths searched (in order): 1. TROCR_ONNX_DIR environment variable 2. /root/.cache/huggingface/onnx/trocr-base-{printed,handwritten}/ (Docker) 3. models/onnx/trocr-base-{printed,handwritten}/ (local dev) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import io import os import logging import time import asyncio from pathlib import Path from typing import Tuple, Optional, List, Dict, Any from datetime import datetime logger = logging.getLogger(__name__) # Re-use shared types and cache from trocr_service from .trocr_service import ( OCRResult, _compute_image_hash, _cache_get, _cache_set, _split_into_lines, ) # --------------------------------------------------------------------------- # Module-level state # --------------------------------------------------------------------------- # {model_key: (processor, model)} — model_key = "printed" | "handwritten" _onnx_models: Dict[str, Any] = {} _onnx_available: Optional[bool] = None _onnx_model_loaded_at: Optional[datetime] = None # --------------------------------------------------------------------------- # Path resolution # --------------------------------------------------------------------------- _VARIANT_NAMES = { False: "trocr-base-printed", True: "trocr-base-handwritten", } # HuggingFace model IDs (used for processor downloads) _HF_MODEL_IDS = { False: "microsoft/trocr-base-printed", True: "microsoft/trocr-base-handwritten", } def _resolve_onnx_model_dir(handwritten: bool = False) -> Optional[Path]: """ Resolve the directory containing ONNX model files for the given variant. Search order: 1. TROCR_ONNX_DIR env var (appended with variant name) 2. /root/.cache/huggingface/onnx// (Docker) 3. models/onnx// (local dev, relative to this file) Returns the first directory that exists and contains at least one .onnx file, or None if no valid directory is found. """ variant = _VARIANT_NAMES[handwritten] candidates: List[Path] = [] # 1. Environment variable env_dir = os.environ.get("TROCR_ONNX_DIR") if env_dir: candidates.append(Path(env_dir) / variant) # Also allow the env var to point directly at a variant dir candidates.append(Path(env_dir)) # 2. Docker path candidates.append(Path(f"/root/.cache/huggingface/onnx/{variant}")) # 3. Local dev path (relative to klausur-service/backend/) backend_dir = Path(__file__).resolve().parent.parent candidates.append(backend_dir / "models" / "onnx" / variant) for candidate in candidates: if candidate.is_dir(): # Check for ONNX files or a model config (optimum stores config.json) onnx_files = list(candidate.glob("*.onnx")) has_config = (candidate / "config.json").exists() if onnx_files or has_config: logger.info(f"ONNX model directory resolved: {candidate}") return candidate return None # --------------------------------------------------------------------------- # Availability checks # --------------------------------------------------------------------------- def _check_onnx_runtime_available() -> bool: """Check if onnxruntime and optimum are importable.""" try: import onnxruntime # noqa: F401 from optimum.onnxruntime import ORTModelForVision2Seq # noqa: F401 from transformers import TrOCRProcessor # noqa: F401 return True except ImportError as e: logger.debug(f"ONNX runtime dependencies not available: {e}") return False def is_onnx_available(handwritten: bool = False) -> bool: """ Check whether ONNX inference is available for the given variant. Returns True only when: - onnxruntime + optimum are installed - A valid model directory with ONNX files exists """ if not _check_onnx_runtime_available(): return False return _resolve_onnx_model_dir(handwritten=handwritten) is not None # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def _get_onnx_model(handwritten: bool = False): """ Lazy-load ONNX model and processor. Returns: Tuple of (processor, model) or (None, None) if unavailable. """ global _onnx_model_loaded_at model_key = "handwritten" if handwritten else "printed" if model_key in _onnx_models: return _onnx_models[model_key] model_dir = _resolve_onnx_model_dir(handwritten=handwritten) if model_dir is None: logger.warning( f"No ONNX model directory found for variant " f"{'handwritten' if handwritten else 'printed'}" ) return None, None if not _check_onnx_runtime_available(): logger.warning("ONNX runtime dependencies not installed") return None, None try: from optimum.onnxruntime import ORTModelForVision2Seq from transformers import TrOCRProcessor hf_id = _HF_MODEL_IDS[handwritten] logger.info(f"Loading ONNX TrOCR model from {model_dir} (processor: {hf_id})") t0 = time.monotonic() # Load processor from HuggingFace (tokenizer + feature extractor) processor = TrOCRProcessor.from_pretrained(hf_id) # Load ONNX model from local directory model = ORTModelForVision2Seq.from_pretrained(str(model_dir)) elapsed = time.monotonic() - t0 logger.info( f"ONNX TrOCR model loaded in {elapsed:.1f}s " f"(variant={model_key}, dir={model_dir})" ) _onnx_models[model_key] = (processor, model) _onnx_model_loaded_at = datetime.now() return processor, model except Exception as e: logger.error(f"Failed to load ONNX TrOCR model ({model_key}): {e}") import traceback logger.error(traceback.format_exc()) return None, None def preload_onnx_model(handwritten: bool = True) -> bool: """ Preload ONNX model at startup for faster first request. Call from FastAPI startup event: @app.on_event("startup") async def startup(): preload_onnx_model() """ logger.info(f"Preloading ONNX TrOCR model (handwritten={handwritten})...") processor, model = _get_onnx_model(handwritten=handwritten) if processor is not None and model is not None: logger.info("ONNX TrOCR model preloaded successfully") return True else: logger.warning("ONNX TrOCR model preloading failed") return False # --------------------------------------------------------------------------- # Status # --------------------------------------------------------------------------- def get_onnx_model_status() -> Dict[str, Any]: """Get current ONNX model status information.""" runtime_ok = _check_onnx_runtime_available() printed_dir = _resolve_onnx_model_dir(handwritten=False) handwritten_dir = _resolve_onnx_model_dir(handwritten=True) printed_loaded = "printed" in _onnx_models handwritten_loaded = "handwritten" in _onnx_models # Detect ONNX runtime providers providers = [] if runtime_ok: try: import onnxruntime providers = onnxruntime.get_available_providers() except Exception: pass return { "backend": "onnx", "runtime_available": runtime_ok, "providers": providers, "printed": { "model_dir": str(printed_dir) if printed_dir else None, "available": printed_dir is not None and runtime_ok, "loaded": printed_loaded, }, "handwritten": { "model_dir": str(handwritten_dir) if handwritten_dir else None, "available": handwritten_dir is not None and runtime_ok, "loaded": handwritten_loaded, }, "loaded_at": _onnx_model_loaded_at.isoformat() if _onnx_model_loaded_at else None, } # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- async def run_trocr_onnx( image_data: bytes, handwritten: bool = False, split_lines: bool = True, ) -> Tuple[Optional[str], float]: """ Run TrOCR OCR using ONNX backend. Mirrors the interface of trocr_service.run_trocr_ocr. Args: image_data: Raw image bytes (PNG, JPEG, etc.) handwritten: Use handwritten model variant split_lines: Split image into text lines before recognition Returns: Tuple of (extracted_text, confidence). Returns (None, 0.0) on failure. """ processor, model = _get_onnx_model(handwritten=handwritten) if processor is None or model is None: logger.error("ONNX TrOCR model not available") return None, 0.0 try: from PIL import Image image = Image.open(io.BytesIO(image_data)).convert("RGB") if split_lines: lines = _split_into_lines(image) if not lines: lines = [image] else: lines = [image] all_text: List[str] = [] confidences: List[float] = [] for line_image in lines: # Prepare input — processor returns PyTorch tensors pixel_values = processor(images=line_image, return_tensors="pt").pixel_values # Generate via ONNX (ORTModelForVision2Seq.generate is compatible) generated_ids = model.generate(pixel_values, max_length=128) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=True )[0] if generated_text.strip(): all_text.append(generated_text.strip()) confidences.append(0.85 if len(generated_text) > 3 else 0.5) text = "\n".join(all_text) confidence = sum(confidences) / len(confidences) if confidences else 0.0 logger.info( f"ONNX TrOCR extracted {len(text)} chars from {len(lines)} lines" ) return text, confidence except Exception as e: logger.error(f"ONNX TrOCR failed: {e}") import traceback logger.error(traceback.format_exc()) return None, 0.0 async def run_trocr_onnx_enhanced( image_data: bytes, handwritten: bool = True, split_lines: bool = True, use_cache: bool = True, ) -> OCRResult: """ Enhanced ONNX TrOCR with caching and detailed results. Mirrors the interface of trocr_service.run_trocr_ocr_enhanced. Args: image_data: Raw image bytes handwritten: Use handwritten model variant split_lines: Split image into text lines use_cache: Use SHA256-based in-memory cache Returns: OCRResult with text, confidence, timing, word boxes, etc. """ start_time = time.time() # Check cache first image_hash = _compute_image_hash(image_data) if use_cache: cached = _cache_get(image_hash) if cached: return OCRResult( text=cached["text"], confidence=cached["confidence"], processing_time_ms=0, model=cached["model"], has_lora_adapter=cached.get("has_lora_adapter", False), char_confidences=cached.get("char_confidences", []), word_boxes=cached.get("word_boxes", []), from_cache=True, image_hash=image_hash, ) # Run ONNX inference text, confidence = await run_trocr_onnx( image_data, handwritten=handwritten, split_lines=split_lines ) processing_time_ms = int((time.time() - start_time) * 1000) # Generate word boxes with simulated confidences word_boxes: List[Dict[str, Any]] = [] if text: words = text.split() for word in words: word_conf = min( 1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100) ) word_boxes.append({ "text": word, "confidence": word_conf, "bbox": [0, 0, 0, 0], }) # Generate character confidences char_confidences: List[float] = [] if text: for char in text: char_conf = min( 1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100) ) char_confidences.append(char_conf) model_name = ( "trocr-base-handwritten-onnx" if handwritten else "trocr-base-printed-onnx" ) result = OCRResult( text=text or "", confidence=confidence, processing_time_ms=processing_time_ms, model=model_name, has_lora_adapter=False, char_confidences=char_confidences, word_boxes=word_boxes, from_cache=False, image_hash=image_hash, ) # Cache result if use_cache and text: _cache_set(image_hash, { "text": result.text, "confidence": result.confidence, "model": result.model, "has_lora_adapter": result.has_lora_adapter, "char_confidences": result.char_confidences, "word_boxes": result.word_boxes, }) return result