""" Tests for TrOCR ONNX service. All tests use mocking — no actual ONNX model files required. """ import os import pytest from pathlib import Path from unittest.mock import patch, MagicMock, PropertyMock # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _services_path(): """Return absolute path to the services/ directory.""" return Path(__file__).resolve().parent.parent / "services" # --------------------------------------------------------------------------- # Test: is_onnx_available — no models on disk # --------------------------------------------------------------------------- class TestIsOnnxAvailableNoModels: """When no ONNX files exist on disk, is_onnx_available must return False.""" @patch( "services.trocr_onnx_service._check_onnx_runtime_available", return_value=True, ) @patch( "services.trocr_onnx_service._resolve_onnx_model_dir", return_value=None, ) def test_is_onnx_available_no_models(self, mock_resolve, mock_runtime): from services.trocr_onnx_service import is_onnx_available assert is_onnx_available(handwritten=False) is False assert is_onnx_available(handwritten=True) is False @patch( "services.trocr_onnx_service._check_onnx_runtime_available", return_value=False, ) def test_is_onnx_available_no_runtime(self, mock_runtime): """Even if model dirs existed, missing runtime → False.""" from services.trocr_onnx_service import is_onnx_available assert is_onnx_available(handwritten=False) is False # --------------------------------------------------------------------------- # Test: get_onnx_model_status — not available # --------------------------------------------------------------------------- class TestOnnxModelStatusNotAvailable: """Status dict when ONNX is not loaded.""" @patch( "services.trocr_onnx_service._check_onnx_runtime_available", return_value=False, ) @patch( "services.trocr_onnx_service._resolve_onnx_model_dir", return_value=None, ) def test_onnx_model_status_not_available(self, mock_resolve, mock_runtime): from services.trocr_onnx_service import get_onnx_model_status # Clear any cached models from prior tests import services.trocr_onnx_service as mod mod._onnx_models.clear() mod._onnx_model_loaded_at = None status = get_onnx_model_status() assert status["backend"] == "onnx" assert status["runtime_available"] is False assert status["printed"]["available"] is False assert status["printed"]["loaded"] is False assert status["printed"]["model_dir"] is None assert status["handwritten"]["available"] is False assert status["handwritten"]["loaded"] is False assert status["handwritten"]["model_dir"] is None assert status["loaded_at"] is None assert status["providers"] == [] @patch( "services.trocr_onnx_service._check_onnx_runtime_available", return_value=True, ) def test_onnx_model_status_runtime_but_no_files(self, mock_runtime): """Runtime installed but no model files on disk.""" from services.trocr_onnx_service import get_onnx_model_status import services.trocr_onnx_service as mod mod._onnx_models.clear() mod._onnx_model_loaded_at = None with patch( "services.trocr_onnx_service._resolve_onnx_model_dir", return_value=None, ), patch("services.trocr_onnx_service.onnxruntime", create=True) as mock_ort: # Mock onnxruntime import inside get_onnx_model_status mock_ort_module = MagicMock() mock_ort_module.get_available_providers.return_value = [ "CPUExecutionProvider" ] with patch.dict("sys.modules", {"onnxruntime": mock_ort_module}): status = get_onnx_model_status() assert status["runtime_available"] is True assert status["printed"]["available"] is False assert status["handwritten"]["available"] is False # --------------------------------------------------------------------------- # Test: path resolution logic # --------------------------------------------------------------------------- class TestOnnxModelPaths: """Verify the path resolution order.""" def test_env_var_path_takes_precedence(self, tmp_path): """TROCR_ONNX_DIR env var should be checked first.""" from services.trocr_onnx_service import _resolve_onnx_model_dir # Create a fake model dir with a config.json model_dir = tmp_path / "trocr-base-printed" model_dir.mkdir(parents=True) (model_dir / "config.json").write_text("{}") with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}): result = _resolve_onnx_model_dir(handwritten=False) assert result is not None assert result == model_dir def test_env_var_handwritten_variant(self, tmp_path): """TROCR_ONNX_DIR works for handwritten variant too.""" from services.trocr_onnx_service import _resolve_onnx_model_dir model_dir = tmp_path / "trocr-base-handwritten" model_dir.mkdir(parents=True) (model_dir / "encoder_model.onnx").write_bytes(b"fake") with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}): result = _resolve_onnx_model_dir(handwritten=True) assert result is not None assert result == model_dir def test_returns_none_when_no_dirs_exist(self): """When none of the candidate dirs exist, return None.""" from services.trocr_onnx_service import _resolve_onnx_model_dir with patch.dict(os.environ, {}, clear=True): # Remove TROCR_ONNX_DIR if set os.environ.pop("TROCR_ONNX_DIR", None) # The Docker and local-dev paths almost certainly don't contain # real ONNX models on the test machine. result = _resolve_onnx_model_dir(handwritten=False) # Could be None or a real dir if someone has models locally. # We just verify it doesn't raise. assert result is None or isinstance(result, Path) def test_docker_path_checked(self, tmp_path): """Docker path /root/.cache/huggingface/onnx/ is in candidate list.""" from services.trocr_onnx_service import _resolve_onnx_model_dir docker_path = Path("/root/.cache/huggingface/onnx/trocr-base-printed") # We can't create that path in tests, but we can verify the logic # by checking that when env var points nowhere and docker path # doesn't exist, the function still runs without error. with patch.dict(os.environ, {}, clear=True): os.environ.pop("TROCR_ONNX_DIR", None) # Just verify it doesn't crash _resolve_onnx_model_dir(handwritten=False) def test_local_dev_path_relative_to_backend(self, tmp_path): """Local dev path is models/onnx// relative to backend dir.""" from services.trocr_onnx_service import _resolve_onnx_model_dir # The backend dir is derived from __file__, so we can't easily # redirect it. Instead, verify the function signature and return type. with patch.dict(os.environ, {}, clear=True): os.environ.pop("TROCR_ONNX_DIR", None) result = _resolve_onnx_model_dir(handwritten=False) # May or may not find models — just verify the return type assert result is None or isinstance(result, Path) def test_dir_without_onnx_files_is_skipped(self, tmp_path): """A directory that exists but has no .onnx files or config.json is skipped.""" from services.trocr_onnx_service import _resolve_onnx_model_dir empty_dir = tmp_path / "trocr-base-printed" empty_dir.mkdir(parents=True) # No .onnx files, no config.json with patch.dict(os.environ, {"TROCR_ONNX_DIR": str(tmp_path)}): result = _resolve_onnx_model_dir(handwritten=False) # The env-var candidate exists as a dir but has no model files, # so it should be skipped. Result depends on whether other # candidate dirs have models. if result is not None: # If found elsewhere, that's fine — just not the empty dir assert result != empty_dir # --------------------------------------------------------------------------- # Test: fallback to PyTorch # --------------------------------------------------------------------------- class TestOnnxFallbackToPytorch: """When ONNX is unavailable, the routing layer in trocr_service falls back.""" @pytest.mark.asyncio async def test_onnx_fallback_to_pytorch(self): """With backend='auto' and ONNX unavailable, PyTorch path is used.""" import services.trocr_service as svc original_backend = svc._trocr_backend try: svc._trocr_backend = "auto" with patch( "services.trocr_service._try_onnx_ocr", return_value=None, ) as mock_onnx, patch( "services.trocr_service._run_pytorch_ocr", return_value=("pytorch result", 0.9), ) as mock_pytorch: text, conf = await svc.run_trocr_ocr(b"fake-image-data") mock_onnx.assert_called_once() mock_pytorch.assert_called_once() assert text == "pytorch result" assert conf == 0.9 finally: svc._trocr_backend = original_backend @pytest.mark.asyncio async def test_onnx_backend_forced(self): """With backend='onnx', failure raises RuntimeError.""" import services.trocr_service as svc original_backend = svc._trocr_backend try: svc._trocr_backend = "onnx" with patch( "services.trocr_service._try_onnx_ocr", return_value=None, ): with pytest.raises(RuntimeError, match="ONNX backend.*unavailable"): await svc.run_trocr_ocr(b"fake-image-data") finally: svc._trocr_backend = original_backend @pytest.mark.asyncio async def test_pytorch_backend_skips_onnx(self): """With backend='pytorch', ONNX is never attempted.""" import services.trocr_service as svc original_backend = svc._trocr_backend try: svc._trocr_backend = "pytorch" with patch( "services.trocr_service._try_onnx_ocr", ) as mock_onnx, patch( "services.trocr_service._run_pytorch_ocr", return_value=("pytorch only", 0.85), ) as mock_pytorch: text, conf = await svc.run_trocr_ocr(b"fake-image-data") mock_onnx.assert_not_called() mock_pytorch.assert_called_once() assert text == "pytorch only" finally: svc._trocr_backend = original_backend # --------------------------------------------------------------------------- # Test: TROCR_BACKEND env var handling # --------------------------------------------------------------------------- class TestBackendConfig: """TROCR_BACKEND environment variable handling.""" def test_default_backend_is_auto(self): """Without env var, backend defaults to 'auto'.""" import services.trocr_service as svc # The module reads the env var at import time; in a fresh import # with no TROCR_BACKEND set, it should default to "auto". # We test the get_active_backend function instead. original = svc._trocr_backend try: svc._trocr_backend = "auto" assert svc.get_active_backend() == "auto" finally: svc._trocr_backend = original def test_backend_pytorch(self): """TROCR_BACKEND=pytorch is reflected in get_active_backend.""" import services.trocr_service as svc original = svc._trocr_backend try: svc._trocr_backend = "pytorch" assert svc.get_active_backend() == "pytorch" finally: svc._trocr_backend = original def test_backend_onnx(self): """TROCR_BACKEND=onnx is reflected in get_active_backend.""" import services.trocr_service as svc original = svc._trocr_backend try: svc._trocr_backend = "onnx" assert svc.get_active_backend() == "onnx" finally: svc._trocr_backend = original def test_env_var_read_at_import(self): """Module reads TROCR_BACKEND from environment.""" # We can't easily re-import, but we can verify the variable exists import services.trocr_service as svc assert hasattr(svc, "_trocr_backend") assert svc._trocr_backend in ("auto", "pytorch", "onnx")