[split-required] Split 700-870 LOC files across all services
backend-lehrer (11 files): - llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6) - messenger_api.py (840 → 5), print_generator.py (824 → 5) - unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4) - llm_gateway/routes/edu_search_seeds.py (710 → 4) klausur-service (12 files): - ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4) - legal_corpus_api.py (790 → 4), page_crop.py (758 → 3) - mail/ai_service.py (747 → 4), github_crawler.py (767 → 3) - trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4) - dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4) website (6 pages): - audit-checklist (867 → 8), content (806 → 6) - screen-flow (790 → 4), scraper (789 → 5) - zeugnisse (776 → 5), modules (745 → 4) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
160
klausur-service/backend/services/trocr_batch.py
Normal file
160
klausur-service/backend/services/trocr_batch.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
TrOCR Batch Processing & Streaming
|
||||
|
||||
Batch OCR and SSE streaming for multiple images.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import OCRResult, BatchOCRResult
|
||||
from .trocr_ocr import run_trocr_ocr_enhanced
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
from .trocr_ocr import run_trocr_ocr
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
278
klausur-service/backend/services/trocr_models.py
Normal file
278
klausur-service/backend/services/trocr_models.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
TrOCR Models & Cache
|
||||
|
||||
Dataclasses, LRU cache, and model loading for TrOCR service.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias -> base-printed
|
||||
_trocr_model = None # backwards-compat alias -> base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
309
klausur-service/backend/services/trocr_ocr.py
Normal file
309
klausur-service/backend/services/trocr_ocr.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
TrOCR OCR Execution
|
||||
|
||||
Core OCR inference routines (PyTorch, ONNX routing, enhanced mode).
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
_trocr_backend,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_trocr_model,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async -- return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" -- always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" -- always use PyTorch (original behaviour)
|
||||
- "auto" -- try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -1,720 +1,70 @@
|
||||
"""
|
||||
TrOCR Service
|
||||
TrOCR Service — Barrel Re-export
|
||||
|
||||
Microsoft's Transformer-based OCR for text recognition.
|
||||
Besonders geeignet fuer:
|
||||
- Gedruckten Text
|
||||
- Saubere Scans
|
||||
- Schnelle Verarbeitung
|
||||
|
||||
Model: microsoft/trocr-base-printed (oder handwritten Variante)
|
||||
Split into submodules:
|
||||
- trocr_models.py — Dataclasses, cache, model loading, line splitting
|
||||
- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced)
|
||||
- trocr_batch.py — Batch processing and SSE streaming
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
|
||||
Phase 2 Enhancements:
|
||||
- Batch processing for multiple images
|
||||
- SHA256-based caching for repeated requests
|
||||
- Model preloading for faster first request
|
||||
- Word-level bounding boxes with confidence
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias → base-printed
|
||||
_trocr_model = None # backwards-compat alias → base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async — return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" — always use PyTorch (original behaviour)
|
||||
- "auto" — try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
# Models, cache, and model loading
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
BatchOCRResult,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_cache_stats,
|
||||
_check_trocr_available,
|
||||
get_trocr_model,
|
||||
preload_trocr_model,
|
||||
get_model_status,
|
||||
get_active_backend,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
# Core OCR execution
|
||||
from .trocr_ocr import (
|
||||
run_trocr_ocr,
|
||||
run_trocr_ocr_enhanced,
|
||||
_run_pytorch_ocr,
|
||||
)
|
||||
|
||||
# Batch processing & streaming
|
||||
from .trocr_batch import (
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
test_trocr_ocr,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"OCRResult",
|
||||
"BatchOCRResult",
|
||||
# Cache
|
||||
"_compute_image_hash",
|
||||
"_cache_get",
|
||||
"_cache_set",
|
||||
"get_cache_stats",
|
||||
# Model loading
|
||||
"_check_trocr_available",
|
||||
"get_trocr_model",
|
||||
"preload_trocr_model",
|
||||
"get_model_status",
|
||||
"get_active_backend",
|
||||
"_split_into_lines",
|
||||
# OCR execution
|
||||
"run_trocr_ocr",
|
||||
"run_trocr_ocr_enhanced",
|
||||
"_run_pytorch_ocr",
|
||||
# Batch
|
||||
"run_trocr_batch",
|
||||
"run_trocr_batch_stream",
|
||||
"test_trocr_ocr",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user