D2: TrOCR ONNX export script (printed + handwritten, int8 quantization) D3: PP-DocLayout ONNX export script (download or Docker-based conversion) B3: Model Management admin page (PyTorch vs ONNX status, benchmarks, config) A4: TrOCR ONNX service with runtime routing (auto/pytorch/onnx via TROCR_BACKEND) A5: PP-DocLayout ONNX detection with OpenCV fallback (via GRAPHIC_DETECT_BACKEND) B4: Structure Detection UI toggle (OpenCV vs PP-DocLayout) with class color coding C3: TrOCR-ONNX.md documentation C4: OCR-Pipeline.md ONNX section added C5: mkdocs.yml nav updated, optimum added to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
340 lines
13 KiB
Python
340 lines
13 KiB
Python
"""
|
|
Tests for TrOCR ONNX service.
|
|
|
|
All tests use mocking — no actual ONNX model files required.
|
|
"""
|
|
|
|
import os
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock, PropertyMock
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _services_path():
|
|
"""Return absolute path to the services/ directory."""
|
|
return Path(__file__).resolve().parent.parent / "services"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: is_onnx_available — no models on disk
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestIsOnnxAvailableNoModels:
|
|
"""When no ONNX files exist on disk, is_onnx_available must return False."""
|
|
|
|
@patch(
|
|
"services.trocr_onnx_service._check_onnx_runtime_available",
|
|
return_value=True,
|
|
)
|
|
@patch(
|
|
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
|
return_value=None,
|
|
)
|
|
def test_is_onnx_available_no_models(self, mock_resolve, mock_runtime):
|
|
from services.trocr_onnx_service import is_onnx_available
|
|
|
|
assert is_onnx_available(handwritten=False) is False
|
|
assert is_onnx_available(handwritten=True) is False
|
|
|
|
@patch(
|
|
"services.trocr_onnx_service._check_onnx_runtime_available",
|
|
return_value=False,
|
|
)
|
|
def test_is_onnx_available_no_runtime(self, mock_runtime):
|
|
"""Even if model dirs existed, missing runtime → False."""
|
|
from services.trocr_onnx_service import is_onnx_available
|
|
|
|
assert is_onnx_available(handwritten=False) is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: get_onnx_model_status — not available
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOnnxModelStatusNotAvailable:
|
|
"""Status dict when ONNX is not loaded."""
|
|
|
|
@patch(
|
|
"services.trocr_onnx_service._check_onnx_runtime_available",
|
|
return_value=False,
|
|
)
|
|
@patch(
|
|
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
|
return_value=None,
|
|
)
|
|
def test_onnx_model_status_not_available(self, mock_resolve, mock_runtime):
|
|
from services.trocr_onnx_service import get_onnx_model_status
|
|
|
|
# Clear any cached models from prior tests
|
|
import services.trocr_onnx_service as mod
|
|
mod._onnx_models.clear()
|
|
mod._onnx_model_loaded_at = None
|
|
|
|
status = get_onnx_model_status()
|
|
|
|
assert status["backend"] == "onnx"
|
|
assert status["runtime_available"] is False
|
|
assert status["printed"]["available"] is False
|
|
assert status["printed"]["loaded"] is False
|
|
assert status["printed"]["model_dir"] is None
|
|
assert status["handwritten"]["available"] is False
|
|
assert status["handwritten"]["loaded"] is False
|
|
assert status["handwritten"]["model_dir"] is None
|
|
assert status["loaded_at"] is None
|
|
assert status["providers"] == []
|
|
|
|
@patch(
|
|
"services.trocr_onnx_service._check_onnx_runtime_available",
|
|
return_value=True,
|
|
)
|
|
def test_onnx_model_status_runtime_but_no_files(self, mock_runtime):
|
|
"""Runtime installed but no model files on disk."""
|
|
from services.trocr_onnx_service import get_onnx_model_status
|
|
import services.trocr_onnx_service as mod
|
|
mod._onnx_models.clear()
|
|
mod._onnx_model_loaded_at = None
|
|
|
|
with patch(
|
|
"services.trocr_onnx_service._resolve_onnx_model_dir",
|
|
return_value=None,
|
|
), patch("services.trocr_onnx_service.onnxruntime", create=True) as mock_ort:
|
|
# Mock onnxruntime import inside get_onnx_model_status
|
|
mock_ort_module = MagicMock()
|
|
mock_ort_module.get_available_providers.return_value = [
|
|
"CPUExecutionProvider"
|
|
]
|
|
with patch.dict("sys.modules", {"onnxruntime": mock_ort_module}):
|
|
status = get_onnx_model_status()
|
|
|
|
assert status["runtime_available"] is True
|
|
assert status["printed"]["available"] is False
|
|
assert status["handwritten"]["available"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: path resolution logic
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOnnxModelPaths:
|
|
"""Verify the path resolution order."""
|
|
|
|
def test_env_var_path_takes_precedence(self, tmp_path):
|
|
"""TROCR_ONNX_DIR env var should be checked first."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
# Create a fake model dir with a config.json
|
|
model_dir = tmp_path / "trocr-base-printed"
|
|
model_dir.mkdir(parents=True)
|
|
(model_dir / "config.json").write_text("{}")
|
|
|
|
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
|
result = _resolve_onnx_model_dir(handwritten=False)
|
|
|
|
assert result is not None
|
|
assert result == model_dir
|
|
|
|
def test_env_var_handwritten_variant(self, tmp_path):
|
|
"""TROCR_ONNX_DIR works for handwritten variant too."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
model_dir = tmp_path / "trocr-base-handwritten"
|
|
model_dir.mkdir(parents=True)
|
|
(model_dir / "encoder_model.onnx").write_bytes(b"fake")
|
|
|
|
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
|
result = _resolve_onnx_model_dir(handwritten=True)
|
|
|
|
assert result is not None
|
|
assert result == model_dir
|
|
|
|
def test_returns_none_when_no_dirs_exist(self):
|
|
"""When none of the candidate dirs exist, return None."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
# Remove TROCR_ONNX_DIR if set
|
|
os.environ.pop("TROCR_ONNX_DIR", None)
|
|
# The Docker and local-dev paths almost certainly don't contain
|
|
# real ONNX models on the test machine.
|
|
result = _resolve_onnx_model_dir(handwritten=False)
|
|
|
|
# Could be None or a real dir if someone has models locally.
|
|
# We just verify it doesn't raise.
|
|
assert result is None or isinstance(result, Path)
|
|
|
|
def test_docker_path_checked(self, tmp_path):
|
|
"""Docker path /root/.cache/huggingface/onnx/ is in candidate list."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
docker_path = Path("/root/.cache/huggingface/onnx/trocr-base-printed")
|
|
|
|
# We can't create that path in tests, but we can verify the logic
|
|
# by checking that when env var points nowhere and docker path
|
|
# doesn't exist, the function still runs without error.
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
os.environ.pop("TROCR_ONNX_DIR", None)
|
|
# Just verify it doesn't crash
|
|
_resolve_onnx_model_dir(handwritten=False)
|
|
|
|
def test_local_dev_path_relative_to_backend(self, tmp_path):
|
|
"""Local dev path is models/onnx/<variant>/ relative to backend dir."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
# The backend dir is derived from __file__, so we can't easily
|
|
# redirect it. Instead, verify the function signature and return type.
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
os.environ.pop("TROCR_ONNX_DIR", None)
|
|
result = _resolve_onnx_model_dir(handwritten=False)
|
|
# May or may not find models — just verify the return type
|
|
assert result is None or isinstance(result, Path)
|
|
|
|
def test_dir_without_onnx_files_is_skipped(self, tmp_path):
|
|
"""A directory that exists but has no .onnx files or config.json is skipped."""
|
|
from services.trocr_onnx_service import _resolve_onnx_model_dir
|
|
|
|
empty_dir = tmp_path / "trocr-base-printed"
|
|
empty_dir.mkdir(parents=True)
|
|
# No .onnx files, no config.json
|
|
|
|
with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}):
|
|
result = _resolve_onnx_model_dir(handwritten=False)
|
|
|
|
# The env-var candidate exists as a dir but has no model files,
|
|
# so it should be skipped. Result depends on whether other
|
|
# candidate dirs have models.
|
|
if result is not None:
|
|
# If found elsewhere, that's fine — just not the empty dir
|
|
assert result != empty_dir
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: fallback to PyTorch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOnnxFallbackToPytorch:
|
|
"""When ONNX is unavailable, the routing layer in trocr_service falls back."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_onnx_fallback_to_pytorch(self):
|
|
"""With backend='auto' and ONNX unavailable, PyTorch path is used."""
|
|
import services.trocr_service as svc
|
|
|
|
original_backend = svc._trocr_backend
|
|
|
|
try:
|
|
svc._trocr_backend = "auto"
|
|
|
|
with patch(
|
|
"services.trocr_service._try_onnx_ocr",
|
|
return_value=None,
|
|
) as mock_onnx, patch(
|
|
"services.trocr_service._run_pytorch_ocr",
|
|
return_value=("pytorch result", 0.9),
|
|
) as mock_pytorch:
|
|
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
|
|
|
mock_onnx.assert_called_once()
|
|
mock_pytorch.assert_called_once()
|
|
assert text == "pytorch result"
|
|
assert conf == 0.9
|
|
|
|
finally:
|
|
svc._trocr_backend = original_backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_onnx_backend_forced(self):
|
|
"""With backend='onnx', failure raises RuntimeError."""
|
|
import services.trocr_service as svc
|
|
|
|
original_backend = svc._trocr_backend
|
|
|
|
try:
|
|
svc._trocr_backend = "onnx"
|
|
|
|
with patch(
|
|
"services.trocr_service._try_onnx_ocr",
|
|
return_value=None,
|
|
):
|
|
with pytest.raises(RuntimeError, match="ONNX backend.*unavailable"):
|
|
await svc.run_trocr_ocr(b"fake-image-data")
|
|
|
|
finally:
|
|
svc._trocr_backend = original_backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pytorch_backend_skips_onnx(self):
|
|
"""With backend='pytorch', ONNX is never attempted."""
|
|
import services.trocr_service as svc
|
|
|
|
original_backend = svc._trocr_backend
|
|
|
|
try:
|
|
svc._trocr_backend = "pytorch"
|
|
|
|
with patch(
|
|
"services.trocr_service._try_onnx_ocr",
|
|
) as mock_onnx, patch(
|
|
"services.trocr_service._run_pytorch_ocr",
|
|
return_value=("pytorch only", 0.85),
|
|
) as mock_pytorch:
|
|
text, conf = await svc.run_trocr_ocr(b"fake-image-data")
|
|
|
|
mock_onnx.assert_not_called()
|
|
mock_pytorch.assert_called_once()
|
|
assert text == "pytorch only"
|
|
|
|
finally:
|
|
svc._trocr_backend = original_backend
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test: TROCR_BACKEND env var handling
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestBackendConfig:
|
|
"""TROCR_BACKEND environment variable handling."""
|
|
|
|
def test_default_backend_is_auto(self):
|
|
"""Without env var, backend defaults to 'auto'."""
|
|
import services.trocr_service as svc
|
|
# The module reads the env var at import time; in a fresh import
|
|
# with no TROCR_BACKEND set, it should default to "auto".
|
|
# We test the get_active_backend function instead.
|
|
original = svc._trocr_backend
|
|
try:
|
|
svc._trocr_backend = "auto"
|
|
assert svc.get_active_backend() == "auto"
|
|
finally:
|
|
svc._trocr_backend = original
|
|
|
|
def test_backend_pytorch(self):
|
|
"""TROCR_BACKEND=pytorch is reflected in get_active_backend."""
|
|
import services.trocr_service as svc
|
|
original = svc._trocr_backend
|
|
try:
|
|
svc._trocr_backend = "pytorch"
|
|
assert svc.get_active_backend() == "pytorch"
|
|
finally:
|
|
svc._trocr_backend = original
|
|
|
|
def test_backend_onnx(self):
|
|
"""TROCR_BACKEND=onnx is reflected in get_active_backend."""
|
|
import services.trocr_service as svc
|
|
original = svc._trocr_backend
|
|
try:
|
|
svc._trocr_backend = "onnx"
|
|
assert svc.get_active_backend() == "onnx"
|
|
finally:
|
|
svc._trocr_backend = original
|
|
|
|
def test_env_var_read_at_import(self):
|
|
"""Module reads TROCR_BACKEND from environment."""
|
|
# We can't easily re-import, but we can verify the variable exists
|
|
import services.trocr_service as svc
|
|
assert hasattr(svc, "_trocr_backend")
|
|
assert svc._trocr_backend in ("auto", "pytorch", "onnx")
|