feat: Sprint 2 — TrOCR ONNX, PP-DocLayout, Model Management
D2: TrOCR ONNX export script (printed + handwritten, int8 quantization) D3: PP-DocLayout ONNX export script (download or Docker-based conversion) B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config) A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND) A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND) B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding C3: TrOCR-ONNX.md documentation C4: OCR-Pipeline.md ONNX section added C5: mkdocs.yml nav updated, optimum added to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
394
klausur-service/backend/tests/test_doclayout_detect.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Tests for PP-DocLayout ONNX Document Layout Detection.
|
||||
|
||||
Uses mocking to avoid requiring the actual ONNX model file.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# We patch the module-level globals before importing to ensure clean state
|
||||
# in tests that check "no model" behaviour.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fresh_import():
|
||||
"""Re-import cv_doclayout_detect with reset globals."""
|
||||
import cv_doclayout_detect as mod
|
||||
# Reset module-level caching so each test starts clean
|
||||
mod._onnx_session = None
|
||||
mod._model_path = None
|
||||
mod._load_attempted = False
|
||||
mod._load_error = None
|
||||
return mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. is_doclayout_available — no model present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsDoclayoutAvailableNoModel:
|
||||
def test_returns_false_when_no_onnx_file(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
def test_returns_false_when_onnxruntime_missing(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value="/fake/model.onnx"):
|
||||
with patch.dict("sys.modules", {"onnxruntime": None}):
|
||||
# Force ImportError by making import fail
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "onnxruntime":
|
||||
raise ImportError("no onnxruntime")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
assert mod.is_doclayout_available() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. LayoutRegion dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLayoutRegionDataclass:
|
||||
def test_basic_creation(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
region = LayoutRegion(
|
||||
x=10, y=20, width=100, height=200,
|
||||
label="figure", confidence=0.95, label_index=1,
|
||||
)
|
||||
assert region.x == 10
|
||||
assert region.y == 20
|
||||
assert region.width == 100
|
||||
assert region.height == 200
|
||||
assert region.label == "figure"
|
||||
assert region.confidence == 0.95
|
||||
assert region.label_index == 1
|
||||
|
||||
def test_all_fields_present(self):
|
||||
from cv_doclayout_detect import LayoutRegion
|
||||
import dataclasses
|
||||
field_names = {f.name for f in dataclasses.fields(LayoutRegion)}
|
||||
expected = {"x", "y", "width", "height", "label", "confidence", "label_index"}
|
||||
assert field_names == expected
|
||||
|
||||
def test_different_labels(self):
|
||||
from cv_doclayout_detect import LayoutRegion, DOCLAYOUT_CLASSES
|
||||
for idx, label in enumerate(DOCLAYOUT_CLASSES):
|
||||
region = LayoutRegion(
|
||||
x=0, y=0, width=50, height=50,
|
||||
label=label, confidence=0.8, label_index=idx,
|
||||
)
|
||||
assert region.label == label
|
||||
assert region.label_index == idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. detect_layout_regions — no model available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetectLayoutRegionsNoModel:
|
||||
def test_returns_empty_list_when_model_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_none_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
result = mod.detect_layout_regions(None)
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_list_for_empty_image(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
img = np.array([], dtype=np.uint8)
|
||||
result = mod.detect_layout_regions(img)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Preprocessing — tensor shape verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPreprocessingShapes:
|
||||
def test_square_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (800, 800, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
assert tensor.dtype == np.float32
|
||||
assert 0.0 <= tensor.min()
|
||||
assert tensor.max() <= 1.0
|
||||
|
||||
def test_landscape_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (600, 1200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Landscape: scale by width, should have vertical padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_y > 0 # vertical padding expected
|
||||
|
||||
def test_portrait_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (1200, 600, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
# Portrait: scale by height, should have horizontal padding
|
||||
expected_scale = 800 / 1200
|
||||
assert abs(scale - expected_scale) < 1e-5
|
||||
assert pad_x > 0 # horizontal padding expected
|
||||
|
||||
def test_small_image(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (100, 200, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_typical_scan_a4(self):
|
||||
"""A4 scan at 300dpi: roughly 2480x3508 pixels."""
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
img = np.random.randint(0, 255, (3508, 2480, 3), dtype=np.uint8)
|
||||
tensor, scale, pad_x, pad_y = preprocess_image(img)
|
||||
assert tensor.shape == (1, 3, 800, 800)
|
||||
|
||||
def test_values_normalized(self):
|
||||
from cv_doclayout_detect import preprocess_image
|
||||
# All white image
|
||||
img = np.full((400, 400, 3), 255, dtype=np.uint8)
|
||||
tensor, _, _, _ = preprocess_image(img)
|
||||
# The padded region is 114/255 ≈ 0.447, the image region is 1.0
|
||||
assert tensor.max() <= 1.0
|
||||
assert tensor.min() >= 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. NMS logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNmsLogic:
|
||||
def test_empty_input(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([]).reshape(0, 4)
|
||||
scores = np.array([])
|
||||
assert nms(boxes, scores) == []
|
||||
|
||||
def test_single_box(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([[10, 10, 100, 100]], dtype=np.float32)
|
||||
scores = np.array([0.9])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert kept == [0]
|
||||
|
||||
def test_non_overlapping_boxes(self):
|
||||
from cv_doclayout_detect import nms
|
||||
boxes = np.array([
|
||||
[0, 0, 50, 50],
|
||||
[200, 200, 300, 300],
|
||||
[400, 400, 500, 500],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 3
|
||||
assert set(kept) == {0, 1, 2}
|
||||
|
||||
def test_overlapping_boxes_suppressed(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that heavily overlap
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110], # 100x100
|
||||
[15, 15, 115, 115], # 100x100, heavily overlapping with first
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.95, 0.80])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Only the higher-confidence box should survive
|
||||
assert kept == [0]
|
||||
|
||||
def test_partially_overlapping_boxes_kept(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Two boxes that overlap ~25% (below 0.5 threshold)
|
||||
boxes = np.array([
|
||||
[0, 0, 100, 100], # 100x100
|
||||
[75, 0, 175, 100], # 100x100, overlap 25x100 = 2500
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.9, 0.8])
|
||||
# IoU = 2500 / (10000 + 10000 - 2500) = 2500/17500 ≈ 0.143
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
assert len(kept) == 2
|
||||
|
||||
def test_nms_respects_score_ordering(self):
|
||||
from cv_doclayout_detect import nms
|
||||
# Three overlapping boxes — highest confidence should be kept first
|
||||
boxes = np.array([
|
||||
[10, 10, 110, 110],
|
||||
[12, 12, 112, 112],
|
||||
[14, 14, 114, 114],
|
||||
], dtype=np.float32)
|
||||
scores = np.array([0.5, 0.9, 0.7])
|
||||
kept = nms(boxes, scores, iou_threshold=0.5)
|
||||
# Index 1 has highest score → kept first, suppresses 0 and 2
|
||||
assert kept[0] == 1
|
||||
|
||||
def test_iou_computation(self):
|
||||
from cv_doclayout_detect import _compute_iou
|
||||
box_a = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
box_b = np.array([0, 0, 100, 100], dtype=np.float32)
|
||||
assert abs(_compute_iou(box_a, box_b) - 1.0) < 1e-5
|
||||
|
||||
box_c = np.array([200, 200, 300, 300], dtype=np.float32)
|
||||
assert _compute_iou(box_a, box_c) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DOCLAYOUT_CLASSES verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDoclayoutClasses:
|
||||
def test_correct_class_list(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
expected = [
|
||||
"table", "figure", "title", "text", "list",
|
||||
"header", "footer", "equation", "reference", "abstract",
|
||||
]
|
||||
assert DOCLAYOUT_CLASSES == expected
|
||||
|
||||
def test_class_count(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == 10
|
||||
|
||||
def test_no_duplicates(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
assert len(DOCLAYOUT_CLASSES) == len(set(DOCLAYOUT_CLASSES))
|
||||
|
||||
def test_all_lowercase(self):
|
||||
from cv_doclayout_detect import DOCLAYOUT_CLASSES
|
||||
for cls in DOCLAYOUT_CLASSES:
|
||||
assert cls == cls.lower(), f"Class '{cls}' should be lowercase"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. get_doclayout_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDoclayoutStatus:
|
||||
def test_status_when_unavailable(self):
|
||||
mod = _fresh_import()
|
||||
with patch.object(mod, "_find_model_path", return_value=None):
|
||||
status = mod.get_doclayout_status()
|
||||
assert status["available"] is False
|
||||
assert status["model_path"] is None
|
||||
assert status["load_error"] is not None
|
||||
assert status["classes"] == mod.DOCLAYOUT_CLASSES
|
||||
assert status["class_count"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Post-processing with mocked ONNX outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPostprocessing:
|
||||
def test_single_tensor_format_6cols(self):
|
||||
"""Test parsing of (1, N, 6) output format: x1,y1,x2,y2,score,class."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# One detection: figure at (100,100)-(300,300) in 800x800 space
|
||||
raw = np.array([[[100, 100, 300, 300, 0.92, 1]]], dtype=np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
assert regions[0].confidence >= 0.9
|
||||
|
||||
def test_three_tensor_format(self):
|
||||
"""Test parsing of 3-tensor output: boxes, scores, class_ids."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
boxes = np.array([[50, 50, 200, 150]], dtype=np.float32)
|
||||
scores = np.array([0.88], dtype=np.float32)
|
||||
class_ids = np.array([0], dtype=np.float32) # table
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[boxes, scores, class_ids],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "table"
|
||||
|
||||
def test_confidence_filtering(self):
|
||||
"""Detections below threshold should be excluded."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
raw = np.array([
|
||||
[100, 100, 200, 200, 0.9, 1], # above threshold
|
||||
[300, 300, 400, 400, 0.3, 2], # below threshold
|
||||
], dtype=np.float32).reshape(1, 2, 6)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
assert regions[0].label == "figure"
|
||||
|
||||
def test_coordinate_scaling(self):
|
||||
"""Verify coordinates are correctly scaled back to original image."""
|
||||
from cv_doclayout_detect import _postprocess
|
||||
|
||||
# Image was 1600x1200, scaled to fit 800x800 → scale=0.5, pad_y offset
|
||||
scale = 800 / 1600 # 0.5
|
||||
pad_x = 0
|
||||
pad_y = (800 - int(1200 * scale)) // 2 # (800-600)//2 = 100
|
||||
|
||||
# Detection in 800x800 space at (100, 200) to (300, 400)
|
||||
raw = np.array([[[100, 200, 300, 400, 0.95, 0]]], dtype=np.float32)
|
||||
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=scale, pad_x=pad_x, pad_y=pad_y,
|
||||
orig_w=1600, orig_h=1200,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert len(regions) == 1
|
||||
r = regions[0]
|
||||
# x1 = (100 - 0) / 0.5 = 200
|
||||
assert r.x == 200
|
||||
# y1 = (200 - 100) / 0.5 = 200
|
||||
assert r.y == 200
|
||||
|
||||
def test_empty_output(self):
|
||||
from cv_doclayout_detect import _postprocess
|
||||
raw = np.array([]).reshape(1, 0, 6).astype(np.float32)
|
||||
regions = _postprocess(
|
||||
outputs=[raw],
|
||||
scale=1.0, pad_x=0, pad_y=0,
|
||||
orig_w=800, orig_h=800,
|
||||
confidence_threshold=0.5,
|
||||
max_regions=50,
|
||||
)
|
||||
assert regions == []
|
||||
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
339
klausur-service/backend/tests/test_trocr_onnx.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Tests for TrOCR ONNX service.
|
||||
|
||||
All tests use mocking — no actual ONNX model files required.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _services_path():
|
||||
"""Return absolute path to the services/ directory."""
|
||||
return Path(__file__).resolve().parent.parent / "services"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: is_onnx_available — no models on disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOnnxAvailableNoModels:
|
||||
"""When no ONNX files exist on disk, is_onnx_available must return False."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_is_onnx_available_no_models(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
assert is_onnx_available(handwritten=True) is False
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
def test_is_onnx_available_no_runtime(self, mock_runtime):
|
||||
"""Even if model dirs existed, missing runtime → False."""
|
||||
from services.trocr_onnx_service import is_onnx_available
|
||||
|
||||
assert is_onnx_available(handwritten=False) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: get_onnx_model_status — not available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelStatusNotAvailable:
|
||||
"""Status dict when ONNX is not loaded."""
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
)
|
||||
def test_onnx_model_status_not_available(self, mock_resolve, mock_runtime):
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
|
||||
# Clear any cached models from prior tests
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["backend"] == "onnx"
|
||||
assert status["runtime_available"] is False
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["printed"]["loaded"] is False
|
||||
assert status["printed"]["model_dir"] is None
|
||||
assert status["handwritten"]["available"] is False
|
||||
assert status["handwritten"]["loaded"] is False
|
||||
assert status["handwritten"]["model_dir"] is None
|
||||
assert status["loaded_at"] is None
|
||||
assert status["providers"] == []
|
||||
|
||||
@patch(
|
||||
"services.trocr_onnx_service._check_onnx_runtime_available",
|
||||
return_value=True,
|
||||
)
|
||||
def test_onnx_model_status_runtime_but_no_files(self, mock_runtime):
|
||||
"""Runtime installed but no model files on disk."""
|
||||
from services.trocr_onnx_service import get_onnx_model_status
|
||||
import services.trocr_onnx_service as mod
|
||||
mod._onnx_models.clear()
|
||||
mod._onnx_model_loaded_at = None
|
||||
|
||||
with patch(
|
||||
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
||||
return_value=None,
|
||||
), patch("services.trocr_onnx_service.onnxruntime", create=True) as mock_ort:
|
||||
# Mock onnxruntime import inside get_onnx_model_status
|
||||
mock_ort_module = MagicMock()
|
||||
mock_ort_module.get_available_providers.return_value = [
|
||||
"CPUExecutionProvider"
|
||||
]
|
||||
with patch.dict("sys.modules", {"onnxruntime": mock_ort_module}):
|
||||
status = get_onnx_model_status()
|
||||
|
||||
assert status["runtime_available"] is True
|
||||
assert status["printed"]["available"] is False
|
||||
assert status["handwritten"]["available"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: path resolution logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxModelPaths:
|
||||
"""Verify the path resolution order."""
|
||||
|
||||
def test_env_var_path_takes_precedence(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR env var should be checked first."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# Create a fake model dir with a config.json
|
||||
model_dir = tmp_path / "trocr-base-printed"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "config.json").write_text("{}")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_env_var_handwritten_variant(self, tmp_path):
|
||||
"""TROCR_ONNX_DIR works for handwritten variant too."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
model_dir = tmp_path / "trocr-base-handwritten"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "encoder_model.onnx").write_bytes(b"fake")
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=True)
|
||||
|
||||
assert result is not None
|
||||
assert result == model_dir
|
||||
|
||||
def test_returns_none_when_no_dirs_exist(self):
|
||||
"""When none of the candidate dirs exist, return None."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove TROCR_ONNX_DIR if set
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# The Docker and local-dev paths almost certainly don't contain
|
||||
# real ONNX models on the test machine.
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# Could be None or a real dir if someone has models locally.
|
||||
# We just verify it doesn't raise.
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_docker_path_checked(self, tmp_path):
|
||||
"""Docker path /root/.cache/huggingface/onnx/ is in candidate list."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
docker_path = Path("/root/.cache/huggingface/onnx/trocr-base-printed")
|
||||
|
||||
# We can't create that path in tests, but we can verify the logic
|
||||
# by checking that when env var points nowhere and docker path
|
||||
# doesn't exist, the function still runs without error.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
# Just verify it doesn't crash
|
||||
_resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
def test_local_dev_path_relative_to_backend(self, tmp_path):
|
||||
"""Local dev path is models/onnx/<variant>/ relative to backend dir."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
# The backend dir is derived from __file__, so we can't easily
|
||||
# redirect it. Instead, verify the function signature and return type.
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("TROCR_ONNX_DIR", None)
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
# May or may not find models — just verify the return type
|
||||
assert result is None or isinstance(result, Path)
|
||||
|
||||
def test_dir_without_onnx_files_is_skipped(self, tmp_path):
|
||||
"""A directory that exists but has no .onnx files or config.json is skipped."""
|
||||
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
||||
|
||||
empty_dir = tmp_path / "trocr-base-printed"
|
||||
empty_dir.mkdir(parents=True)
|
||||
# No .onnx files, no config.json
|
||||
|
||||
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
||||
result = _resolve_onnx_model_dir(handwritten=False)
|
||||
|
||||
# The env-var candidate exists as a dir but has no model files,
|
||||
# so it should be skipped. Result depends on whether other
|
||||
# candidate dirs have models.
|
||||
if result is not None:
|
||||
# If found elsewhere, that's fine — just not the empty dir
|
||||
assert result != empty_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: fallback to PyTorch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOnnxFallbackToPytorch:
|
||||
"""When ONNX is unavailable, the routing layer in trocr_service falls back."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_fallback_to_pytorch(self):
|
||||
"""With backend='auto' and ONNX unavailable, PyTorch path is used."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch result", 0.9),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch result"
|
||||
assert conf == 0.9
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_onnx_backend_forced(self):
|
||||
"""With backend='onnx', failure raises RuntimeError."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="ONNX backend.*unavailable"):
|
||||
await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pytorch_backend_skips_onnx(self):
|
||||
"""With backend='pytorch', ONNX is never attempted."""
|
||||
import services.trocr_service as svc
|
||||
|
||||
original_backend = svc._trocr_backend
|
||||
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
|
||||
with patch(
|
||||
"services.trocr_service._try_onnx_ocr",
|
||||
) as mock_onnx, patch(
|
||||
"services.trocr_service._run_pytorch_ocr",
|
||||
return_value=("pytorch only", 0.85),
|
||||
) as mock_pytorch:
|
||||
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
||||
|
||||
mock_onnx.assert_not_called()
|
||||
mock_pytorch.assert_called_once()
|
||||
assert text == "pytorch only"
|
||||
|
||||
finally:
|
||||
svc._trocr_backend = original_backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: TROCR_BACKEND env var handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackendConfig:
|
||||
"""TROCR_BACKEND environment variable handling."""
|
||||
|
||||
def test_default_backend_is_auto(self):
|
||||
"""Without env var, backend defaults to 'auto'."""
|
||||
import services.trocr_service as svc
|
||||
# The module reads the env var at import time; in a fresh import
|
||||
# with no TROCR_BACKEND set, it should default to "auto".
|
||||
# We test the get_active_backend function instead.
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "auto"
|
||||
assert svc.get_active_backend() == "auto"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_pytorch(self):
|
||||
"""TROCR_BACKEND=pytorch is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "pytorch"
|
||||
assert svc.get_active_backend() == "pytorch"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_backend_onnx(self):
|
||||
"""TROCR_BACKEND=onnx is reflected in get_active_backend."""
|
||||
import services.trocr_service as svc
|
||||
original = svc._trocr_backend
|
||||
try:
|
||||
svc._trocr_backend = "onnx"
|
||||
assert svc.get_active_backend() == "onnx"
|
||||
finally:
|
||||
svc._trocr_backend = original
|
||||
|
||||
def test_env_var_read_at_import(self):
|
||||
"""Module reads TROCR_BACKEND from environment."""
|
||||
# We can't easily re-import, but we can verify the variable exists
|
||||
import services.trocr_service as svc
|
||||
assert hasattr(svc, "_trocr_backend")
|
||||
assert svc._trocr_backend in ("auto", "pytorch", "onnx")
|
||||
Reference in New Issue
Block a user