Files
breakpilot-lehrer/klausur-service/backend/tests/test_ocr_labeling.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

800 lines
26 KiB
Python

"""
Tests for OCR Labeling API
Tests session management, image upload, labeling workflow, and training export.
BACKLOG: Feature not yet fully integrated - requires external OCR services
See: https://macmini:3002/infrastructure/tests
"""
import pytest
# Mark all tests in this module as expected failures (backlog item)
pytestmark = pytest.mark.xfail(
reason="ocr_labeling requires external services not available in CI - Backlog item",
strict=False # Don't fail if test unexpectedly passes
)
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import io
import json
import hashlib
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def mock_db_pool():
"""Mock PostgreSQL connection pool."""
with patch('metrics_db.get_pool') as mock:
pool = AsyncMock()
conn = AsyncMock()
pool.acquire.return_value.__aenter__.return_value = conn
mock.return_value = pool
yield pool, conn
@pytest.fixture
def mock_minio():
"""Mock MinIO storage functions."""
with patch('ocr_labeling_api.MINIO_AVAILABLE', True), \
patch('ocr_labeling_api.upload_ocr_image') as upload_mock, \
patch('ocr_labeling_api.get_ocr_image') as get_mock:
upload_mock.return_value = "ocr-labeling/session-123/item-456.png"
get_mock.return_value = b"\x89PNG fake image data"
yield upload_mock, get_mock
@pytest.fixture
def mock_vision_ocr():
"""Mock Vision OCR service."""
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', True), \
patch('ocr_labeling_api.get_vision_ocr_service') as mock:
service = AsyncMock()
service.is_available.return_value = True
result = MagicMock()
result.text = "Erkannter Text aus dem Bild"
result.confidence = 0.87
service.extract_text.return_value = result
mock.return_value = service
yield service
@pytest.fixture
def mock_training_export():
"""Mock training export service."""
with patch('ocr_labeling_api.TRAINING_EXPORT_AVAILABLE', True), \
patch('ocr_labeling_api.get_training_export_service') as mock:
service = MagicMock()
export_result = MagicMock()
export_result.export_path = "/app/ocr-exports/generic/20260121_120000"
export_result.manifest_path = "/app/ocr-exports/generic/20260121_120000/manifest.json"
export_result.batch_id = "20260121_120000"
service.export.return_value = export_result
service.list_exports.return_value = [
{"format": "generic", "batch_id": "20260121_120000", "sample_count": 10}
]
mock.return_value = service
yield service
@pytest.fixture
def sample_session():
"""Sample session data."""
return {
"id": "session-123",
"name": "Test Session",
"source_type": "klausur",
"description": "Test description",
"ocr_model": "llama3.2-vision:11b",
"total_items": 5,
"labeled_items": 2,
"confirmed_items": 1,
"corrected_items": 1,
"skipped_items": 0,
"created_at": datetime.utcnow(),
}
@pytest.fixture
def sample_item():
"""Sample labeling item data."""
return {
"id": "item-456",
"session_id": "session-123",
"session_name": "Test Session",
"image_path": "/app/ocr-labeling/session-123/item-456.png",
"ocr_text": "Erkannter Text",
"ocr_confidence": 0.87,
"ground_truth": None,
"status": "pending",
"metadata": {"page": 1},
"created_at": datetime.utcnow(),
}
# =============================================================================
# Session Management Tests
# =============================================================================
class TestSessionCreation:
"""Tests for session creation."""
@pytest.mark.asyncio
async def test_create_session_success(self, mock_db_pool):
"""Test successful session creation."""
from ocr_labeling_api import SessionCreate
from metrics_db import create_ocr_labeling_session
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await create_ocr_labeling_session(
session_id="session-123",
name="Test Session",
source_type="klausur",
description="Test",
ocr_model="llama3.2-vision:11b",
)
# Should call execute to insert
assert pool.acquire.called
def test_session_create_model_validation(self):
"""Test SessionCreate model validation."""
from ocr_labeling_api import SessionCreate
# Valid session
session = SessionCreate(
name="Test Session",
source_type="klausur",
description="Test description",
)
assert session.name == "Test Session"
assert session.source_type == "klausur"
assert session.ocr_model == "llama3.2-vision:11b" # default
def test_session_create_with_custom_model(self):
"""Test SessionCreate with custom OCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="TrOCR Session",
source_type="handwriting_sample",
ocr_model="trocr-base",
)
assert session.ocr_model == "trocr-base"
class TestSessionListing:
"""Tests for session listing."""
@pytest.mark.asyncio
async def test_get_sessions_empty(self):
"""Test getting sessions when none exist."""
from metrics_db import get_ocr_labeling_sessions
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
sessions = await get_ocr_labeling_sessions()
assert sessions == []
@pytest.mark.asyncio
async def test_get_session_not_found(self):
"""Test getting a non-existent session."""
from metrics_db import get_ocr_labeling_session
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
session = await get_ocr_labeling_session("non-existent-id")
assert session is None
# =============================================================================
# Image Upload Tests
# =============================================================================
class TestImageUpload:
"""Tests for image upload functionality."""
def test_compute_image_hash(self):
"""Test image hash computation."""
from ocr_labeling_api import compute_image_hash
image_data = b"\x89PNG fake image data"
hash1 = compute_image_hash(image_data)
hash2 = compute_image_hash(image_data)
# Same data should produce same hash
assert hash1 == hash2
assert len(hash1) == 64 # SHA256 hex length
def test_compute_image_hash_different_data(self):
"""Test that different images produce different hashes."""
from ocr_labeling_api import compute_image_hash
hash1 = compute_image_hash(b"image 1 data")
hash2 = compute_image_hash(b"image 2 data")
assert hash1 != hash2
def test_save_image_locally(self, tmp_path):
"""Test local image saving."""
from ocr_labeling_api import save_image_locally, LOCAL_STORAGE_PATH
# Temporarily override storage path
with patch('ocr_labeling_api.LOCAL_STORAGE_PATH', str(tmp_path)):
from ocr_labeling_api import save_image_locally
image_data = b"\x89PNG fake image data"
filepath = save_image_locally(
session_id="session-123",
item_id="item-456",
image_data=image_data,
extension="png",
)
assert filepath.endswith("item-456.png")
# File should exist
import os
assert os.path.exists(filepath)
def test_get_image_url_local(self):
"""Test URL generation for local images."""
from ocr_labeling_api import get_image_url, LOCAL_STORAGE_PATH
local_path = f"{LOCAL_STORAGE_PATH}/session-123/item-456.png"
url = get_image_url(local_path)
assert url == "/api/v1/ocr-label/images/session-123/item-456.png"
def test_get_image_url_minio(self):
"""Test URL for MinIO images (passthrough)."""
from ocr_labeling_api import get_image_url
minio_path = "ocr-labeling/session-123/item-456.png"
url = get_image_url(minio_path)
# Non-local paths are passed through
assert url == minio_path
# =============================================================================
# Labeling Workflow Tests
# =============================================================================
class TestConfirmLabel:
"""Tests for label confirmation."""
@pytest.mark.asyncio
async def test_confirm_label_success(self, mock_db_pool):
"""Test successful label confirmation."""
from metrics_db import confirm_ocr_label
pool, conn = mock_db_pool
conn.fetchrow.return_value = {"ocr_text": "Test text"}
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await confirm_ocr_label(
item_id="item-456",
labeled_by="admin",
label_time_seconds=5,
)
# Should update item status and ground_truth
assert conn.execute.called
def test_confirm_request_validation(self):
"""Test ConfirmRequest model validation."""
from ocr_labeling_api import ConfirmRequest
request = ConfirmRequest(
item_id="item-456",
label_time_seconds=5,
)
assert request.item_id == "item-456"
assert request.label_time_seconds == 5
class TestCorrectLabel:
"""Tests for label correction."""
@pytest.mark.asyncio
async def test_correct_label_success(self, mock_db_pool):
"""Test successful label correction."""
from metrics_db import correct_ocr_label
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await correct_ocr_label(
item_id="item-456",
ground_truth="Korrigierter Text",
labeled_by="admin",
label_time_seconds=15,
)
# Should update item with corrected ground_truth
assert conn.execute.called
def test_correct_request_validation(self):
"""Test CorrectRequest model validation."""
from ocr_labeling_api import CorrectRequest
request = CorrectRequest(
item_id="item-456",
ground_truth="Korrigierter Text",
label_time_seconds=15,
)
assert request.item_id == "item-456"
assert request.ground_truth == "Korrigierter Text"
class TestSkipItem:
"""Tests for item skipping."""
@pytest.mark.asyncio
async def test_skip_item_success(self, mock_db_pool):
"""Test successful item skip."""
from metrics_db import skip_ocr_item
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await skip_ocr_item(
item_id="item-456",
labeled_by="admin",
)
# Should update item status to skipped
assert conn.execute.called
# =============================================================================
# Statistics Tests
# =============================================================================
class TestLabelingStats:
"""Tests for labeling statistics."""
@pytest.mark.asyncio
async def test_get_stats_no_db(self):
"""Test stats when database is not available."""
from metrics_db import get_ocr_labeling_stats
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
stats = await get_ocr_labeling_stats()
assert "error" in stats or stats.get("total_items", 0) == 0
def test_stats_response_model(self):
"""Test StatsResponse model structure."""
from ocr_labeling_api import StatsResponse
stats = StatsResponse(
total_items=100,
labeled_items=50,
confirmed_items=40,
corrected_items=10,
pending_items=50,
accuracy_rate=0.8,
)
assert stats.total_items == 100
assert stats.accuracy_rate == 0.8
# =============================================================================
# Export Tests
# =============================================================================
class TestTrainingExport:
"""Tests for training data export."""
def test_export_request_validation(self):
"""Test ExportRequest model validation."""
from ocr_labeling_api import ExportRequest
# Default format is generic
request = ExportRequest()
assert request.export_format == "generic"
# TrOCR format
request = ExportRequest(export_format="trocr")
assert request.export_format == "trocr"
# Llama Vision format
request = ExportRequest(export_format="llama_vision")
assert request.export_format == "llama_vision"
@pytest.mark.asyncio
async def test_export_training_samples(self, mock_db_pool):
"""Test training sample export from database."""
from metrics_db import export_training_samples
pool, conn = mock_db_pool
conn.fetch.return_value = [
{
"id": "sample-1",
"image_path": "/app/ocr-labeling/session-123/item-1.png",
"ground_truth": "Text 1",
},
{
"id": "sample-2",
"image_path": "/app/ocr-labeling/session-123/item-2.png",
"ground_truth": "Text 2",
},
]
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
samples = await export_training_samples(
export_format="generic",
exported_by="admin",
)
# Should return exportable samples
assert conn.fetch.called or conn.execute.called
class TestTrainingExportService:
"""Tests for training export service."""
def test_trocr_export_format(self):
"""Test TrOCR export format structure."""
expected_format = {
"file_name": "images/sample-1.png",
"text": "Ground truth text",
"id": "sample-1",
}
assert "file_name" in expected_format
assert "text" in expected_format
def test_llama_vision_export_format(self):
"""Test Llama Vision export format structure."""
expected_format = {
"id": "sample-1",
"messages": [
{"role": "system", "content": "Du bist ein OCR-Experte..."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "..."}},
{"type": "text", "text": "Lies den Text..."},
]},
{"role": "assistant", "content": "Ground truth text"},
],
}
assert "messages" in expected_format
assert len(expected_format["messages"]) == 3
assert expected_format["messages"][2]["role"] == "assistant"
def test_generic_export_format(self):
"""Test generic export format structure."""
expected_format = {
"id": "sample-1",
"image_path": "images/sample-1.png",
"ground_truth": "Ground truth text",
"ocr_text": "OCR recognized text",
"ocr_confidence": 0.87,
"metadata": {},
}
assert "image_path" in expected_format
assert "ground_truth" in expected_format
# =============================================================================
# OCR Processing Tests
# =============================================================================
class TestOCRProcessing:
"""Tests for OCR processing."""
@pytest.mark.asyncio
async def test_run_ocr_on_image_no_service(self):
"""Test OCR when service is not available."""
from ocr_labeling_api import run_ocr_on_image
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', False), \
patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False), \
patch('ocr_labeling_api.TROCR_AVAILABLE', False), \
patch('ocr_labeling_api.DONUT_AVAILABLE', False):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
)
assert text is None
assert confidence == 0.0
@pytest.mark.asyncio
async def test_run_ocr_on_image_success(self, mock_vision_ocr):
"""Test successful OCR processing."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
)
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
# =============================================================================
# OCR Model Dispatcher Tests
# =============================================================================
class TestOCRModelDispatcher:
"""Tests for the OCR model dispatcher (v1.1.0)."""
@pytest.mark.asyncio
async def test_dispatcher_vision_model_default(self, mock_vision_ocr):
"""Test dispatcher uses Vision OCR by default."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="llama3.2-vision:11b",
)
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
@pytest.mark.asyncio
async def test_dispatcher_paddleocr_model(self):
"""Test dispatcher routes to PaddleOCR."""
from ocr_labeling_api import run_ocr_on_image
# Mock PaddleOCR
mock_regions = []
mock_text = "PaddleOCR erkannter Text"
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', True), \
patch('ocr_labeling_api.run_paddle_ocr', return_value=(mock_regions, mock_text)):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="paddleocr",
)
assert text == "PaddleOCR erkannter Text"
@pytest.mark.asyncio
async def test_dispatcher_paddleocr_fallback_to_vision(self, mock_vision_ocr):
"""Test PaddleOCR falls back to Vision OCR when unavailable."""
from ocr_labeling_api import run_ocr_on_image
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="paddleocr",
)
# Should fall back to Vision OCR
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
@pytest.mark.asyncio
async def test_dispatcher_trocr_model(self):
"""Test dispatcher routes to TrOCR."""
from ocr_labeling_api import run_ocr_on_image
async def mock_trocr(image_data):
return "TrOCR erkannter Text", 0.85
with patch('ocr_labeling_api.TROCR_AVAILABLE', True), \
patch('ocr_labeling_api.run_trocr_ocr', mock_trocr):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="trocr",
)
assert text == "TrOCR erkannter Text"
assert confidence == 0.85
@pytest.mark.asyncio
async def test_dispatcher_donut_model(self):
"""Test dispatcher routes to Donut."""
from ocr_labeling_api import run_ocr_on_image
async def mock_donut(image_data):
return "Donut erkannter Text", 0.80
with patch('ocr_labeling_api.DONUT_AVAILABLE', True), \
patch('ocr_labeling_api.run_donut_ocr', mock_donut):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="donut",
)
assert text == "Donut erkannter Text"
assert confidence == 0.80
@pytest.mark.asyncio
async def test_dispatcher_unknown_model_uses_vision(self, mock_vision_ocr):
"""Test dispatcher uses Vision OCR for unknown models."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="unknown-model",
)
# Unknown model should fall back to Vision OCR
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
class TestOCRModelTypes:
"""Tests for OCR model type definitions."""
def test_session_with_paddleocr_model(self):
"""Test session creation with PaddleOCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="PaddleOCR Session",
source_type="klausur",
ocr_model="paddleocr",
)
assert session.ocr_model == "paddleocr"
def test_session_with_donut_model(self):
"""Test session creation with Donut model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="Donut Session",
source_type="scan",
ocr_model="donut",
)
assert session.ocr_model == "donut"
def test_session_with_trocr_model(self):
"""Test session creation with TrOCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="TrOCR Session",
source_type="handwriting_sample",
ocr_model="trocr",
)
assert session.ocr_model == "trocr"
# =============================================================================
# API Response Model Tests
# =============================================================================
class TestResponseModels:
"""Tests for API response models."""
def test_session_response_model(self):
"""Test SessionResponse model."""
from ocr_labeling_api import SessionResponse
session = SessionResponse(
id="session-123",
name="Test Session",
source_type="klausur",
description="Test",
ocr_model="llama3.2-vision:11b",
total_items=10,
labeled_items=5,
confirmed_items=3,
corrected_items=2,
skipped_items=0,
created_at=datetime.utcnow(),
)
assert session.id == "session-123"
assert session.total_items == 10
def test_item_response_model(self):
"""Test ItemResponse model."""
from ocr_labeling_api import ItemResponse
item = ItemResponse(
id="item-456",
session_id="session-123",
session_name="Test Session",
image_path="/app/ocr-labeling/session-123/item-456.png",
image_url="/api/v1/ocr-label/images/session-123/item-456.png",
ocr_text="Test OCR text",
ocr_confidence=0.87,
ground_truth=None,
status="pending",
metadata={"page": 1},
created_at=datetime.utcnow(),
)
assert item.id == "item-456"
assert item.status == "pending"
# =============================================================================
# Deduplication Tests
# =============================================================================
class TestDeduplication:
"""Tests for image deduplication."""
def test_hash_based_deduplication(self):
"""Test that same images produce same hash for deduplication."""
from ocr_labeling_api import compute_image_hash
# Same content should be detected as duplicate
image1 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
image2 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
hash1 = compute_image_hash(image1)
hash2 = compute_image_hash(image2)
assert hash1 == hash2
def test_unique_images_different_hash(self):
"""Test that different images produce different hashes."""
from ocr_labeling_api import compute_image_hash
image1 = b"\x89PNG unique content 1"
image2 = b"\x89PNG unique content 2"
hash1 = compute_image_hash(image1)
hash2 = compute_image_hash(image2)
assert hash1 != hash2
# =============================================================================
# Integration Tests (require running services)
# =============================================================================
class TestOCRLabelingIntegration:
"""Integration tests - require Ollama, MinIO, PostgreSQL running."""
@pytest.mark.skip(reason="Requires running Ollama with llama3.2-vision")
@pytest.mark.asyncio
async def test_full_labeling_workflow(self):
"""Test complete labeling workflow."""
# This would require:
# 1. Create session
# 2. Upload image
# 3. Run OCR
# 4. Confirm or correct label
# 5. Export training data
pass
@pytest.mark.skip(reason="Requires running PostgreSQL")
@pytest.mark.asyncio
async def test_stats_calculation(self):
"""Test statistics calculation with real data."""
pass
@pytest.mark.skip(reason="Requires running MinIO")
@pytest.mark.asyncio
async def test_image_storage_and_retrieval(self):
"""Test image upload and download from MinIO."""
pass
# =============================================================================
# Run Tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])