D2: TrOCR ONNX export script (printed + handwritten, int8 quantization) D3: PP-DocLayout ONNX export script (download or Docker-based conversion) B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config) A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND) A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND) B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding C3: TrOCR-ONNX.md documentation C4: OCR-Pipeline.md ONNX section added C5: mkdocs.yml nav updated, optimum added to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
731 lines
23 KiB
Python
731 lines
23 KiB
Python
"""
|
|
TrOCR Service
|
|
|
|
Microsoft's Transformer-based OCR for text recognition.
|
|
Besonders geeignet fuer:
|
|
- Gedruckten Text
|
|
- Saubere Scans
|
|
- Schnelle Verarbeitung
|
|
|
|
Model: microsoft/trocr-base-printed (oder handwritten Variante)
|
|
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
|
|
Phase 2 Enhancements:
|
|
- Batch processing for multiple images
|
|
- SHA256-based caching for repeated requests
|
|
- Model preloading for faster first request
|
|
- Word-level bounding boxes with confidence
|
|
"""
|
|
|
|
import io
|
|
import os
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
import asyncio
|
|
from typing import Tuple, Optional, List, Dict, Any
|
|
from dataclasses import dataclass, field
|
|
from collections import OrderedDict
|
|
from datetime import datetime, timedelta
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Backend routing: auto | pytorch | onnx
|
|
# ---------------------------------------------------------------------------
|
|
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
|
|
|
# Lazy loading for heavy dependencies
|
|
# Cache keyed by model_name to support base and large variants simultaneously
|
|
_trocr_models: dict = {} # {model_name: (processor, model)}
|
|
_trocr_processor = None # backwards-compat alias → base-printed
|
|
_trocr_model = None # backwards-compat alias → base-printed
|
|
_trocr_available = None
|
|
_model_loaded_at = None
|
|
|
|
# Simple in-memory cache with LRU eviction
|
|
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
|
_cache_max_size = 100
|
|
_cache_ttl_seconds = 3600 # 1 hour
|
|
|
|
|
|
@dataclass
|
|
class OCRResult:
|
|
"""Enhanced OCR result with detailed information."""
|
|
text: str
|
|
confidence: float
|
|
processing_time_ms: int
|
|
model: str
|
|
has_lora_adapter: bool = False
|
|
char_confidences: List[float] = field(default_factory=list)
|
|
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
|
from_cache: bool = False
|
|
image_hash: str = ""
|
|
|
|
|
|
@dataclass
|
|
class BatchOCRResult:
|
|
"""Result for batch processing."""
|
|
results: List[OCRResult]
|
|
total_time_ms: int
|
|
processed_count: int
|
|
cached_count: int
|
|
error_count: int
|
|
|
|
|
|
def _compute_image_hash(image_data: bytes) -> str:
|
|
"""Compute SHA256 hash of image data for caching."""
|
|
return hashlib.sha256(image_data).hexdigest()[:16]
|
|
|
|
|
|
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
|
"""Get cached OCR result if available and not expired."""
|
|
if image_hash in _ocr_cache:
|
|
entry = _ocr_cache[image_hash]
|
|
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
|
# Move to end (LRU)
|
|
_ocr_cache.move_to_end(image_hash)
|
|
return entry["result"]
|
|
else:
|
|
# Expired, remove
|
|
del _ocr_cache[image_hash]
|
|
return None
|
|
|
|
|
|
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
|
"""Store OCR result in cache."""
|
|
# Evict oldest if at capacity
|
|
while len(_ocr_cache) >= _cache_max_size:
|
|
_ocr_cache.popitem(last=False)
|
|
|
|
_ocr_cache[image_hash] = {
|
|
"result": result,
|
|
"cached_at": datetime.now()
|
|
}
|
|
|
|
|
|
def get_cache_stats() -> Dict[str, Any]:
|
|
"""Get cache statistics."""
|
|
return {
|
|
"size": len(_ocr_cache),
|
|
"max_size": _cache_max_size,
|
|
"ttl_seconds": _cache_ttl_seconds,
|
|
"hit_rate": 0 # Could track this with additional counters
|
|
}
|
|
|
|
|
|
def _check_trocr_available() -> bool:
|
|
"""Check if TrOCR dependencies are available."""
|
|
global _trocr_available
|
|
if _trocr_available is not None:
|
|
return _trocr_available
|
|
|
|
try:
|
|
import torch
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
_trocr_available = True
|
|
except ImportError as e:
|
|
logger.warning(f"TrOCR dependencies not available: {e}")
|
|
_trocr_available = False
|
|
|
|
return _trocr_available
|
|
|
|
|
|
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
|
"""
|
|
Lazy load TrOCR model and processor.
|
|
|
|
Args:
|
|
handwritten: Use handwritten model instead of printed model
|
|
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
|
|
for exam HTR). Only applies to handwritten variant.
|
|
|
|
Returns tuple of (processor, model) or (None, None) if unavailable.
|
|
"""
|
|
global _trocr_processor, _trocr_model
|
|
|
|
if not _check_trocr_available():
|
|
return None, None
|
|
|
|
# Select model name
|
|
if size == "large" and handwritten:
|
|
model_name = "microsoft/trocr-large-handwritten"
|
|
elif handwritten:
|
|
model_name = "microsoft/trocr-base-handwritten"
|
|
else:
|
|
model_name = "microsoft/trocr-base-printed"
|
|
|
|
if model_name in _trocr_models:
|
|
return _trocr_models[model_name]
|
|
|
|
try:
|
|
import torch
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
|
|
logger.info(f"Loading TrOCR model: {model_name}")
|
|
processor = TrOCRProcessor.from_pretrained(model_name)
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
|
|
|
# Use GPU if available
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
model.to(device)
|
|
logger.info(f"TrOCR model loaded on device: {device}")
|
|
|
|
_trocr_models[model_name] = (processor, model)
|
|
|
|
# Keep backwards-compat globals pointing at base-printed
|
|
if model_name == "microsoft/trocr-base-printed":
|
|
_trocr_processor = processor
|
|
_trocr_model = model
|
|
|
|
return processor, model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
|
return None, None
|
|
|
|
|
|
def preload_trocr_model(handwritten: bool = True) -> bool:
|
|
"""
|
|
Preload TrOCR model at startup for faster first request.
|
|
|
|
Call this from your FastAPI startup event:
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
preload_trocr_model()
|
|
"""
|
|
global _model_loaded_at
|
|
logger.info("Preloading TrOCR model...")
|
|
processor, model = get_trocr_model(handwritten=handwritten)
|
|
if processor is not None and model is not None:
|
|
_model_loaded_at = datetime.now()
|
|
logger.info("TrOCR model preloaded successfully")
|
|
return True
|
|
else:
|
|
logger.warning("TrOCR model preloading failed")
|
|
return False
|
|
|
|
|
|
def get_model_status() -> Dict[str, Any]:
|
|
"""Get current model status information."""
|
|
processor, model = get_trocr_model(handwritten=True)
|
|
is_loaded = processor is not None and model is not None
|
|
|
|
status = {
|
|
"status": "available" if is_loaded else "not_installed",
|
|
"is_loaded": is_loaded,
|
|
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
|
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
|
}
|
|
|
|
if is_loaded:
|
|
import torch
|
|
device = next(model.parameters()).device
|
|
status["device"] = str(device)
|
|
|
|
return status
|
|
|
|
|
|
def get_active_backend() -> str:
|
|
"""
|
|
Return which TrOCR backend is configured.
|
|
|
|
Possible values: "auto", "pytorch", "onnx".
|
|
"""
|
|
return _trocr_backend
|
|
|
|
|
|
def _try_onnx_ocr(
|
|
image_data: bytes,
|
|
handwritten: bool = False,
|
|
split_lines: bool = True,
|
|
) -> Optional[Tuple[Optional[str], float]]:
|
|
"""
|
|
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
|
success, or None if ONNX is not available / fails to load.
|
|
"""
|
|
try:
|
|
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
|
|
|
if not is_onnx_available(handwritten=handwritten):
|
|
return None
|
|
# run_trocr_onnx is async — return the coroutine's awaitable result
|
|
# The caller (run_trocr_ocr) will await it.
|
|
return run_trocr_onnx # sentinel: caller checks callable
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
async def _run_pytorch_ocr(
|
|
image_data: bytes,
|
|
handwritten: bool = False,
|
|
split_lines: bool = True,
|
|
size: str = "base",
|
|
) -> Tuple[Optional[str], float]:
|
|
"""
|
|
Original PyTorch inference path (extracted for routing).
|
|
"""
|
|
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
|
|
|
if processor is None or model is None:
|
|
logger.error("TrOCR PyTorch model not available")
|
|
return None, 0.0
|
|
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
# Load image
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
if split_lines:
|
|
lines = _split_into_lines(image)
|
|
if not lines:
|
|
lines = [image]
|
|
else:
|
|
lines = [image]
|
|
|
|
all_text = []
|
|
confidences = []
|
|
|
|
for line_image in lines:
|
|
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
|
|
|
device = next(model.parameters()).device
|
|
pixel_values = pixel_values.to(device)
|
|
|
|
with torch.no_grad():
|
|
generated_ids = model.generate(pixel_values, max_length=128)
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
if generated_text.strip():
|
|
all_text.append(generated_text.strip())
|
|
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
|
|
|
text = "\n".join(all_text)
|
|
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
|
|
|
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
|
return text, confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"TrOCR PyTorch failed: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return None, 0.0
|
|
|
|
|
|
async def run_trocr_ocr(
|
|
image_data: bytes,
|
|
handwritten: bool = False,
|
|
split_lines: bool = True,
|
|
size: str = "base",
|
|
) -> Tuple[Optional[str], float]:
|
|
"""
|
|
Run TrOCR on an image.
|
|
|
|
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
|
environment variable (default: "auto").
|
|
|
|
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
|
- "pytorch" — always use PyTorch (original behaviour)
|
|
- "auto" — try ONNX first, fall back to PyTorch
|
|
|
|
TrOCR is optimized for single-line text recognition, so for full-page
|
|
images we need to either:
|
|
1. Split into lines first (using line detection)
|
|
2. Process the whole image and get partial results
|
|
|
|
Args:
|
|
image_data: Raw image bytes
|
|
handwritten: Use handwritten model (slower but better for handwriting)
|
|
split_lines: Whether to split image into lines first
|
|
size: "base" or "large" (only for handwritten variant)
|
|
|
|
Returns:
|
|
Tuple of (extracted_text, confidence)
|
|
"""
|
|
backend = _trocr_backend
|
|
|
|
# --- ONNX-only mode ---
|
|
if backend == "onnx":
|
|
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
|
if onnx_fn is None or not callable(onnx_fn):
|
|
raise RuntimeError(
|
|
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
|
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
|
)
|
|
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
|
|
|
# --- PyTorch-only mode ---
|
|
if backend == "pytorch":
|
|
return await _run_pytorch_ocr(
|
|
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
|
)
|
|
|
|
# --- Auto mode: try ONNX first, then PyTorch ---
|
|
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
|
if onnx_fn is not None and callable(onnx_fn):
|
|
try:
|
|
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
|
if result[0] is not None:
|
|
return result
|
|
logger.warning("ONNX returned None text, falling back to PyTorch")
|
|
except Exception as e:
|
|
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
|
|
|
return await _run_pytorch_ocr(
|
|
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
|
)
|
|
|
|
|
|
def _split_into_lines(image) -> list:
|
|
"""
|
|
Split an image into text lines using simple projection-based segmentation.
|
|
|
|
This is a basic implementation - for production use, consider using
|
|
a dedicated line detection model.
|
|
"""
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
try:
|
|
# Convert to grayscale
|
|
gray = image.convert('L')
|
|
img_array = np.array(gray)
|
|
|
|
# Binarize (simple threshold)
|
|
threshold = 200
|
|
binary = img_array < threshold
|
|
|
|
# Horizontal projection (sum of dark pixels per row)
|
|
h_proj = np.sum(binary, axis=1)
|
|
|
|
# Find line boundaries (where projection drops below threshold)
|
|
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
|
in_line = False
|
|
line_start = 0
|
|
lines = []
|
|
|
|
for i, val in enumerate(h_proj):
|
|
if val > line_threshold and not in_line:
|
|
# Start of line
|
|
in_line = True
|
|
line_start = i
|
|
elif val <= line_threshold and in_line:
|
|
# End of line
|
|
in_line = False
|
|
# Add padding
|
|
start = max(0, line_start - 5)
|
|
end = min(img_array.shape[0], i + 5)
|
|
if end - start > 10: # Minimum line height
|
|
lines.append(image.crop((0, start, image.width, end)))
|
|
|
|
# Handle last line if still in_line
|
|
if in_line:
|
|
start = max(0, line_start - 5)
|
|
lines.append(image.crop((0, start, image.width, image.height)))
|
|
|
|
logger.info(f"Split image into {len(lines)} lines")
|
|
return lines
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Line splitting failed: {e}")
|
|
return []
|
|
|
|
|
|
def _try_onnx_enhanced(
|
|
handwritten: bool = True,
|
|
):
|
|
"""
|
|
Return the ONNX enhanced coroutine function, or None if unavailable.
|
|
"""
|
|
try:
|
|
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
|
|
|
if not is_onnx_available(handwritten=handwritten):
|
|
return None
|
|
return run_trocr_onnx_enhanced
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
async def run_trocr_ocr_enhanced(
|
|
image_data: bytes,
|
|
handwritten: bool = True,
|
|
split_lines: bool = True,
|
|
use_cache: bool = True
|
|
) -> OCRResult:
|
|
"""
|
|
Enhanced TrOCR OCR with caching and detailed results.
|
|
|
|
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
|
environment variable (default: "auto").
|
|
|
|
Args:
|
|
image_data: Raw image bytes
|
|
handwritten: Use handwritten model
|
|
split_lines: Whether to split image into lines first
|
|
use_cache: Whether to use caching
|
|
|
|
Returns:
|
|
OCRResult with detailed information
|
|
"""
|
|
backend = _trocr_backend
|
|
|
|
# --- ONNX-only mode ---
|
|
if backend == "onnx":
|
|
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
|
if onnx_fn is None:
|
|
raise RuntimeError(
|
|
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
|
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
|
)
|
|
return await onnx_fn(
|
|
image_data, handwritten=handwritten,
|
|
split_lines=split_lines, use_cache=use_cache,
|
|
)
|
|
|
|
# --- Auto mode: try ONNX first ---
|
|
if backend == "auto":
|
|
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
|
if onnx_fn is not None:
|
|
try:
|
|
result = await onnx_fn(
|
|
image_data, handwritten=handwritten,
|
|
split_lines=split_lines, use_cache=use_cache,
|
|
)
|
|
if result.text:
|
|
return result
|
|
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
|
except Exception as e:
|
|
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
|
|
|
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
|
start_time = time.time()
|
|
|
|
# Check cache first
|
|
image_hash = _compute_image_hash(image_data)
|
|
if use_cache:
|
|
cached = _cache_get(image_hash)
|
|
if cached:
|
|
return OCRResult(
|
|
text=cached["text"],
|
|
confidence=cached["confidence"],
|
|
processing_time_ms=0,
|
|
model=cached["model"],
|
|
has_lora_adapter=cached.get("has_lora_adapter", False),
|
|
char_confidences=cached.get("char_confidences", []),
|
|
word_boxes=cached.get("word_boxes", []),
|
|
from_cache=True,
|
|
image_hash=image_hash
|
|
)
|
|
|
|
# Run OCR via PyTorch
|
|
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
|
|
|
processing_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# Generate word boxes with simulated confidences
|
|
word_boxes = []
|
|
if text:
|
|
words = text.split()
|
|
for idx, word in enumerate(words):
|
|
# Simulate word confidence (slightly varied around overall confidence)
|
|
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
|
word_boxes.append({
|
|
"text": word,
|
|
"confidence": word_conf,
|
|
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
|
})
|
|
|
|
# Generate character confidences
|
|
char_confidences = []
|
|
if text:
|
|
for char in text:
|
|
# Simulate per-character confidence
|
|
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
|
char_confidences.append(char_conf)
|
|
|
|
result = OCRResult(
|
|
text=text or "",
|
|
confidence=confidence,
|
|
processing_time_ms=processing_time_ms,
|
|
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
|
has_lora_adapter=False, # Would check actual adapter status
|
|
char_confidences=char_confidences,
|
|
word_boxes=word_boxes,
|
|
from_cache=False,
|
|
image_hash=image_hash
|
|
)
|
|
|
|
# Cache result
|
|
if use_cache and text:
|
|
_cache_set(image_hash, {
|
|
"text": result.text,
|
|
"confidence": result.confidence,
|
|
"model": result.model,
|
|
"has_lora_adapter": result.has_lora_adapter,
|
|
"char_confidences": result.char_confidences,
|
|
"word_boxes": result.word_boxes
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
async def run_trocr_batch(
|
|
images: List[bytes],
|
|
handwritten: bool = True,
|
|
split_lines: bool = True,
|
|
use_cache: bool = True,
|
|
progress_callback: Optional[callable] = None
|
|
) -> BatchOCRResult:
|
|
"""
|
|
Process multiple images in batch.
|
|
|
|
Args:
|
|
images: List of image data bytes
|
|
handwritten: Use handwritten model
|
|
split_lines: Whether to split images into lines
|
|
use_cache: Whether to use caching
|
|
progress_callback: Optional callback(current, total) for progress updates
|
|
|
|
Returns:
|
|
BatchOCRResult with all results
|
|
"""
|
|
start_time = time.time()
|
|
results = []
|
|
cached_count = 0
|
|
error_count = 0
|
|
|
|
for idx, image_data in enumerate(images):
|
|
try:
|
|
result = await run_trocr_ocr_enhanced(
|
|
image_data,
|
|
handwritten=handwritten,
|
|
split_lines=split_lines,
|
|
use_cache=use_cache
|
|
)
|
|
results.append(result)
|
|
|
|
if result.from_cache:
|
|
cached_count += 1
|
|
|
|
# Report progress
|
|
if progress_callback:
|
|
progress_callback(idx + 1, len(images))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Batch OCR error for image {idx}: {e}")
|
|
error_count += 1
|
|
results.append(OCRResult(
|
|
text=f"Error: {str(e)}",
|
|
confidence=0.0,
|
|
processing_time_ms=0,
|
|
model="error",
|
|
has_lora_adapter=False
|
|
))
|
|
|
|
total_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
return BatchOCRResult(
|
|
results=results,
|
|
total_time_ms=total_time_ms,
|
|
processed_count=len(images),
|
|
cached_count=cached_count,
|
|
error_count=error_count
|
|
)
|
|
|
|
|
|
# Generator for SSE streaming during batch processing
|
|
async def run_trocr_batch_stream(
|
|
images: List[bytes],
|
|
handwritten: bool = True,
|
|
split_lines: bool = True,
|
|
use_cache: bool = True
|
|
):
|
|
"""
|
|
Process images and yield progress updates for SSE streaming.
|
|
|
|
Yields:
|
|
dict with current progress and result
|
|
"""
|
|
start_time = time.time()
|
|
total = len(images)
|
|
|
|
for idx, image_data in enumerate(images):
|
|
try:
|
|
result = await run_trocr_ocr_enhanced(
|
|
image_data,
|
|
handwritten=handwritten,
|
|
split_lines=split_lines,
|
|
use_cache=use_cache
|
|
)
|
|
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
avg_time_per_image = elapsed_ms / (idx + 1)
|
|
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
|
|
|
yield {
|
|
"type": "progress",
|
|
"current": idx + 1,
|
|
"total": total,
|
|
"progress_percent": ((idx + 1) / total) * 100,
|
|
"elapsed_ms": elapsed_ms,
|
|
"estimated_remaining_ms": estimated_remaining,
|
|
"result": {
|
|
"text": result.text,
|
|
"confidence": result.confidence,
|
|
"processing_time_ms": result.processing_time_ms,
|
|
"from_cache": result.from_cache
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream OCR error for image {idx}: {e}")
|
|
yield {
|
|
"type": "error",
|
|
"current": idx + 1,
|
|
"total": total,
|
|
"error": str(e)
|
|
}
|
|
|
|
total_time_ms = int((time.time() - start_time) * 1000)
|
|
yield {
|
|
"type": "complete",
|
|
"total_time_ms": total_time_ms,
|
|
"processed_count": total
|
|
}
|
|
|
|
|
|
# Test function
|
|
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
|
"""Test TrOCR on a local image file."""
|
|
with open(image_path, "rb") as f:
|
|
image_data = f.read()
|
|
|
|
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
|
|
|
print(f"\n=== TrOCR Test ===")
|
|
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
|
print(f"Confidence: {confidence:.2f}")
|
|
print(f"Text:\n{text}")
|
|
|
|
return text, confidence
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
import sys
|
|
|
|
handwritten = "--handwritten" in sys.argv
|
|
args = [a for a in sys.argv[1:] if not a.startswith("--")]
|
|
|
|
if args:
|
|
asyncio.run(test_trocr_ocr(args[0], handwritten=handwritten))
|
|
else:
|
|
print("Usage: python trocr_service.py <image_path> [--handwritten]")
|