feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management
D2: TrOCR ONNX export script (printed + handwritten, int8 quantization) D3: PP-DocLayout ONNX export script (download or Docker-based conversion) B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config) A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND) A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND) B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding C3: TrOCR-ONNX.md documentation C4: OCR-Pipeline.md ONNX section added C5: mkdocs.yml nav updated, optimum added to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
413
klausur-service/backend/cv_doclayout_detect.py
Normal file
413
klausur-service/backend/cv_doclayout_detect.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
PP-DocLayout ONNX Document Layout Detection.
|
||||
|
||||
Uses PP-DocLayout ONNX model to detect document structure regions:
|
||||
table, figure, title, text, list, header, footer, equation, reference, abstract
|
||||
|
||||
Fallback: If ONNX model not available, returns empty list (caller should
|
||||
fall back to OpenCV-based detection in cv_graphic_detect.py).
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"detect_layout_regions",
|
||||
"is_doclayout_available",
|
||||
"get_doclayout_status",
|
||||
"LayoutRegion",
|
||||
"DOCLAYOUT_CLASSES",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Class labels (PP-DocLayout default order)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCLAYOUT_CLASSES = [
|
||||
"table", "figure", "title", "text", "list",
|
||||
"header", "footer", "equation", "reference", "abstract",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutRegion:
|
||||
"""A detected document layout region."""
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
label: str # table, figure, title, text, list, etc.
|
||||
confidence: float
|
||||
label_index: int # raw class index
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX model loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MODEL_SEARCH_PATHS = [
|
||||
# 1. Explicit environment variable
|
||||
os.environ.get("DOCLAYOUT_ONNX_PATH", ""),
|
||||
# 2. Docker default cache path
|
||||
"/root/.cache/huggingface/onnx/pp-doclayout/model.onnx",
|
||||
# 3. Local dev relative to working directory
|
||||
"models/onnx/pp-doclayout/model.onnx",
|
||||
]
|
||||
|
||||
_onnx_session: Optional[object] = None
|
||||
_model_path: Optional[str] = None
|
||||
_load_attempted: bool = False
|
||||
_load_error: Optional[str] = None
|
||||
|
||||
|
||||
def _find_model_path() -> Optional[str]:
|
||||
"""Search for the ONNX model file in known locations."""
|
||||
for p in _MODEL_SEARCH_PATHS:
|
||||
if p and Path(p).is_file():
|
||||
return str(Path(p).resolve())
|
||||
return None
|
||||
|
||||
|
||||
def _load_onnx_session():
|
||||
"""Lazy-load the ONNX runtime session (once)."""
|
||||
global _onnx_session, _model_path, _load_attempted, _load_error
|
||||
|
||||
if _load_attempted:
|
||||
return _onnx_session
|
||||
|
||||
_load_attempted = True
|
||||
|
||||
path = _find_model_path()
|
||||
if path is None:
|
||||
_load_error = "ONNX model not found in any search path"
|
||||
logger.info("PP-DocLayout: %s", _load_error)
|
||||
return None
|
||||
|
||||
try:
|
||||
import onnxruntime as ort # type: ignore[import-untyped]
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# Prefer CPU – keeps the GPU free for OCR / LLM.
|
||||
providers = ["CPUExecutionProvider"]
|
||||
_onnx_session = ort.InferenceSession(path, sess_options, providers=providers)
|
||||
_model_path = path
|
||||
logger.info("PP-DocLayout: model loaded from %s", path)
|
||||
except ImportError:
|
||||
_load_error = "onnxruntime not installed"
|
||||
logger.info("PP-DocLayout: %s", _load_error)
|
||||
except Exception as exc:
|
||||
_load_error = str(exc)
|
||||
logger.warning("PP-DocLayout: failed to load model from %s: %s", path, exc)
|
||||
|
||||
return _onnx_session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_doclayout_available() -> bool:
|
||||
"""Return True if the ONNX model can be loaded successfully."""
|
||||
return _load_onnx_session() is not None
|
||||
|
||||
|
||||
def get_doclayout_status() -> Dict:
|
||||
"""Return diagnostic information about the DocLayout backend."""
|
||||
_load_onnx_session() # ensure we tried
|
||||
return {
|
||||
"available": _onnx_session is not None,
|
||||
"model_path": _model_path,
|
||||
"load_error": _load_error,
|
||||
"classes": DOCLAYOUT_CLASSES,
|
||||
"class_count": len(DOCLAYOUT_CLASSES),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INPUT_SIZE = 800 # PP-DocLayout expects 800x800
|
||||
|
||||
|
||||
def preprocess_image(img_bgr: np.ndarray) -> tuple:
|
||||
"""Resize + normalize image for PP-DocLayout ONNX input.
|
||||
|
||||
Returns:
|
||||
(input_tensor, scale_x, scale_y, pad_x, pad_y)
|
||||
where scale/pad allow mapping boxes back to original coords.
|
||||
"""
|
||||
orig_h, orig_w = img_bgr.shape[:2]
|
||||
|
||||
# Compute scale to fit within _INPUT_SIZE keeping aspect ratio
|
||||
scale = min(_INPUT_SIZE / orig_w, _INPUT_SIZE / orig_h)
|
||||
new_w = int(orig_w * scale)
|
||||
new_h = int(orig_h * scale)
|
||||
|
||||
import cv2 # local import — cv2 is always available in this service
|
||||
resized = cv2.resize(img_bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Pad to _INPUT_SIZE x _INPUT_SIZE with gray (114)
|
||||
pad_x = (_INPUT_SIZE - new_w) // 2
|
||||
pad_y = (_INPUT_SIZE - new_h) // 2
|
||||
padded = np.full((_INPUT_SIZE, _INPUT_SIZE, 3), 114, dtype=np.uint8)
|
||||
padded[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized
|
||||
|
||||
# Normalize to [0, 1] float32
|
||||
blob = padded.astype(np.float32) / 255.0
|
||||
|
||||
# HWC → CHW
|
||||
blob = blob.transpose(2, 0, 1)
|
||||
|
||||
# Add batch dimension → (1, 3, 800, 800)
|
||||
blob = np.expand_dims(blob, axis=0)
|
||||
|
||||
return blob, scale, pad_x, pad_y
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-Maximum Suppression (NMS)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_iou(box_a: np.ndarray, box_b: np.ndarray) -> float:
|
||||
"""Compute IoU between two boxes [x1, y1, x2, y2]."""
|
||||
ix1 = max(box_a[0], box_b[0])
|
||||
iy1 = max(box_a[1], box_b[1])
|
||||
ix2 = min(box_a[2], box_b[2])
|
||||
iy2 = min(box_a[3], box_b[3])
|
||||
|
||||
inter = max(0.0, ix2 - ix1) * max(0.0, iy2 - iy1)
|
||||
if inter == 0:
|
||||
return 0.0
|
||||
|
||||
area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
|
||||
area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
|
||||
union = area_a + area_b - inter
|
||||
return inter / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def nms(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float = 0.5) -> List[int]:
|
||||
"""Apply greedy Non-Maximum Suppression.
|
||||
|
||||
Args:
|
||||
boxes: (N, 4) array of [x1, y1, x2, y2].
|
||||
scores: (N,) confidence scores.
|
||||
iou_threshold: Overlap threshold for suppression.
|
||||
|
||||
Returns:
|
||||
List of kept indices.
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
order = np.argsort(scores)[::-1].tolist()
|
||||
keep: List[int] = []
|
||||
|
||||
while order:
|
||||
i = order.pop(0)
|
||||
keep.append(i)
|
||||
remaining = []
|
||||
for j in order:
|
||||
if _compute_iou(boxes[i], boxes[j]) < iou_threshold:
|
||||
remaining.append(j)
|
||||
order = remaining
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Post-processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _postprocess(
|
||||
outputs: list,
|
||||
scale: float,
|
||||
pad_x: int,
|
||||
pad_y: int,
|
||||
orig_w: int,
|
||||
orig_h: int,
|
||||
confidence_threshold: float,
|
||||
max_regions: int,
|
||||
) -> List[LayoutRegion]:
|
||||
"""Parse ONNX output tensors into LayoutRegion list.
|
||||
|
||||
PP-DocLayout ONNX typically outputs one tensor of shape
|
||||
(1, N, 6) or three tensors (boxes, scores, class_ids).
|
||||
We handle both common formats.
|
||||
"""
|
||||
regions: List[LayoutRegion] = []
|
||||
|
||||
# --- Determine output format ---
|
||||
if len(outputs) == 1:
|
||||
# Single tensor: (1, N, 4+1+1) = (batch, detections, [x1,y1,x2,y2,score,class])
|
||||
raw = np.squeeze(outputs[0]) # (N, 6) or (N, 5+num_classes)
|
||||
if raw.ndim == 1:
|
||||
raw = raw.reshape(1, -1)
|
||||
if raw.shape[0] == 0:
|
||||
return []
|
||||
|
||||
if raw.shape[1] == 6:
|
||||
# Format: x1, y1, x2, y2, score, class_id
|
||||
all_boxes = raw[:, :4]
|
||||
all_scores = raw[:, 4]
|
||||
all_classes = raw[:, 5].astype(int)
|
||||
elif raw.shape[1] > 6:
|
||||
# Format: x1, y1, x2, y2, obj_conf, cls0_conf, cls1_conf, ...
|
||||
all_boxes = raw[:, :4]
|
||||
cls_scores = raw[:, 5:]
|
||||
all_classes = np.argmax(cls_scores, axis=1)
|
||||
all_scores = raw[:, 4] * np.max(cls_scores, axis=1)
|
||||
else:
|
||||
logger.warning("PP-DocLayout: unexpected output shape %s", raw.shape)
|
||||
return []
|
||||
|
||||
elif len(outputs) == 3:
|
||||
# Three tensors: boxes (N,4), scores (N,), class_ids (N,)
|
||||
all_boxes = np.squeeze(outputs[0])
|
||||
all_scores = np.squeeze(outputs[1])
|
||||
all_classes = np.squeeze(outputs[2]).astype(int)
|
||||
if all_boxes.ndim == 1:
|
||||
all_boxes = all_boxes.reshape(1, 4)
|
||||
all_scores = np.array([all_scores])
|
||||
all_classes = np.array([all_classes])
|
||||
else:
|
||||
logger.warning("PP-DocLayout: unexpected %d output tensors", len(outputs))
|
||||
return []
|
||||
|
||||
# --- Confidence filter ---
|
||||
mask = all_scores >= confidence_threshold
|
||||
boxes = all_boxes[mask]
|
||||
scores = all_scores[mask]
|
||||
classes = all_classes[mask]
|
||||
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# --- NMS ---
|
||||
keep_idxs = nms(boxes, scores, iou_threshold=0.5)
|
||||
boxes = boxes[keep_idxs]
|
||||
scores = scores[keep_idxs]
|
||||
classes = classes[keep_idxs]
|
||||
|
||||
# --- Scale boxes back to original image coordinates ---
|
||||
for i in range(len(boxes)):
|
||||
x1, y1, x2, y2 = boxes[i]
|
||||
|
||||
# Remove padding offset
|
||||
x1 = (x1 - pad_x) / scale
|
||||
y1 = (y1 - pad_y) / scale
|
||||
x2 = (x2 - pad_x) / scale
|
||||
y2 = (y2 - pad_y) / scale
|
||||
|
||||
# Clamp to original dimensions
|
||||
x1 = max(0, min(x1, orig_w))
|
||||
y1 = max(0, min(y1, orig_h))
|
||||
x2 = max(0, min(x2, orig_w))
|
||||
y2 = max(0, min(y2, orig_h))
|
||||
|
||||
w = int(round(x2 - x1))
|
||||
h = int(round(y2 - y1))
|
||||
if w < 5 or h < 5:
|
||||
continue
|
||||
|
||||
cls_idx = int(classes[i])
|
||||
label = DOCLAYOUT_CLASSES[cls_idx] if 0 <= cls_idx < len(DOCLAYOUT_CLASSES) else f"class_{cls_idx}"
|
||||
|
||||
regions.append(LayoutRegion(
|
||||
x=int(round(x1)),
|
||||
y=int(round(y1)),
|
||||
width=w,
|
||||
height=h,
|
||||
label=label,
|
||||
confidence=round(float(scores[i]), 4),
|
||||
label_index=cls_idx,
|
||||
))
|
||||
|
||||
# Sort by confidence descending, limit
|
||||
regions.sort(key=lambda r: r.confidence, reverse=True)
|
||||
return regions[:max_regions]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main detection function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def detect_layout_regions(
|
||||
img_bgr: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
max_regions: int = 50,
|
||||
) -> List[LayoutRegion]:
|
||||
"""Detect document layout regions using PP-DocLayout ONNX model.
|
||||
|
||||
Args:
|
||||
img_bgr: BGR color image (OpenCV format).
|
||||
confidence_threshold: Minimum confidence to keep a detection.
|
||||
max_regions: Maximum number of regions to return.
|
||||
|
||||
Returns:
|
||||
List of LayoutRegion sorted by confidence descending.
|
||||
Returns empty list if model is not available.
|
||||
"""
|
||||
session = _load_onnx_session()
|
||||
if session is None:
|
||||
return []
|
||||
|
||||
if img_bgr is None or img_bgr.size == 0:
|
||||
return []
|
||||
|
||||
orig_h, orig_w = img_bgr.shape[:2]
|
||||
|
||||
# Pre-process
|
||||
input_tensor, scale, pad_x, pad_y = preprocess_image(img_bgr)
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
input_name = session.get_inputs()[0].name
|
||||
outputs = session.run(None, {input_name: input_tensor})
|
||||
except Exception as exc:
|
||||
logger.warning("PP-DocLayout inference failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Post-process
|
||||
regions = _postprocess(
|
||||
outputs,
|
||||
scale=scale,
|
||||
pad_x=pad_x,
|
||||
pad_y=pad_y,
|
||||
orig_w=orig_w,
|
||||
orig_h=orig_h,
|
||||
confidence_threshold=confidence_threshold,
|
||||
max_regions=max_regions,
|
||||
)
|
||||
|
||||
if regions:
|
||||
label_counts: Dict[str, int] = {}
|
||||
for r in regions:
|
||||
label_counts[r.label] = label_counts.get(r.label, 0) + 1
|
||||
logger.info(
|
||||
"PP-DocLayout: %d regions (%s)",
|
||||
len(regions),
|
||||
", ".join(f"{k}: {v}" for k, v in sorted(label_counts.items())),
|
||||
)
|
||||
else:
|
||||
logger.debug("PP-DocLayout: no regions above threshold %.2f", confidence_threshold)
|
||||
|
||||
return regions
|
||||
@@ -120,6 +120,57 @@ def detect_graphic_elements(
|
||||
if img_bgr is None:
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Try PP-DocLayout ONNX first if available
|
||||
# ------------------------------------------------------------------
|
||||
import os
|
||||
backend = os.environ.get("GRAPHIC_DETECT_BACKEND", "auto")
|
||||
if backend in ("doclayout", "auto"):
|
||||
try:
|
||||
from cv_doclayout_detect import detect_layout_regions, is_doclayout_available
|
||||
if is_doclayout_available():
|
||||
regions = detect_layout_regions(img_bgr)
|
||||
if regions:
|
||||
_LABEL_TO_COLOR = {
|
||||
"figure": ("image", "green", _COLOR_HEX.get("green", "#16a34a")),
|
||||
"table": ("image", "blue", _COLOR_HEX.get("blue", "#2563eb")),
|
||||
}
|
||||
converted: List[GraphicElement] = []
|
||||
for r in regions:
|
||||
shape, color_name, color_hex = _LABEL_TO_COLOR.get(
|
||||
r.label,
|
||||
(r.label, "gray", _COLOR_HEX.get("gray", "#6b7280")),
|
||||
)
|
||||
converted.append(GraphicElement(
|
||||
x=r.x,
|
||||
y=r.y,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
area=r.width * r.height,
|
||||
shape=shape,
|
||||
color_name=color_name,
|
||||
color_hex=color_hex,
|
||||
confidence=r.confidence,
|
||||
contour=None,
|
||||
))
|
||||
converted.sort(key=lambda g: g.area, reverse=True)
|
||||
result = converted[:max_elements]
|
||||
if result:
|
||||
shape_counts: Dict[str, int] = {}
|
||||
for g in result:
|
||||
shape_counts[g.shape] = shape_counts.get(g.shape, 0) + 1
|
||||
logger.info(
|
||||
"GraphicDetect (PP-DocLayout): %d elements (%s)",
|
||||
len(result),
|
||||
", ".join(f"{s}: {c}" for s, c in sorted(shape_counts.items())),
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning("PP-DocLayout failed, falling back to OpenCV: %s", e)
|
||||
# ------------------------------------------------------------------
|
||||
# OpenCV fallback (original logic)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
logger.debug("GraphicDetect: image %dx%d, %d word_boxes, %d detected_boxes",
|
||||
|
||||
@@ -48,6 +48,9 @@ email-validator>=2.0.0
|
||||
# DOCX export for reconstruction editor (MIT license)
|
||||
python-docx>=1.1.0
|
||||
|
||||
# ONNX model export and optimization (Apache-2.0)
|
||||
optimum[onnxruntime]>=1.17.0
|
||||
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
|
||||
430
klausur-service/backend/services/trocr_onnx_service.py
Normal file
430
klausur-service/backend/services/trocr_onnx_service.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
TrOCR ONNX Service
|
||||
|
||||
ONNX-optimized inference for TrOCR text recognition.
|
||||
Uses optimum.onnxruntime.ORTModelForVision2Seq for hardware-accelerated
|
||||
inference without requiring PyTorch at runtime.
|
||||
|
||||
Advantages over PyTorch backend:
|
||||
- 2-4x faster inference on CPU
|
||||
- Lower memory footprint (~300 MB vs ~600 MB)
|
||||
- No PyTorch/CUDA dependency at runtime
|
||||
- Apple Silicon (CoreML) and x86 (OpenVINO) acceleration
|
||||
|
||||
Model paths searched (in order):
|
||||
1. TROCR_ONNX_DIR environment variable
|
||||
2. /root/.cache/huggingface/onnx/trocr-base-{printed,handwritten}/ (Docker)
|
||||
3. models/onnx/trocr-base-{printed,handwritten}/ (local dev)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-use shared types and cache from trocr_service
|
||||
from .trocr_service import (
|
||||
OCRResult,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# {model_key: (processor, model)} — model_key = "printed" | "handwritten"
|
||||
_onnx_models: Dict[str, Any] = {}
|
||||
_onnx_available: Optional[bool] = None
|
||||
_onnx_model_loaded_at: Optional[datetime] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VARIANT_NAMES = {
|
||||
False: "trocr-base-printed",
|
||||
True: "trocr-base-handwritten",
|
||||
}
|
||||
|
||||
# HuggingFace model IDs (used for processor downloads)
|
||||
_HF_MODEL_IDS = {
|
||||
False: "microsoft/trocr-base-printed",
|
||||
True: "microsoft/trocr-base-handwritten",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_onnx_model_dir(handwritten: bool = False) -> Optional[Path]:
|
||||
"""
|
||||
Resolve the directory containing ONNX model files for the given variant.
|
||||
|
||||
Search order:
|
||||
1. TROCR_ONNX_DIR env var (appended with variant name)
|
||||
2. /root/.cache/huggingface/onnx/<variant>/ (Docker)
|
||||
3. models/onnx/<variant>/ (local dev, relative to this file)
|
||||
|
||||
Returns the first directory that exists and contains at least one .onnx file,
|
||||
or None if no valid directory is found.
|
||||
"""
|
||||
variant = _VARIANT_NAMES[handwritten]
|
||||
candidates: List[Path] = []
|
||||
|
||||
# 1. Environment variable
|
||||
env_dir = os.environ.get("TROCR_ONNX_DIR")
|
||||
if env_dir:
|
||||
candidates.append(Path(env_dir) / variant)
|
||||
# Also allow the env var to point directly at a variant dir
|
||||
candidates.append(Path(env_dir))
|
||||
|
||||
# 2. Docker path
|
||||
candidates.append(Path(f"/root/.cache/huggingface/onnx/{variant}"))
|
||||
|
||||
# 3. Local dev path (relative to klausur-service/backend/)
|
||||
backend_dir = Path(__file__).resolve().parent.parent
|
||||
candidates.append(backend_dir / "models" / "onnx" / variant)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.is_dir():
|
||||
# Check for ONNX files or a model config (optimum stores config.json)
|
||||
onnx_files = list(candidate.glob("*.onnx"))
|
||||
has_config = (candidate / "config.json").exists()
|
||||
if onnx_files or has_config:
|
||||
logger.info(f"ONNX model directory resolved: {candidate}")
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_onnx_runtime_available() -> bool:
|
||||
"""Check if onnxruntime and optimum are importable."""
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq # noqa: F401
|
||||
from transformers import TrOCRProcessor # noqa: F401
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.debug(f"ONNX runtime dependencies not available: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_onnx_available(handwritten: bool = False) -> bool:
|
||||
"""
|
||||
Check whether ONNX inference is available for the given variant.
|
||||
|
||||
Returns True only when:
|
||||
- onnxruntime + optimum are installed
|
||||
- A valid model directory with ONNX files exists
|
||||
"""
|
||||
if not _check_onnx_runtime_available():
|
||||
return False
|
||||
return _resolve_onnx_model_dir(handwritten=handwritten) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_onnx_model(handwritten: bool = False):
|
||||
"""
|
||||
Lazy-load ONNX model and processor.
|
||||
|
||||
Returns:
|
||||
Tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _onnx_model_loaded_at
|
||||
|
||||
model_key = "handwritten" if handwritten else "printed"
|
||||
|
||||
if model_key in _onnx_models:
|
||||
return _onnx_models[model_key]
|
||||
|
||||
model_dir = _resolve_onnx_model_dir(handwritten=handwritten)
|
||||
if model_dir is None:
|
||||
logger.warning(
|
||||
f"No ONNX model directory found for variant "
|
||||
f"{'handwritten' if handwritten else 'printed'}"
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not _check_onnx_runtime_available():
|
||||
logger.warning("ONNX runtime dependencies not installed")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
from transformers import TrOCRProcessor
|
||||
|
||||
hf_id = _HF_MODEL_IDS[handwritten]
|
||||
|
||||
logger.info(f"Loading ONNX TrOCR model from {model_dir} (processor: {hf_id})")
|
||||
t0 = time.monotonic()
|
||||
|
||||
# Load processor from HuggingFace (tokenizer + feature extractor)
|
||||
processor = TrOCRProcessor.from_pretrained(hf_id)
|
||||
|
||||
# Load ONNX model from local directory
|
||||
model = ORTModelForVision2Seq.from_pretrained(str(model_dir))
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
f"ONNX TrOCR model loaded in {elapsed:.1f}s "
|
||||
f"(variant={model_key}, dir={model_dir})"
|
||||
)
|
||||
|
||||
_onnx_models[model_key] = (processor, model)
|
||||
_onnx_model_loaded_at = datetime.now()
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ONNX TrOCR model ({model_key}): {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_onnx_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload ONNX model at startup for faster first request.
|
||||
|
||||
Call from FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_onnx_model()
|
||||
"""
|
||||
logger.info(f"Preloading ONNX TrOCR model (handwritten={handwritten})...")
|
||||
processor, model = _get_onnx_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
logger.info("ONNX TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("ONNX TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_onnx_model_status() -> Dict[str, Any]:
|
||||
"""Get current ONNX model status information."""
|
||||
runtime_ok = _check_onnx_runtime_available()
|
||||
|
||||
printed_dir = _resolve_onnx_model_dir(handwritten=False)
|
||||
handwritten_dir = _resolve_onnx_model_dir(handwritten=True)
|
||||
|
||||
printed_loaded = "printed" in _onnx_models
|
||||
handwritten_loaded = "handwritten" in _onnx_models
|
||||
|
||||
# Detect ONNX runtime providers
|
||||
providers = []
|
||||
if runtime_ok:
|
||||
try:
|
||||
import onnxruntime
|
||||
providers = onnxruntime.get_available_providers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"backend": "onnx",
|
||||
"runtime_available": runtime_ok,
|
||||
"providers": providers,
|
||||
"printed": {
|
||||
"model_dir": str(printed_dir) if printed_dir else None,
|
||||
"available": printed_dir is not None and runtime_ok,
|
||||
"loaded": printed_loaded,
|
||||
},
|
||||
"handwritten": {
|
||||
"model_dir": str(handwritten_dir) if handwritten_dir else None,
|
||||
"available": handwritten_dir is not None and runtime_ok,
|
||||
"loaded": handwritten_loaded,
|
||||
},
|
||||
"loaded_at": _onnx_model_loaded_at.isoformat() if _onnx_model_loaded_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_trocr_onnx(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR OCR using ONNX backend.
|
||||
|
||||
Mirrors the interface of trocr_service.run_trocr_ocr.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes (PNG, JPEG, etc.)
|
||||
handwritten: Use handwritten model variant
|
||||
split_lines: Split image into text lines before recognition
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence).
|
||||
Returns (None, 0.0) on failure.
|
||||
"""
|
||||
processor, model = _get_onnx_model(handwritten=handwritten)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("ONNX TrOCR model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text: List[str] = []
|
||||
confidences: List[float] = []
|
||||
|
||||
for line_image in lines:
|
||||
# Prepare input — processor returns PyTorch tensors
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
# Generate via ONNX (ORTModelForVision2Seq.generate is compatible)
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(
|
||||
f"ONNX TrOCR extracted {len(text)} chars from {len(lines)} lines"
|
||||
)
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX TrOCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_onnx_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced ONNX TrOCR with caching and detailed results.
|
||||
|
||||
Mirrors the interface of trocr_service.run_trocr_ocr_enhanced.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model variant
|
||||
split_lines: Split image into text lines
|
||||
use_cache: Use SHA256-based in-memory cache
|
||||
|
||||
Returns:
|
||||
OCRResult with text, confidence, timing, word boxes, etc.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash,
|
||||
)
|
||||
|
||||
# Run ONNX inference
|
||||
text, confidence = await run_trocr_onnx(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines
|
||||
)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes: List[Dict[str, Any]] = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for word in words:
|
||||
word_conf = min(
|
||||
1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)
|
||||
)
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0],
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences: List[float] = []
|
||||
if text:
|
||||
for char in text:
|
||||
char_conf = min(
|
||||
1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)
|
||||
)
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
model_name = (
|
||||
"trocr-base-handwritten-onnx" if handwritten else "trocr-base-printed-onnx"
|
||||
)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model=model_name,
|
||||
has_lora_adapter=False,
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash,
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes,
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -19,6 +19,7 @@ Phase 2 Enhancements:
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
@@ -30,6 +31,11 @@ from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
@@ -221,6 +227,97 @@ def get_model_status() -> Dict[str, Any]:
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async — return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
@@ -230,6 +327,13 @@ async def run_trocr_ocr(
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" — always use PyTorch (original behaviour)
|
||||
- "auto" — try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
@@ -244,65 +348,38 @@ async def run_trocr_ocr(
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
backend = _trocr_backend
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR model not available")
|
||||
return None, 0.0
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
if split_lines:
|
||||
# Split image into lines and process each
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image] # Fallback to full image
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
# Prepare input
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
# Move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
# Decode
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
# TrOCR doesn't provide confidence, estimate based on output
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
# Combine results
|
||||
text = "\n".join(all_text)
|
||||
|
||||
# Average confidence
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
@@ -360,6 +437,22 @@ def _split_into_lines(image) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
@@ -369,6 +462,9 @@ async def run_trocr_ocr_enhanced(
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
@@ -378,6 +474,37 @@ async def run_trocr_ocr_enhanced(
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
@@ -397,8 +524,8 @@ async def run_trocr_ocr_enhanced(
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
|
||||
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Tests for PP-DocLayout ONNX Document Layout Detection.
|
||||
|
||||
Uses mocking to avoid requiring the actual ONNX model file.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# We patch the module-level globals before importing to ensure clean state
|
||||
# in tests that check "no model" behaviour.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fresh_import():
|
||||
"""Re-import cv_doclayout_detect with reset globals."""
|
||||
import cv_doclayout_detect as mod
|
||||
# Reset module-level caching so each test starts clean
|
||||
mod._onnx_session = None
|
||||
mod._model_path = None
|
||||
mod._load_attempted = False
|
||||
mod._load_error = None
|
||||
return mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. is_doclayout_available — no model present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsDoclayoutAvailableNoModel:
|
||||
def test_returns_false_when_no_onnx_file(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
def test_returns_false_when_onnxruntime_missing(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value="/fake/model.onnx"):
|
||||
with patch.dict("sys.modules", {"onnxruntime": None}):
|
||||
# Force ImportError by making import fail
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "onnxruntime":
|
||||
raise ImportError("no onnxruntime")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. LayoutRegion dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLayoutRegionDataclass:
|
||||
def test_basic_creation(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
region = LayoutRegion(
|
||||
x=10, y=20, width=100, height=200,
|
||||
label="figure", confidence=0.95, label_index=1,
|
||||
)
|
||||
assert region.x == 10
|
||||
assert region.y == 20
|
||||
assert region.width == 100
|
||||
assert region.height == 200
|
||||
assert region.label == "figure"
|
||||
assert region.confidence == 0.95
|
||||
assert region.label_index == 1
|
||||
|
||||
def test_all_fields_present(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
import dataclasses
|
||||
field_names = {f.name for f in dataclasses.fields(LayoutRegion)}
|
||||
expected = {"x", "y", "width", "height", "label", "confidence", "label_index"}
|
||||
assert field_names == expected
|
||||
|
||||
def test_different_labels(self):
|
||||
from cv_doclayout_detect import LayoutRegion, DOCLAYOUT_CLASSES
|
||||
for idx, label in enumerate(DOCLAYOUT_CLASSES):
|
||||
region = LayoutRegion(
|
||||
x=0, y=0, width=50, height=50,
|
||||
label=label, confidence=0.8, label_index=idx,
|
||||
)
|
||||
assert region.label == label
|
||||
assert region.label_index == idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. detect_layout_regions — no model available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetectLayoutRegionsNoModel:
|
||||
def test_returns_empty_list_when_model_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_none_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
result = mod.detect_layout_regions(None)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_empty_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.array([], dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Preprocessing — tensor shape verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPreprocessingShapes:
|
||||
def test_square_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (800, 800, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
assert tensor.dtype == np.float32
|
||||
assert 0.0 <= tensor.min()
|
||||
assert tensor.max() <= 1.0
|
||||
|
||||
def test_landscape_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (600, 1200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Landscape: scale by width, should have vertical padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_y > 0 # vertical padding expected
|
||||
|
||||
def test_portrait_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (1200, 600, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Portrait: scale by height, should have horizontal padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_x > 0 # horizontal padding expected
|
||||
|
||||
def test_small_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (100, 200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_typical_scan_a4(self):
|
||||
"""A4 scan at 300dpi: roughly 2480x3508 pixels."""
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (3508, 2480, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_values_normalized(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
# All white image
|
||||
img = np.full((400, 400, 3), 255, dtype=np.uint8)
|
||||
tensor, _, _, _ = preprocess_image(img)
|
||||
# The padded region is 114/255 ≈ 0.447, the image region is 1.0
|
||||
assert tensor.max() <= 1.0
|
||||
assert tensor.min() >= 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. NMS logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNmsLogic:
|
||||
def test_empty_input(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([]).reshape(0, 4)
|
||||
scores = np.array([])
|
||||
assert nms(boxes, scores) == []
|
||||
|
||||
def test_single_box(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([[10, 10, 100, 100]], dtype=np.float32)
|
||||
scores = np.array([0.9])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert kept == [0]
|
||||
|
||||
def test_non_overlapping_boxes(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([
|
||||
[0, 0, 50, 50],
|
||||
[200, 200, 300, 300],
|
||||
[400, 400, 500, 500],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 3
|
||||
assert set(kept) == {0, 1, 2}
|
||||
|
||||
def test_overlapping_boxes_suppressed(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that heavily overlap
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110], # 100x100
|
||||
[15, 15, 115, 115], # 100x100, heavily overlapping with first
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.95, 0.80])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Only the higher-confidence box should survive
|
||||
assert kept == [0]
|
||||
|
||||
def test_partially_overlapping_boxes_kept(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that overlap ~25% (below 0.5 threshold)
|
||||
boxes = np.array([
|
||||
[0, 0, 100, 100], # 100x100
|
||||
[75, 0, 175, 100], # 100x100, overlap 25x100 = 2500
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8])
|
||||
# IoU = 2500 / (10000 + 10000 - 2500) = 2500/17500 ≈ 0.143
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 2
|
||||
|
||||
def test_nms_respects_score_ordering(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Three overlapping boxes — highest confidence should be kept first
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110],
|
||||
[12, 12, 112, 112],
|
||||
[14, 14, 114, 114],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.5, 0.9, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Index 1 has highest score → kept first, suppresses 0 and 2
|
||||
assert kept[0] == 1
|
||||
|
||||
def test_iou_computation(self):
|
||||
from cv_doclayout_detect import _compute_iou
|
||||
box_a = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
box_b = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
assert abs(_compute_iou(box_a, box_b) - 1.0) < 1e-5
|
||||
|
||||
box_c = np.array([200, 200, 300, 300], dtype=np.float32)
|
||||
assert _compute_iou(box_a, box_c) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DOCLAYOUT_CLASSES verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDoclayoutClasses:
|
||||
def test_correct_class_list(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
expected = [
|
||||
"table", "figure", "title", "text", "list",
|
||||
"header", "footer", "equation", "reference", "abstract",
|
||||
]
|
||||
assert DOCLAYOUT_CLASSES == expected
|
||||
|
||||
def test_class_count(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == 10
|
||||
|
||||
def test_no_duplicates(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == len(set(DOCLAYOUT_CLASSES))
|
||||
|
||||
def test_all_lowercase(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
for cls in DOCLAYOUT_CLASSES:
|
||||
assert cls == cls.lower(), f"Class '{cls}' should be lowercase"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. get_doclayout_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDoclayoutStatus:
|
||||
def test_status_when_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
status = mod.get_doclayout_status()
|
||||
assert status["available"] is False
|
||||
assert status["model_path"] is None
|
||||
assert status["load_error"] is not None
|
||||
assert status["classes"] == mod.DOCLAYOUT_CLASSES
|
||||
assert status["class_count"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Post-processing with mocked ONNX outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPostprocessing:
|
||||
def test_single_tensor_format_6cols(self):
|
||||
"""Test parsing of (1, N, 6) output format: x1,y1,x2,y2,score,class."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# One detection: figure at (100,100)-(300,300) in 800x800 space
|
||||
raw = np.array([[[100, 100, 300, 300, 0.92, 1]]], dtype=np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
assert regions[0].confidence >= 0.9
|
||||
|
||||
def test_three_tensor_format(self):
|
||||
"""Test parsing of 3-tensor output: boxes, scores, class_ids."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
boxes = np.array([[50, 50, 200, 150]], dtype=np.float32)
|
||||
scores = np.array([0.88], dtype=np.float32)
|
||||
class_ids = np.array([0], dtype=np.float32) # table
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[boxes, scores, class_ids],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "table"
|
||||
|
||||
def test_confidence_filtering(self):
|
||||
"""Detections below threshold should be excluded."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
raw = np.array([
|
||||
[100, 100, 200, 200, 0.9, 1], # above threshold
|
||||
[300, 300, 400, 400, 0.3, 2], # below threshold
|
||||
], dtype=np.float32).reshape(1, 2, 6)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
|
||||
def test_coordinate_scaling(self):
|
||||
"""Verify coordinates are correctly scaled back to original image."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# Image was 1600x1200, scaled to fit 800x800 → scale=0.5, pad_y offset
|
||||
scale = 800 / 1600 # 0.5
|
||||
pad_x = 0
|
||||
pad_y = (800 - int(1200 * scale)) // 2 # (800-600)//2 = 100
|
||||
|
||||
# Detection in 800x800 space at (100, 200) to (300, 400)
|
||||
raw = np.array([[[100, 200, 300, 400, 0.95, 0]]], dtype=np.float32)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=scale, pad_x=pad_x, pad_y=pad_y,
|
||||
orig_w=1600, orig_h=1200,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
r = regions[0]
|
||||
# x1 = (100 - 0) / 0.5 = 200
|
||||
assert r.x == 200
|
||||
# y1 = (200 - 100) / 0.5 = 200
|
||||
assert r.y == 200
|
||||
|
||||
def test_empty_output(self):
|
||||
from cv_doclayout_detect import _postprocess
|
||||
raw = np.array([]).reshape(1, 0, 6).astype(np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert regions == []
|
||||
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Tests for TrOCR ONNX service.
|
||||
|
||||
All tests use mocking — no actual ONNX model files required.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _services_path():
|
||||
"""Return absolute path to the services/ directory."""
|
||||
return Path(__file__).resolve().parent.parent / "services"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: is_onnx_available — no models on disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOnnxAvailableNoModels:
|
||||
"""When no ONNX files exist on disk, is_onnx_available must return False."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_is_onnx_available_no_models(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
assert is_onnx_available(handwritten=True) is False
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
def test_is_onnx_available_no_runtime(self, mock_runtime):
|
||||
"""Even if model dirs existed, missing runtime → False."""
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: get_onnx_model_status — not available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelStatusNotAvailable:
|
||||
"""Status dict when ONNX is not loaded."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_onnx_model_status_not_available(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
|
||||
# Clear any cached models from prior tests
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["backend"] == "onnx"
|
||||
assert status["runtime_available"] is False
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["printed"]["loaded"] is False
|
||||
assert status["printed"]["model_dir"] is None
|
||||
assert status["handwritten"]["available"] is False
|
||||
assert status["handwritten"]["loaded"] is False
|
||||
assert status["handwritten"]["model_dir"] is None
|
||||
assert status["loaded_at"] is None
|
||||
assert status["providers"] == []
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
def test_onnx_model_status_runtime_but_no_files(self, mock_runtime):
|
||||
"""Runtime installed but no model files on disk."""
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
with patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
), patch("services.trocr_onnx_service.onnxruntime", create=True) as mock_ort:
|
||||
# Mock onnxruntime import inside get_onnx_model_status
|
||||
mock_ort_module = MagicMock()
|
||||
mock_ort_module.get_available_providers.return_value = [
|
||||
"CPUExecutionProvider"
|
||||
]
|
||||
with patch.dict("sys.modules", {"onnxruntime": mock_ort_module}):
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["runtime_available"] is True
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["handwritten"]["available"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: path resolution logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelPaths:
|
||||
"""Verify the path resolution order."""
|
||||
|
||||
def test_env_var_path_takes_precedence(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR env var should be checked first."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# Create a fake model dir with a config.json
|
||||
model_dir = tmp_path / "trocr-base-printed"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "config.json").write_text("{}")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_env_var_handwritten_variant(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR works for handwritten variant too."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
model_dir = tmp_path / "trocr-base-handwritten"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "encoder_model.onnx").write_bytes(b"fake")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=True)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_returns_none_when_no_dirs_exist(self):
|
||||
"""When none of the candidate dirs exist, return None."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove TROCR_ONNX_DIR if set
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# The Docker and local-dev paths almost certainly don't contain
|
||||
# real ONNX models on the test machine.
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# Could be None or a real dir if someone has models locally.
|
||||
# We just verify it doesn't raise.
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_docker_path_checked(self, tmp_path):
|
||||
"""Docker path /root/.cache/huggingface/onnx/ is in candidate list."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
docker_path = Path("/root/.cache/huggingface/onnx/trocr-base-printed")
|
||||
|
||||
# We can't create that path in tests, but we can verify the logic
|
||||
# by checking that when env var points nowhere and docker path
|
||||
# doesn't exist, the function still runs without error.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# Just verify it doesn't crash
|
||||
_resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
def test_local_dev_path_relative_to_backend(self, tmp_path):
|
||||
"""Local dev path is models/onnx/<variant>/ relative to backend dir."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# The backend dir is derived from __file__, so we can't easily
|
||||
# redirect it. Instead, verify the function signature and return type.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
# May or may not find models — just verify the return type
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_dir_without_onnx_files_is_skipped(self, tmp_path):
|
||||
"""A directory that exists but has no .onnx files or config.json is skipped."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
empty_dir = tmp_path / "trocr-base-printed"
|
||||
empty_dir.mkdir(parents=True)
|
||||
# No .onnx files, no config.json
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# The env-var candidate exists as a dir but has no model files,
|
||||
# so it should be skipped. Result depends on whether other
|
||||
# candidate dirs have models.
|
||||
if result is not None:
|
||||
# If found elsewhere, that's fine — just not the empty dir
|
||||
assert result != empty_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: fallback to PyTorch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxFallbackToPytorch:
|
||||
"""When ONNX is unavailable, the routing layer in trocr_service falls back."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_fallback_to_pytorch(self):
|
||||
"""With backend='auto' and ONNX unavailable, PyTorch path is used."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch result", 0.9),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch result"
|
||||
assert conf == 0.9
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_backend_forced(self):
|
||||
"""With backend='onnx', failure raises RuntimeError."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="ONNX backend.*unavailable"):
|
||||
await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pytorch_backend_skips_onnx(self):
|
||||
"""With backend='pytorch', ONNX is never attempted."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch only", 0.85),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_not_called()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch only"
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: TROCR_BACKEND env var handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackendConfig:
|
||||
"""TROCR_BACKEND environment variable handling."""
|
||||
|
||||
def test_default_backend_is_auto(self):
|
||||
"""Without env var, backend defaults to 'auto'."""
|
||||
import services.trocr_service as svc
|
||||
# The module reads the env var at import time; in a fresh import
|
||||
# with no TROCR_BACKEND set, it should default to "auto".
|
||||
# We test the get_active_backend function instead.
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
assert svc.get_active_backend() == "auto"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_pytorch(self):
|
||||
"""TROCR_BACKEND=pytorch is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
assert svc.get_active_backend() == "pytorch"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_onnx(self):
|
||||
"""TROCR_BACKEND=onnx is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
assert svc.get_active_backend() == "onnx"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_env_var_read_at_import(self):
|
||||
"""Module reads TROCR_BACKEND from environment."""
|
||||
# We can't easily re-import, but we can verify the variable exists
|
||||
import services.trocr_service as svc
|
||||
assert hasattr(svc, "_trocr_backend")
|
||||
assert svc._trocr_backend in ("auto", "pytorch", "onnx")
|
||||
Reference in New Issue
Block a user