""" Tests for OCR Labeling API Tests session management, image upload, labeling workflow, and training export. BACKLOG: Feature not yet fully integrated - requires external OCR services See: https://macmini:3002/infrastructure/tests """ import pytest # Mark all tests in this module as expected failures (backlog item) pytestmark = pytest.mark.xfail( reason="ocr_labeling requires external services not available in CI - Backlog item", strict=False # Don't fail if test unexpectedly passes ) from unittest.mock import AsyncMock, MagicMock, patch from datetime import datetime import io import json import hashlib # ============================================================================= # Test Fixtures # ============================================================================= @pytest.fixture def mock_db_pool(): """Mock PostgreSQL connection pool.""" with patch('metrics_db.get_pool') as mock: pool = AsyncMock() conn = AsyncMock() pool.acquire.return_value.__aenter__.return_value = conn mock.return_value = pool yield pool, conn @pytest.fixture def mock_minio(): """Mock MinIO storage functions.""" with patch('ocr_labeling_api.MINIO_AVAILABLE', True), \ patch('ocr_labeling_api.upload_ocr_image') as upload_mock, \ patch('ocr_labeling_api.get_ocr_image') as get_mock: upload_mock.return_value = "ocr-labeling/session-123/item-456.png" get_mock.return_value = b"\x89PNG fake image data" yield upload_mock, get_mock @pytest.fixture def mock_vision_ocr(): """Mock Vision OCR service.""" with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', True), \ patch('ocr_labeling_api.get_vision_ocr_service') as mock: service = AsyncMock() service.is_available.return_value = True result = MagicMock() result.text = "Erkannter Text aus dem Bild" result.confidence = 0.87 service.extract_text.return_value = result mock.return_value = service yield service @pytest.fixture def mock_training_export(): """Mock training export service.""" with patch('ocr_labeling_api.TRAINING_EXPORT_AVAILABLE', True), \ patch('ocr_labeling_api.get_training_export_service') as mock: service = MagicMock() export_result = MagicMock() export_result.export_path = "/app/ocr-exports/generic/20260121_120000" export_result.manifest_path = "/app/ocr-exports/generic/20260121_120000/manifest.json" export_result.batch_id = "20260121_120000" service.export.return_value = export_result service.list_exports.return_value = [ {"format": "generic", "batch_id": "20260121_120000", "sample_count": 10} ] mock.return_value = service yield service @pytest.fixture def sample_session(): """Sample session data.""" return { "id": "session-123", "name": "Test Session", "source_type": "klausur", "description": "Test description", "ocr_model": "llama3.2-vision:11b", "total_items": 5, "labeled_items": 2, "confirmed_items": 1, "corrected_items": 1, "skipped_items": 0, "created_at": datetime.utcnow(), } @pytest.fixture def sample_item(): """Sample labeling item data.""" return { "id": "item-456", "session_id": "session-123", "session_name": "Test Session", "image_path": "/app/ocr-labeling/session-123/item-456.png", "ocr_text": "Erkannter Text", "ocr_confidence": 0.87, "ground_truth": None, "status": "pending", "metadata": {"page": 1}, "created_at": datetime.utcnow(), } # ============================================================================= # Session Management Tests # ============================================================================= class TestSessionCreation: """Tests for session creation.""" @pytest.mark.asyncio async def test_create_session_success(self, mock_db_pool): """Test successful session creation.""" from ocr_labeling_api import SessionCreate from metrics_db import create_ocr_labeling_session pool, conn = mock_db_pool conn.execute.return_value = None with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool): result = await create_ocr_labeling_session( session_id="session-123", name="Test Session", source_type="klausur", description="Test", ocr_model="llama3.2-vision:11b", ) # Should call execute to insert assert pool.acquire.called def test_session_create_model_validation(self): """Test SessionCreate model validation.""" from ocr_labeling_api import SessionCreate # Valid session session = SessionCreate( name="Test Session", source_type="klausur", description="Test description", ) assert session.name == "Test Session" assert session.source_type == "klausur" assert session.ocr_model == "llama3.2-vision:11b" # default def test_session_create_with_custom_model(self): """Test SessionCreate with custom OCR model.""" from ocr_labeling_api import SessionCreate session = SessionCreate( name="TrOCR Session", source_type="handwriting_sample", ocr_model="trocr-base", ) assert session.ocr_model == "trocr-base" class TestSessionListing: """Tests for session listing.""" @pytest.mark.asyncio async def test_get_sessions_empty(self): """Test getting sessions when none exist.""" from metrics_db import get_ocr_labeling_sessions with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None): sessions = await get_ocr_labeling_sessions() assert sessions == [] @pytest.mark.asyncio async def test_get_session_not_found(self): """Test getting a non-existent session.""" from metrics_db import get_ocr_labeling_session with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None): session = await get_ocr_labeling_session("non-existent-id") assert session is None # ============================================================================= # Image Upload Tests # ============================================================================= class TestImageUpload: """Tests for image upload functionality.""" def test_compute_image_hash(self): """Test image hash computation.""" from ocr_labeling_api import compute_image_hash image_data = b"\x89PNG fake image data" hash1 = compute_image_hash(image_data) hash2 = compute_image_hash(image_data) # Same data should produce same hash assert hash1 == hash2 assert len(hash1) == 64 # SHA256 hex length def test_compute_image_hash_different_data(self): """Test that different images produce different hashes.""" from ocr_labeling_api import compute_image_hash hash1 = compute_image_hash(b"image 1 data") hash2 = compute_image_hash(b"image 2 data") assert hash1 != hash2 def test_save_image_locally(self, tmp_path): """Test local image saving.""" from ocr_labeling_api import save_image_locally, LOCAL_STORAGE_PATH # Temporarily override storage path with patch('ocr_labeling_api.LOCAL_STORAGE_PATH', str(tmp_path)): from ocr_labeling_api import save_image_locally image_data = b"\x89PNG fake image data" filepath = save_image_locally( session_id="session-123", item_id="item-456", image_data=image_data, extension="png", ) assert filepath.endswith("item-456.png") # File should exist import os assert os.path.exists(filepath) def test_get_image_url_local(self): """Test URL generation for local images.""" from ocr_labeling_api import get_image_url, LOCAL_STORAGE_PATH local_path = f"{LOCAL_STORAGE_PATH}/session-123/item-456.png" url = get_image_url(local_path) assert url == "/api/v1/ocr-label/images/session-123/item-456.png" def test_get_image_url_minio(self): """Test URL for MinIO images (passthrough).""" from ocr_labeling_api import get_image_url minio_path = "ocr-labeling/session-123/item-456.png" url = get_image_url(minio_path) # Non-local paths are passed through assert url == minio_path # ============================================================================= # Labeling Workflow Tests # ============================================================================= class TestConfirmLabel: """Tests for label confirmation.""" @pytest.mark.asyncio async def test_confirm_label_success(self, mock_db_pool): """Test successful label confirmation.""" from metrics_db import confirm_ocr_label pool, conn = mock_db_pool conn.fetchrow.return_value = {"ocr_text": "Test text"} conn.execute.return_value = None with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool): result = await confirm_ocr_label( item_id="item-456", labeled_by="admin", label_time_seconds=5, ) # Should update item status and ground_truth assert conn.execute.called def test_confirm_request_validation(self): """Test ConfirmRequest model validation.""" from ocr_labeling_api import ConfirmRequest request = ConfirmRequest( item_id="item-456", label_time_seconds=5, ) assert request.item_id == "item-456" assert request.label_time_seconds == 5 class TestCorrectLabel: """Tests for label correction.""" @pytest.mark.asyncio async def test_correct_label_success(self, mock_db_pool): """Test successful label correction.""" from metrics_db import correct_ocr_label pool, conn = mock_db_pool conn.execute.return_value = None with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool): result = await correct_ocr_label( item_id="item-456", ground_truth="Korrigierter Text", labeled_by="admin", label_time_seconds=15, ) # Should update item with corrected ground_truth assert conn.execute.called def test_correct_request_validation(self): """Test CorrectRequest model validation.""" from ocr_labeling_api import CorrectRequest request = CorrectRequest( item_id="item-456", ground_truth="Korrigierter Text", label_time_seconds=15, ) assert request.item_id == "item-456" assert request.ground_truth == "Korrigierter Text" class TestSkipItem: """Tests for item skipping.""" @pytest.mark.asyncio async def test_skip_item_success(self, mock_db_pool): """Test successful item skip.""" from metrics_db import skip_ocr_item pool, conn = mock_db_pool conn.execute.return_value = None with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool): result = await skip_ocr_item( item_id="item-456", labeled_by="admin", ) # Should update item status to skipped assert conn.execute.called # ============================================================================= # Statistics Tests # ============================================================================= class TestLabelingStats: """Tests for labeling statistics.""" @pytest.mark.asyncio async def test_get_stats_no_db(self): """Test stats when database is not available.""" from metrics_db import get_ocr_labeling_stats with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None): stats = await get_ocr_labeling_stats() assert "error" in stats or stats.get("total_items", 0) == 0 def test_stats_response_model(self): """Test StatsResponse model structure.""" from ocr_labeling_api import StatsResponse stats = StatsResponse( total_items=100, labeled_items=50, confirmed_items=40, corrected_items=10, pending_items=50, accuracy_rate=0.8, ) assert stats.total_items == 100 assert stats.accuracy_rate == 0.8 # ============================================================================= # Export Tests # ============================================================================= class TestTrainingExport: """Tests for training data export.""" def test_export_request_validation(self): """Test ExportRequest model validation.""" from ocr_labeling_api import ExportRequest # Default format is generic request = ExportRequest() assert request.export_format == "generic" # TrOCR format request = ExportRequest(export_format="trocr") assert request.export_format == "trocr" # Llama Vision format request = ExportRequest(export_format="llama_vision") assert request.export_format == "llama_vision" @pytest.mark.asyncio async def test_export_training_samples(self, mock_db_pool): """Test training sample export from database.""" from metrics_db import export_training_samples pool, conn = mock_db_pool conn.fetch.return_value = [ { "id": "sample-1", "image_path": "/app/ocr-labeling/session-123/item-1.png", "ground_truth": "Text 1", }, { "id": "sample-2", "image_path": "/app/ocr-labeling/session-123/item-2.png", "ground_truth": "Text 2", }, ] with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool): samples = await export_training_samples( export_format="generic", exported_by="admin", ) # Should return exportable samples assert conn.fetch.called or conn.execute.called class TestTrainingExportService: """Tests for training export service.""" def test_trocr_export_format(self): """Test TrOCR export format structure.""" expected_format = { "file_name": "images/sample-1.png", "text": "Ground truth text", "id": "sample-1", } assert "file_name" in expected_format assert "text" in expected_format def test_llama_vision_export_format(self): """Test Llama Vision export format structure.""" expected_format = { "id": "sample-1", "messages": [ {"role": "system", "content": "Du bist ein OCR-Experte..."}, {"role": "user", "content": [ {"type": "image_url", "image_url": {"url": "..."}}, {"type": "text", "text": "Lies den Text..."}, ]}, {"role": "assistant", "content": "Ground truth text"}, ], } assert "messages" in expected_format assert len(expected_format["messages"]) == 3 assert expected_format["messages"][2]["role"] == "assistant" def test_generic_export_format(self): """Test generic export format structure.""" expected_format = { "id": "sample-1", "image_path": "images/sample-1.png", "ground_truth": "Ground truth text", "ocr_text": "OCR recognized text", "ocr_confidence": 0.87, "metadata": {}, } assert "image_path" in expected_format assert "ground_truth" in expected_format # ============================================================================= # OCR Processing Tests # ============================================================================= class TestOCRProcessing: """Tests for OCR processing.""" @pytest.mark.asyncio async def test_run_ocr_on_image_no_service(self): """Test OCR when service is not available.""" from ocr_labeling_api import run_ocr_on_image with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', False), \ patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False), \ patch('ocr_labeling_api.TROCR_AVAILABLE', False), \ patch('ocr_labeling_api.DONUT_AVAILABLE', False): text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", ) assert text is None assert confidence == 0.0 @pytest.mark.asyncio async def test_run_ocr_on_image_success(self, mock_vision_ocr): """Test successful OCR processing.""" from ocr_labeling_api import run_ocr_on_image text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", ) assert text == "Erkannter Text aus dem Bild" assert confidence == 0.87 # ============================================================================= # OCR Model Dispatcher Tests # ============================================================================= class TestOCRModelDispatcher: """Tests for the OCR model dispatcher (v1.1.0).""" @pytest.mark.asyncio async def test_dispatcher_vision_model_default(self, mock_vision_ocr): """Test dispatcher uses Vision OCR by default.""" from ocr_labeling_api import run_ocr_on_image text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="llama3.2-vision:11b", ) assert text == "Erkannter Text aus dem Bild" assert confidence == 0.87 @pytest.mark.asyncio async def test_dispatcher_paddleocr_model(self): """Test dispatcher routes to PaddleOCR.""" from ocr_labeling_api import run_ocr_on_image # Mock PaddleOCR mock_regions = [] mock_text = "PaddleOCR erkannter Text" with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', True), \ patch('ocr_labeling_api.run_paddle_ocr', return_value=(mock_regions, mock_text)): text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="paddleocr", ) assert text == "PaddleOCR erkannter Text" @pytest.mark.asyncio async def test_dispatcher_paddleocr_fallback_to_vision(self, mock_vision_ocr): """Test PaddleOCR falls back to Vision OCR when unavailable.""" from ocr_labeling_api import run_ocr_on_image with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False): text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="paddleocr", ) # Should fall back to Vision OCR assert text == "Erkannter Text aus dem Bild" assert confidence == 0.87 @pytest.mark.asyncio async def test_dispatcher_trocr_model(self): """Test dispatcher routes to TrOCR.""" from ocr_labeling_api import run_ocr_on_image async def mock_trocr(image_data): return "TrOCR erkannter Text", 0.85 with patch('ocr_labeling_api.TROCR_AVAILABLE', True), \ patch('ocr_labeling_api.run_trocr_ocr', mock_trocr): text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="trocr", ) assert text == "TrOCR erkannter Text" assert confidence == 0.85 @pytest.mark.asyncio async def test_dispatcher_donut_model(self): """Test dispatcher routes to Donut.""" from ocr_labeling_api import run_ocr_on_image async def mock_donut(image_data): return "Donut erkannter Text", 0.80 with patch('ocr_labeling_api.DONUT_AVAILABLE', True), \ patch('ocr_labeling_api.run_donut_ocr', mock_donut): text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="donut", ) assert text == "Donut erkannter Text" assert confidence == 0.80 @pytest.mark.asyncio async def test_dispatcher_unknown_model_uses_vision(self, mock_vision_ocr): """Test dispatcher uses Vision OCR for unknown models.""" from ocr_labeling_api import run_ocr_on_image text, confidence = await run_ocr_on_image( image_data=b"fake image", filename="test.png", model="unknown-model", ) # Unknown model should fall back to Vision OCR assert text == "Erkannter Text aus dem Bild" assert confidence == 0.87 class TestOCRModelTypes: """Tests for OCR model type definitions.""" def test_session_with_paddleocr_model(self): """Test session creation with PaddleOCR model.""" from ocr_labeling_api import SessionCreate session = SessionCreate( name="PaddleOCR Session", source_type="klausur", ocr_model="paddleocr", ) assert session.ocr_model == "paddleocr" def test_session_with_donut_model(self): """Test session creation with Donut model.""" from ocr_labeling_api import SessionCreate session = SessionCreate( name="Donut Session", source_type="scan", ocr_model="donut", ) assert session.ocr_model == "donut" def test_session_with_trocr_model(self): """Test session creation with TrOCR model.""" from ocr_labeling_api import SessionCreate session = SessionCreate( name="TrOCR Session", source_type="handwriting_sample", ocr_model="trocr", ) assert session.ocr_model == "trocr" # ============================================================================= # API Response Model Tests # ============================================================================= class TestResponseModels: """Tests for API response models.""" def test_session_response_model(self): """Test SessionResponse model.""" from ocr_labeling_api import SessionResponse session = SessionResponse( id="session-123", name="Test Session", source_type="klausur", description="Test", ocr_model="llama3.2-vision:11b", total_items=10, labeled_items=5, confirmed_items=3, corrected_items=2, skipped_items=0, created_at=datetime.utcnow(), ) assert session.id == "session-123" assert session.total_items == 10 def test_item_response_model(self): """Test ItemResponse model.""" from ocr_labeling_api import ItemResponse item = ItemResponse( id="item-456", session_id="session-123", session_name="Test Session", image_path="/app/ocr-labeling/session-123/item-456.png", image_url="/api/v1/ocr-label/images/session-123/item-456.png", ocr_text="Test OCR text", ocr_confidence=0.87, ground_truth=None, status="pending", metadata={"page": 1}, created_at=datetime.utcnow(), ) assert item.id == "item-456" assert item.status == "pending" # ============================================================================= # Deduplication Tests # ============================================================================= class TestDeduplication: """Tests for image deduplication.""" def test_hash_based_deduplication(self): """Test that same images produce same hash for deduplication.""" from ocr_labeling_api import compute_image_hash # Same content should be detected as duplicate image1 = b"\x89PNG\x0d\x0a\x1a\x0a test image content" image2 = b"\x89PNG\x0d\x0a\x1a\x0a test image content" hash1 = compute_image_hash(image1) hash2 = compute_image_hash(image2) assert hash1 == hash2 def test_unique_images_different_hash(self): """Test that different images produce different hashes.""" from ocr_labeling_api import compute_image_hash image1 = b"\x89PNG unique content 1" image2 = b"\x89PNG unique content 2" hash1 = compute_image_hash(image1) hash2 = compute_image_hash(image2) assert hash1 != hash2 # ============================================================================= # Integration Tests (require running services) # ============================================================================= class TestOCRLabelingIntegration: """Integration tests - require Ollama, MinIO, PostgreSQL running.""" @pytest.mark.skip(reason="Requires running Ollama with llama3.2-vision") @pytest.mark.asyncio async def test_full_labeling_workflow(self): """Test complete labeling workflow.""" # This would require: # 1. Create session # 2. Upload image # 3. Run OCR # 4. Confirm or correct label # 5. Export training data pass @pytest.mark.skip(reason="Requires running PostgreSQL") @pytest.mark.asyncio async def test_stats_calculation(self): """Test statistics calculation with real data.""" pass @pytest.mark.skip(reason="Requires running MinIO") @pytest.mark.asyncio async def test_image_storage_and_retrieval(self): """Test image upload and download from MinIO.""" pass # ============================================================================= # Run Tests # ============================================================================= if __name__ == "__main__": pytest.main([__file__, "-v"])