feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management

D2: TrOCR ONNX export script (printed + handwritten, int8 quantization)
D3: PP-DocLayout ONNX export script (download or Docker-based conversion)
B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config)
A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND)
A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND)
B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding
C3: TrOCR-ONNX.md documentation
C4: OCR-Pipeline.md ONNX section added
C5: mkdocs.yml nav updated, optimum added to requirements.txt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-23 09:53:02 +01:00
parent c695b659fb
commit be7f5f1872
16 changed files with 3616 additions and 60 deletions

View File

@@ -19,6 +19,7 @@ Phase 2 Enhancements:
"""
import io
import os
import hashlib
import logging
import time
@@ -30,6 +31,11 @@ from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
@@ -221,6 +227,97 @@ def get_model_status() -> Dict[str, Any]:
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async — return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
@@ -230,6 +327,13 @@ async def run_trocr_ocr(
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
- "pytorch" — always use PyTorch (original behaviour)
- "auto" — try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
@@ -244,65 +348,38 @@ async def run_trocr_ocr(
Returns:
Tuple of (extracted_text, confidence)
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
backend = _trocr_backend
if processor is None or model is None:
logger.error("TrOCR model not available")
return None, 0.0
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
try:
import torch
from PIL import Image
import numpy as np
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
if split_lines:
# Split image into lines and process each
lines = _split_into_lines(image)
if not lines:
lines = [image] # Fallback to full image
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
# Prepare input
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
# Move to same device as model
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
# Generate
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
# TrOCR doesn't provide confidence, estimate based on output
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
# Combine results
text = "\n".join(all_text)
# Average confidence
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _split_into_lines(image) -> list:
@@ -360,6 +437,22 @@ def _split_into_lines(image) -> list:
return []
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
@@ -369,6 +462,9 @@ async def run_trocr_ocr_enhanced(
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
@@ -378,6 +474,37 @@ async def run_trocr_ocr_enhanced(
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
@@ -397,8 +524,8 @@ async def run_trocr_ocr_enhanced(
image_hash=image_hash
)
# Run OCR
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)