""" TrOCR Service Microsoft's Transformer-based OCR for text recognition. Besonders geeignet fuer: - Gedruckten Text - Saubere Scans - Schnelle Verarbeitung Model: microsoft/trocr-base-printed (oder handwritten Variante) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. Phase 2 Enhancements: - Batch processing for multiple images - SHA256-based caching for repeated requests - Model preloading for faster first request - Word-level bounding boxes with confidence """ import io import os import hashlib import logging import time import asyncio from typing import Tuple, Optional, List, Dict, Any from dataclasses import dataclass, field from collections import OrderedDict from datetime import datetime, timedelta logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Backend routing: auto | pytorch | onnx # --------------------------------------------------------------------------- _trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx # Lazy loading for heavy dependencies # Cache keyed by model_name to support base and large variants simultaneously _trocr_models: dict = {} # {model_name: (processor, model)} _trocr_processor = None # backwards-compat alias → base-printed _trocr_model = None # backwards-compat alias → base-printed _trocr_available = None _model_loaded_at = None # Simple in-memory cache with LRU eviction _ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() _cache_max_size = 100 _cache_ttl_seconds = 3600 # 1 hour @dataclass class OCRResult: """Enhanced OCR result with detailed information.""" text: str confidence: float processing_time_ms: int model: str has_lora_adapter: bool = False char_confidences: List[float] = field(default_factory=list) word_boxes: List[Dict[str, Any]] = field(default_factory=list) from_cache: bool = False image_hash: str = "" @dataclass class BatchOCRResult: """Result for batch processing.""" results: List[OCRResult] total_time_ms: int processed_count: int cached_count: int error_count: int def _compute_image_hash(image_data: bytes) -> str: """Compute SHA256 hash of image data for caching.""" return hashlib.sha256(image_data).hexdigest()[:16] def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]: """Get cached OCR result if available and not expired.""" if image_hash in _ocr_cache: entry = _ocr_cache[image_hash] if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds): # Move to end (LRU) _ocr_cache.move_to_end(image_hash) return entry["result"] else: # Expired, remove del _ocr_cache[image_hash] return None def _cache_set(image_hash: str, result: Dict[str, Any]) -> None: """Store OCR result in cache.""" # Evict oldest if at capacity while len(_ocr_cache) >= _cache_max_size: _ocr_cache.popitem(last=False) _ocr_cache[image_hash] = { "result": result, "cached_at": datetime.now() } def get_cache_stats() -> Dict[str, Any]: """Get cache statistics.""" return { "size": len(_ocr_cache), "max_size": _cache_max_size, "ttl_seconds": _cache_ttl_seconds, "hit_rate": 0 # Could track this with additional counters } def _check_trocr_available() -> bool: """Check if TrOCR dependencies are available.""" global _trocr_available if _trocr_available is not None: return _trocr_available try: import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel _trocr_available = True except ImportError as e: logger.warning(f"TrOCR dependencies not available: {e}") _trocr_available = False return _trocr_available def get_trocr_model(handwritten: bool = False, size: str = "base"): """ Lazy load TrOCR model and processor. Args: handwritten: Use handwritten model instead of printed model size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy for exam HTR). Only applies to handwritten variant. Returns tuple of (processor, model) or (None, None) if unavailable. """ global _trocr_processor, _trocr_model if not _check_trocr_available(): return None, None # Select model name if size == "large" and handwritten: model_name = "microsoft/trocr-large-handwritten" elif handwritten: model_name = "microsoft/trocr-base-handwritten" else: model_name = "microsoft/trocr-base-printed" if model_name in _trocr_models: return _trocr_models[model_name] try: import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel logger.info(f"Loading TrOCR model: {model_name}") processor = TrOCRProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) # Use GPU if available device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" model.to(device) logger.info(f"TrOCR model loaded on device: {device}") _trocr_models[model_name] = (processor, model) # Keep backwards-compat globals pointing at base-printed if model_name == "microsoft/trocr-base-printed": _trocr_processor = processor _trocr_model = model return processor, model except Exception as e: logger.error(f"Failed to load TrOCR model {model_name}: {e}") return None, None def preload_trocr_model(handwritten: bool = True) -> bool: """ Preload TrOCR model at startup for faster first request. Call this from your FastAPI startup event: @app.on_event("startup") async def startup(): preload_trocr_model() """ global _model_loaded_at logger.info("Preloading TrOCR model...") processor, model = get_trocr_model(handwritten=handwritten) if processor is not None and model is not None: _model_loaded_at = datetime.now() logger.info("TrOCR model preloaded successfully") return True else: logger.warning("TrOCR model preloading failed") return False def get_model_status() -> Dict[str, Any]: """Get current model status information.""" processor, model = get_trocr_model(handwritten=True) is_loaded = processor is not None and model is not None status = { "status": "available" if is_loaded else "not_installed", "is_loaded": is_loaded, "model_name": "trocr-base-handwritten" if is_loaded else None, "loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None, } if is_loaded: import torch device = next(model.parameters()).device status["device"] = str(device) return status def get_active_backend() -> str: """ Return which TrOCR backend is configured. Possible values: "auto", "pytorch", "onnx". """ return _trocr_backend def _try_onnx_ocr( image_data: bytes, handwritten: bool = False, split_lines: bool = True, ) -> Optional[Tuple[Optional[str], float]]: """ Attempt ONNX inference. Returns the (text, confidence) tuple on success, or None if ONNX is not available / fails to load. """ try: from .trocr_onnx_service import is_onnx_available, run_trocr_onnx if not is_onnx_available(handwritten=handwritten): return None # run_trocr_onnx is async — return the coroutine's awaitable result # The caller (run_trocr_ocr) will await it. return run_trocr_onnx # sentinel: caller checks callable except ImportError: return None async def _run_pytorch_ocr( image_data: bytes, handwritten: bool = False, split_lines: bool = True, size: str = "base", ) -> Tuple[Optional[str], float]: """ Original PyTorch inference path (extracted for routing). """ processor, model = get_trocr_model(handwritten=handwritten, size=size) if processor is None or model is None: logger.error("TrOCR PyTorch model not available") return None, 0.0 try: import torch from PIL import Image import numpy as np # Load image image = Image.open(io.BytesIO(image_data)).convert("RGB") if split_lines: lines = _split_into_lines(image) if not lines: lines = [image] else: lines = [image] all_text = [] confidences = [] for line_image in lines: pixel_values = processor(images=line_image, return_tensors="pt").pixel_values device = next(model.parameters()).device pixel_values = pixel_values.to(device) with torch.no_grad(): generated_ids = model.generate(pixel_values, max_length=128) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if generated_text.strip(): all_text.append(generated_text.strip()) confidences.append(0.85 if len(generated_text) > 3 else 0.5) text = "\n".join(all_text) confidence = sum(confidences) / len(confidences) if confidences else 0.0 logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines") return text, confidence except Exception as e: logger.error(f"TrOCR PyTorch failed: {e}") import traceback logger.error(traceback.format_exc()) return None, 0.0 async def run_trocr_ocr( image_data: bytes, handwritten: bool = False, split_lines: bool = True, size: str = "base", ) -> Tuple[Optional[str], float]: """ Run TrOCR on an image. Routes between ONNX and PyTorch backends based on the TROCR_BACKEND environment variable (default: "auto"). - "onnx" — always use ONNX (raises RuntimeError if unavailable) - "pytorch" — always use PyTorch (original behaviour) - "auto" — try ONNX first, fall back to PyTorch TrOCR is optimized for single-line text recognition, so for full-page images we need to either: 1. Split into lines first (using line detection) 2. Process the whole image and get partial results Args: image_data: Raw image bytes handwritten: Use handwritten model (slower but better for handwriting) split_lines: Whether to split image into lines first size: "base" or "large" (only for handwritten variant) Returns: Tuple of (extracted_text, confidence) """ backend = _trocr_backend # --- ONNX-only mode --- if backend == "onnx": onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) if onnx_fn is None or not callable(onnx_fn): raise RuntimeError( "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " "Ensure onnxruntime + optimum are installed and ONNX model files exist." ) return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) # --- PyTorch-only mode --- if backend == "pytorch": return await _run_pytorch_ocr( image_data, handwritten=handwritten, split_lines=split_lines, size=size, ) # --- Auto mode: try ONNX first, then PyTorch --- onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) if onnx_fn is not None and callable(onnx_fn): try: result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) if result[0] is not None: return result logger.warning("ONNX returned None text, falling back to PyTorch") except Exception as e: logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch") return await _run_pytorch_ocr( image_data, handwritten=handwritten, split_lines=split_lines, size=size, ) def _split_into_lines(image) -> list: """ Split an image into text lines using simple projection-based segmentation. This is a basic implementation - for production use, consider using a dedicated line detection model. """ import numpy as np from PIL import Image try: # Convert to grayscale gray = image.convert('L') img_array = np.array(gray) # Binarize (simple threshold) threshold = 200 binary = img_array < threshold # Horizontal projection (sum of dark pixels per row) h_proj = np.sum(binary, axis=1) # Find line boundaries (where projection drops below threshold) line_threshold = img_array.shape[1] * 0.02 # 2% of width in_line = False line_start = 0 lines = [] for i, val in enumerate(h_proj): if val > line_threshold and not in_line: # Start of line in_line = True line_start = i elif val <= line_threshold and in_line: # End of line in_line = False # Add padding start = max(0, line_start - 5) end = min(img_array.shape[0], i + 5) if end - start > 10: # Minimum line height lines.append(image.crop((0, start, image.width, end))) # Handle last line if still in_line if in_line: start = max(0, line_start - 5) lines.append(image.crop((0, start, image.width, image.height))) logger.info(f"Split image into {len(lines)} lines") return lines except Exception as e: logger.warning(f"Line splitting failed: {e}") return [] def _try_onnx_enhanced( handwritten: bool = True, ): """ Return the ONNX enhanced coroutine function, or None if unavailable. """ try: from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced if not is_onnx_available(handwritten=handwritten): return None return run_trocr_onnx_enhanced except ImportError: return None async def run_trocr_ocr_enhanced( image_data: bytes, handwritten: bool = True, split_lines: bool = True, use_cache: bool = True ) -> OCRResult: """ Enhanced TrOCR OCR with caching and detailed results. Routes between ONNX and PyTorch backends based on the TROCR_BACKEND environment variable (default: "auto"). Args: image_data: Raw image bytes handwritten: Use handwritten model split_lines: Whether to split image into lines first use_cache: Whether to use caching Returns: OCRResult with detailed information """ backend = _trocr_backend # --- ONNX-only mode --- if backend == "onnx": onnx_fn = _try_onnx_enhanced(handwritten=handwritten) if onnx_fn is None: raise RuntimeError( "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " "Ensure onnxruntime + optimum are installed and ONNX model files exist." ) return await onnx_fn( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache, ) # --- Auto mode: try ONNX first --- if backend == "auto": onnx_fn = _try_onnx_enhanced(handwritten=handwritten) if onnx_fn is not None: try: result = await onnx_fn( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache, ) if result.text: return result logger.warning("ONNX enhanced returned empty text, falling back to PyTorch") except Exception as e: logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch") # --- PyTorch path (backend == "pytorch" or auto fallback) --- start_time = time.time() # Check cache first image_hash = _compute_image_hash(image_data) if use_cache: cached = _cache_get(image_hash) if cached: return OCRResult( text=cached["text"], confidence=cached["confidence"], processing_time_ms=0, model=cached["model"], has_lora_adapter=cached.get("has_lora_adapter", False), char_confidences=cached.get("char_confidences", []), word_boxes=cached.get("word_boxes", []), from_cache=True, image_hash=image_hash ) # Run OCR via PyTorch text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines) processing_time_ms = int((time.time() - start_time) * 1000) # Generate word boxes with simulated confidences word_boxes = [] if text: words = text.split() for idx, word in enumerate(words): # Simulate word confidence (slightly varied around overall confidence) word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)) word_boxes.append({ "text": word, "confidence": word_conf, "bbox": [0, 0, 0, 0] # Would need actual bounding box detection }) # Generate character confidences char_confidences = [] if text: for char in text: # Simulate per-character confidence char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)) char_confidences.append(char_conf) result = OCRResult( text=text or "", confidence=confidence, processing_time_ms=processing_time_ms, model="trocr-base-handwritten" if handwritten else "trocr-base-printed", has_lora_adapter=False, # Would check actual adapter status char_confidences=char_confidences, word_boxes=word_boxes, from_cache=False, image_hash=image_hash ) # Cache result if use_cache and text: _cache_set(image_hash, { "text": result.text, "confidence": result.confidence, "model": result.model, "has_lora_adapter": result.has_lora_adapter, "char_confidences": result.char_confidences, "word_boxes": result.word_boxes }) return result async def run_trocr_batch( images: List[bytes], handwritten: bool = True, split_lines: bool = True, use_cache: bool = True, progress_callback: Optional[callable] = None ) -> BatchOCRResult: """ Process multiple images in batch. Args: images: List of image data bytes handwritten: Use handwritten model split_lines: Whether to split images into lines use_cache: Whether to use caching progress_callback: Optional callback(current, total) for progress updates Returns: BatchOCRResult with all results """ start_time = time.time() results = [] cached_count = 0 error_count = 0 for idx, image_data in enumerate(images): try: result = await run_trocr_ocr_enhanced( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache ) results.append(result) if result.from_cache: cached_count += 1 # Report progress if progress_callback: progress_callback(idx + 1, len(images)) except Exception as e: logger.error(f"Batch OCR error for image {idx}: {e}") error_count += 1 results.append(OCRResult( text=f"Error: {str(e)}", confidence=0.0, processing_time_ms=0, model="error", has_lora_adapter=False )) total_time_ms = int((time.time() - start_time) * 1000) return BatchOCRResult( results=results, total_time_ms=total_time_ms, processed_count=len(images), cached_count=cached_count, error_count=error_count ) # Generator for SSE streaming during batch processing async def run_trocr_batch_stream( images: List[bytes], handwritten: bool = True, split_lines: bool = True, use_cache: bool = True ): """ Process images and yield progress updates for SSE streaming. Yields: dict with current progress and result """ start_time = time.time() total = len(images) for idx, image_data in enumerate(images): try: result = await run_trocr_ocr_enhanced( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache ) elapsed_ms = int((time.time() - start_time) * 1000) avg_time_per_image = elapsed_ms / (idx + 1) estimated_remaining = int(avg_time_per_image * (total - idx - 1)) yield { "type": "progress", "current": idx + 1, "total": total, "progress_percent": ((idx + 1) / total) * 100, "elapsed_ms": elapsed_ms, "estimated_remaining_ms": estimated_remaining, "result": { "text": result.text, "confidence": result.confidence, "processing_time_ms": result.processing_time_ms, "from_cache": result.from_cache } } except Exception as e: logger.error(f"Stream OCR error for image {idx}: {e}") yield { "type": "error", "current": idx + 1, "total": total, "error": str(e) } total_time_ms = int((time.time() - start_time) * 1000) yield { "type": "complete", "total_time_ms": total_time_ms, "processed_count": total } # Test function async def test_trocr_ocr(image_path: str, handwritten: bool = False): """Test TrOCR on a local image file.""" with open(image_path, "rb") as f: image_data = f.read() text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten) print(f"\n=== TrOCR Test ===") print(f"Mode: {'Handwritten' if handwritten else 'Printed'}") print(f"Confidence: {confidence:.2f}") print(f"Text:\n{text}") return text, confidence if __name__ == "__main__": import asyncio import sys handwritten = "--handwritten" in sys.argv args = [a for a in sys.argv[1:] if not a.startswith("--")] if args: asyncio.run(test_trocr_ocr(args[0], handwritten=handwritten)) else: print("Usage: python trocr_service.py [--handwritten]")