""" TrOCR Models & Cache Dataclasses, LRU cache, and model loading for TrOCR service. DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import io import os import hashlib import logging import time from typing import Tuple, Optional, List, Dict, Any from dataclasses import dataclass, field from collections import OrderedDict from datetime import datetime, timedelta logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Backend routing: auto | pytorch | onnx # --------------------------------------------------------------------------- _trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx # Lazy loading for heavy dependencies # Cache keyed by model_name to support base and large variants simultaneously _trocr_models: dict = {} # {model_name: (processor, model)} _trocr_processor = None # backwards-compat alias -> base-printed _trocr_model = None # backwards-compat alias -> base-printed _trocr_available = None _model_loaded_at = None # Simple in-memory cache with LRU eviction _ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() _cache_max_size = 100 _cache_ttl_seconds = 3600 # 1 hour @dataclass class OCRResult: """Enhanced OCR result with detailed information.""" text: str confidence: float processing_time_ms: int model: str has_lora_adapter: bool = False char_confidences: List[float] = field(default_factory=list) word_boxes: List[Dict[str, Any]] = field(default_factory=list) from_cache: bool = False image_hash: str = "" @dataclass class BatchOCRResult: """Result for batch processing.""" results: List[OCRResult] total_time_ms: int processed_count: int cached_count: int error_count: int def _compute_image_hash(image_data: bytes) -> str: """Compute SHA256 hash of image data for caching.""" return hashlib.sha256(image_data).hexdigest()[:16] def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]: """Get cached OCR result if available and not expired.""" if image_hash in _ocr_cache: entry = _ocr_cache[image_hash] if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds): # Move to end (LRU) _ocr_cache.move_to_end(image_hash) return entry["result"] else: # Expired, remove del _ocr_cache[image_hash] return None def _cache_set(image_hash: str, result: Dict[str, Any]) -> None: """Store OCR result in cache.""" # Evict oldest if at capacity while len(_ocr_cache) >= _cache_max_size: _ocr_cache.popitem(last=False) _ocr_cache[image_hash] = { "result": result, "cached_at": datetime.now() } def get_cache_stats() -> Dict[str, Any]: """Get cache statistics.""" return { "size": len(_ocr_cache), "max_size": _cache_max_size, "ttl_seconds": _cache_ttl_seconds, "hit_rate": 0 # Could track this with additional counters } def _check_trocr_available() -> bool: """Check if TrOCR dependencies are available.""" global _trocr_available if _trocr_available is not None: return _trocr_available try: import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel _trocr_available = True except ImportError as e: logger.warning(f"TrOCR dependencies not available: {e}") _trocr_available = False return _trocr_available def get_trocr_model(handwritten: bool = False, size: str = "base"): """ Lazy load TrOCR model and processor. Args: handwritten: Use handwritten model instead of printed model size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy for exam HTR). Only applies to handwritten variant. Returns tuple of (processor, model) or (None, None) if unavailable. """ global _trocr_processor, _trocr_model if not _check_trocr_available(): return None, None # Select model name if size == "large" and handwritten: model_name = "microsoft/trocr-large-handwritten" elif handwritten: model_name = "microsoft/trocr-base-handwritten" else: model_name = "microsoft/trocr-base-printed" if model_name in _trocr_models: return _trocr_models[model_name] try: import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel logger.info(f"Loading TrOCR model: {model_name}") processor = TrOCRProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) # Use GPU if available device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" model.to(device) logger.info(f"TrOCR model loaded on device: {device}") _trocr_models[model_name] = (processor, model) # Keep backwards-compat globals pointing at base-printed if model_name == "microsoft/trocr-base-printed": _trocr_processor = processor _trocr_model = model return processor, model except Exception as e: logger.error(f"Failed to load TrOCR model {model_name}: {e}") return None, None def preload_trocr_model(handwritten: bool = True) -> bool: """ Preload TrOCR model at startup for faster first request. Call this from your FastAPI startup event: @app.on_event("startup") async def startup(): preload_trocr_model() """ global _model_loaded_at logger.info("Preloading TrOCR model...") processor, model = get_trocr_model(handwritten=handwritten) if processor is not None and model is not None: _model_loaded_at = datetime.now() logger.info("TrOCR model preloaded successfully") return True else: logger.warning("TrOCR model preloading failed") return False def get_model_status() -> Dict[str, Any]: """Get current model status information.""" processor, model = get_trocr_model(handwritten=True) is_loaded = processor is not None and model is not None status = { "status": "available" if is_loaded else "not_installed", "is_loaded": is_loaded, "model_name": "trocr-base-handwritten" if is_loaded else None, "loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None, } if is_loaded: import torch device = next(model.parameters()).device status["device"] = str(device) return status def get_active_backend() -> str: """ Return which TrOCR backend is configured. Possible values: "auto", "pytorch", "onnx". """ return _trocr_backend def _split_into_lines(image) -> list: """ Split an image into text lines using simple projection-based segmentation. This is a basic implementation - for production use, consider using a dedicated line detection model. """ import numpy as np from PIL import Image try: # Convert to grayscale gray = image.convert('L') img_array = np.array(gray) # Binarize (simple threshold) threshold = 200 binary = img_array < threshold # Horizontal projection (sum of dark pixels per row) h_proj = np.sum(binary, axis=1) # Find line boundaries (where projection drops below threshold) line_threshold = img_array.shape[1] * 0.02 # 2% of width in_line = False line_start = 0 lines = [] for i, val in enumerate(h_proj): if val > line_threshold and not in_line: # Start of line in_line = True line_start = i elif val <= line_threshold and in_line: # End of line in_line = False # Add padding start = max(0, line_start - 5) end = min(img_array.shape[0], i + 5) if end - start > 10: # Minimum line height lines.append(image.crop((0, start, image.width, end))) # Handle last line if still in_line if in_line: start = max(0, line_start - 5) lines.append(image.crop((0, start, image.width, image.height))) logger.info(f"Split image into {len(lines)} lines") return lines except Exception as e: logger.warning(f"Line splitting failed: {e}") return []