backend-lehrer (11 files): - llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6) - messenger_api.py (840 → 5), print_generator.py (824 → 5) - unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4) - llm_gateway/routes/edu_search_seeds.py (710 → 4) klausur-service (12 files): - ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4) - legal_corpus_api.py (790 → 4), page_crop.py (758 → 3) - mail/ai_service.py (747 → 4), github_crawler.py (767 → 3) - trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4) - dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4) website (6 pages): - audit-checklist (867 → 8), content (806 → 6) - screen-flow (790 → 4), scraper (789 → 5) - zeugnisse (776 → 5), modules (745 → 4) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
279 lines
8.4 KiB
Python
279 lines
8.4 KiB
Python
"""
|
|
TrOCR Models & Cache
|
|
|
|
Dataclasses, LRU cache, and model loading for TrOCR service.
|
|
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import io
|
|
import os
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
from typing import Tuple, Optional, List, Dict, Any
|
|
from dataclasses import dataclass, field
|
|
from collections import OrderedDict
|
|
from datetime import datetime, timedelta
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Backend routing: auto | pytorch | onnx
|
|
# ---------------------------------------------------------------------------
|
|
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
|
|
|
# Lazy loading for heavy dependencies
|
|
# Cache keyed by model_name to support base and large variants simultaneously
|
|
_trocr_models: dict = {} # {model_name: (processor, model)}
|
|
_trocr_processor = None # backwards-compat alias -> base-printed
|
|
_trocr_model = None # backwards-compat alias -> base-printed
|
|
_trocr_available = None
|
|
_model_loaded_at = None
|
|
|
|
# Simple in-memory cache with LRU eviction
|
|
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
|
_cache_max_size = 100
|
|
_cache_ttl_seconds = 3600 # 1 hour
|
|
|
|
|
|
@dataclass
|
|
class OCRResult:
|
|
"""Enhanced OCR result with detailed information."""
|
|
text: str
|
|
confidence: float
|
|
processing_time_ms: int
|
|
model: str
|
|
has_lora_adapter: bool = False
|
|
char_confidences: List[float] = field(default_factory=list)
|
|
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
|
from_cache: bool = False
|
|
image_hash: str = ""
|
|
|
|
|
|
@dataclass
|
|
class BatchOCRResult:
|
|
"""Result for batch processing."""
|
|
results: List[OCRResult]
|
|
total_time_ms: int
|
|
processed_count: int
|
|
cached_count: int
|
|
error_count: int
|
|
|
|
|
|
def _compute_image_hash(image_data: bytes) -> str:
|
|
"""Compute SHA256 hash of image data for caching."""
|
|
return hashlib.sha256(image_data).hexdigest()[:16]
|
|
|
|
|
|
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
|
"""Get cached OCR result if available and not expired."""
|
|
if image_hash in _ocr_cache:
|
|
entry = _ocr_cache[image_hash]
|
|
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
|
# Move to end (LRU)
|
|
_ocr_cache.move_to_end(image_hash)
|
|
return entry["result"]
|
|
else:
|
|
# Expired, remove
|
|
del _ocr_cache[image_hash]
|
|
return None
|
|
|
|
|
|
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
|
"""Store OCR result in cache."""
|
|
# Evict oldest if at capacity
|
|
while len(_ocr_cache) >= _cache_max_size:
|
|
_ocr_cache.popitem(last=False)
|
|
|
|
_ocr_cache[image_hash] = {
|
|
"result": result,
|
|
"cached_at": datetime.now()
|
|
}
|
|
|
|
|
|
def get_cache_stats() -> Dict[str, Any]:
|
|
"""Get cache statistics."""
|
|
return {
|
|
"size": len(_ocr_cache),
|
|
"max_size": _cache_max_size,
|
|
"ttl_seconds": _cache_ttl_seconds,
|
|
"hit_rate": 0 # Could track this with additional counters
|
|
}
|
|
|
|
|
|
def _check_trocr_available() -> bool:
|
|
"""Check if TrOCR dependencies are available."""
|
|
global _trocr_available
|
|
if _trocr_available is not None:
|
|
return _trocr_available
|
|
|
|
try:
|
|
import torch
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
_trocr_available = True
|
|
except ImportError as e:
|
|
logger.warning(f"TrOCR dependencies not available: {e}")
|
|
_trocr_available = False
|
|
|
|
return _trocr_available
|
|
|
|
|
|
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
|
"""
|
|
Lazy load TrOCR model and processor.
|
|
|
|
Args:
|
|
handwritten: Use handwritten model instead of printed model
|
|
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
|
|
for exam HTR). Only applies to handwritten variant.
|
|
|
|
Returns tuple of (processor, model) or (None, None) if unavailable.
|
|
"""
|
|
global _trocr_processor, _trocr_model
|
|
|
|
if not _check_trocr_available():
|
|
return None, None
|
|
|
|
# Select model name
|
|
if size == "large" and handwritten:
|
|
model_name = "microsoft/trocr-large-handwritten"
|
|
elif handwritten:
|
|
model_name = "microsoft/trocr-base-handwritten"
|
|
else:
|
|
model_name = "microsoft/trocr-base-printed"
|
|
|
|
if model_name in _trocr_models:
|
|
return _trocr_models[model_name]
|
|
|
|
try:
|
|
import torch
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
|
|
logger.info(f"Loading TrOCR model: {model_name}")
|
|
processor = TrOCRProcessor.from_pretrained(model_name)
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
|
|
|
# Use GPU if available
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
model.to(device)
|
|
logger.info(f"TrOCR model loaded on device: {device}")
|
|
|
|
_trocr_models[model_name] = (processor, model)
|
|
|
|
# Keep backwards-compat globals pointing at base-printed
|
|
if model_name == "microsoft/trocr-base-printed":
|
|
_trocr_processor = processor
|
|
_trocr_model = model
|
|
|
|
return processor, model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
|
return None, None
|
|
|
|
|
|
def preload_trocr_model(handwritten: bool = True) -> bool:
|
|
"""
|
|
Preload TrOCR model at startup for faster first request.
|
|
|
|
Call this from your FastAPI startup event:
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
preload_trocr_model()
|
|
"""
|
|
global _model_loaded_at
|
|
logger.info("Preloading TrOCR model...")
|
|
processor, model = get_trocr_model(handwritten=handwritten)
|
|
if processor is not None and model is not None:
|
|
_model_loaded_at = datetime.now()
|
|
logger.info("TrOCR model preloaded successfully")
|
|
return True
|
|
else:
|
|
logger.warning("TrOCR model preloading failed")
|
|
return False
|
|
|
|
|
|
def get_model_status() -> Dict[str, Any]:
|
|
"""Get current model status information."""
|
|
processor, model = get_trocr_model(handwritten=True)
|
|
is_loaded = processor is not None and model is not None
|
|
|
|
status = {
|
|
"status": "available" if is_loaded else "not_installed",
|
|
"is_loaded": is_loaded,
|
|
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
|
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
|
}
|
|
|
|
if is_loaded:
|
|
import torch
|
|
device = next(model.parameters()).device
|
|
status["device"] = str(device)
|
|
|
|
return status
|
|
|
|
|
|
def get_active_backend() -> str:
|
|
"""
|
|
Return which TrOCR backend is configured.
|
|
|
|
Possible values: "auto", "pytorch", "onnx".
|
|
"""
|
|
return _trocr_backend
|
|
|
|
|
|
def _split_into_lines(image) -> list:
|
|
"""
|
|
Split an image into text lines using simple projection-based segmentation.
|
|
|
|
This is a basic implementation - for production use, consider using
|
|
a dedicated line detection model.
|
|
"""
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
try:
|
|
# Convert to grayscale
|
|
gray = image.convert('L')
|
|
img_array = np.array(gray)
|
|
|
|
# Binarize (simple threshold)
|
|
threshold = 200
|
|
binary = img_array < threshold
|
|
|
|
# Horizontal projection (sum of dark pixels per row)
|
|
h_proj = np.sum(binary, axis=1)
|
|
|
|
# Find line boundaries (where projection drops below threshold)
|
|
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
|
in_line = False
|
|
line_start = 0
|
|
lines = []
|
|
|
|
for i, val in enumerate(h_proj):
|
|
if val > line_threshold and not in_line:
|
|
# Start of line
|
|
in_line = True
|
|
line_start = i
|
|
elif val <= line_threshold and in_line:
|
|
# End of line
|
|
in_line = False
|
|
# Add padding
|
|
start = max(0, line_start - 5)
|
|
end = min(img_array.shape[0], i + 5)
|
|
if end - start > 10: # Minimum line height
|
|
lines.append(image.crop((0, start, image.width, end)))
|
|
|
|
# Handle last line if still in_line
|
|
if in_line:
|
|
start = max(0, line_start - 5)
|
|
lines.append(image.crop((0, start, image.width, image.height)))
|
|
|
|
logger.info(f"Split image into {len(lines)} lines")
|
|
return lines
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Line splitting failed: {e}")
|
|
return []
|