Files
breakpilot-lehrer/klausur-service/backend/services/trocr_models.py
Benjamin Admin 34da9f4cda [split-required] Split 700-870 LOC files across all services
backend-lehrer (11 files):
- llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6)
- messenger_api.py (840 → 5), print_generator.py (824 → 5)
- unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4)
- llm_gateway/routes/edu_search_seeds.py (710 → 4)

klausur-service (12 files):
- ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4)
- legal_corpus_api.py (790 → 4), page_crop.py (758 → 3)
- mail/ai_service.py (747 → 4), github_crawler.py (767 → 3)
- trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4)
- dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4)

website (6 pages):
- audit-checklist (867 → 8), content (806 → 6)
- screen-flow (790 → 4), scraper (789 → 5)
- zeugnisse (776 → 5), modules (745 → 4)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-25 08:01:18 +02:00

279 lines
8.4 KiB
Python

"""
TrOCR Models & Cache
Dataclasses, LRU cache, and model loading for TrOCR service.
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import os
import hashlib
import logging
import time
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias -> base-printed
_trocr_model = None # backwards-compat alias -> base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []