Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
799
klausur-service/backend/tests/test_ocr_labeling.py
Normal file
799
klausur-service/backend/tests/test_ocr_labeling.py
Normal file
@@ -0,0 +1,799 @@
|
||||
"""
|
||||
Tests for OCR Labeling API
|
||||
Tests session management, image upload, labeling workflow, and training export.
|
||||
|
||||
BACKLOG: Feature not yet fully integrated - requires external OCR services
|
||||
See: https://macmini:3002/infrastructure/tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
# Mark all tests in this module as expected failures (backlog item)
|
||||
pytestmark = pytest.mark.xfail(
|
||||
reason="ocr_labeling requires external services not available in CI - Backlog item",
|
||||
strict=False # Don't fail if test unexpectedly passes
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
import io
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_pool():
|
||||
"""Mock PostgreSQL connection pool."""
|
||||
with patch('metrics_db.get_pool') as mock:
|
||||
pool = AsyncMock()
|
||||
conn = AsyncMock()
|
||||
pool.acquire.return_value.__aenter__.return_value = conn
|
||||
mock.return_value = pool
|
||||
yield pool, conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_minio():
|
||||
"""Mock MinIO storage functions."""
|
||||
with patch('ocr_labeling_api.MINIO_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.upload_ocr_image') as upload_mock, \
|
||||
patch('ocr_labeling_api.get_ocr_image') as get_mock:
|
||||
upload_mock.return_value = "ocr-labeling/session-123/item-456.png"
|
||||
get_mock.return_value = b"\x89PNG fake image data"
|
||||
yield upload_mock, get_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vision_ocr():
|
||||
"""Mock Vision OCR service."""
|
||||
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.get_vision_ocr_service') as mock:
|
||||
service = AsyncMock()
|
||||
service.is_available.return_value = True
|
||||
result = MagicMock()
|
||||
result.text = "Erkannter Text aus dem Bild"
|
||||
result.confidence = 0.87
|
||||
service.extract_text.return_value = result
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_training_export():
|
||||
"""Mock training export service."""
|
||||
with patch('ocr_labeling_api.TRAINING_EXPORT_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.get_training_export_service') as mock:
|
||||
service = MagicMock()
|
||||
export_result = MagicMock()
|
||||
export_result.export_path = "/app/ocr-exports/generic/20260121_120000"
|
||||
export_result.manifest_path = "/app/ocr-exports/generic/20260121_120000/manifest.json"
|
||||
export_result.batch_id = "20260121_120000"
|
||||
service.export.return_value = export_result
|
||||
service.list_exports.return_value = [
|
||||
{"format": "generic", "batch_id": "20260121_120000", "sample_count": 10}
|
||||
]
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session():
|
||||
"""Sample session data."""
|
||||
return {
|
||||
"id": "session-123",
|
||||
"name": "Test Session",
|
||||
"source_type": "klausur",
|
||||
"description": "Test description",
|
||||
"ocr_model": "llama3.2-vision:11b",
|
||||
"total_items": 5,
|
||||
"labeled_items": 2,
|
||||
"confirmed_items": 1,
|
||||
"corrected_items": 1,
|
||||
"skipped_items": 0,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_item():
|
||||
"""Sample labeling item data."""
|
||||
return {
|
||||
"id": "item-456",
|
||||
"session_id": "session-123",
|
||||
"session_name": "Test Session",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-456.png",
|
||||
"ocr_text": "Erkannter Text",
|
||||
"ocr_confidence": 0.87,
|
||||
"ground_truth": None,
|
||||
"status": "pending",
|
||||
"metadata": {"page": 1},
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Management Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestSessionCreation:
|
||||
"""Tests for session creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, mock_db_pool):
|
||||
"""Test successful session creation."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
from metrics_db import create_ocr_labeling_session
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await create_ocr_labeling_session(
|
||||
session_id="session-123",
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test",
|
||||
ocr_model="llama3.2-vision:11b",
|
||||
)
|
||||
|
||||
# Should call execute to insert
|
||||
assert pool.acquire.called
|
||||
|
||||
def test_session_create_model_validation(self):
|
||||
"""Test SessionCreate model validation."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
# Valid session
|
||||
session = SessionCreate(
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test description",
|
||||
)
|
||||
assert session.name == "Test Session"
|
||||
assert session.source_type == "klausur"
|
||||
assert session.ocr_model == "llama3.2-vision:11b" # default
|
||||
|
||||
def test_session_create_with_custom_model(self):
|
||||
"""Test SessionCreate with custom OCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="TrOCR Session",
|
||||
source_type="handwriting_sample",
|
||||
ocr_model="trocr-base",
|
||||
)
|
||||
assert session.ocr_model == "trocr-base"
|
||||
|
||||
|
||||
class TestSessionListing:
|
||||
"""Tests for session listing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sessions_empty(self):
|
||||
"""Test getting sessions when none exist."""
|
||||
from metrics_db import get_ocr_labeling_sessions
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
sessions = await get_ocr_labeling_sessions()
|
||||
assert sessions == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_not_found(self):
|
||||
"""Test getting a non-existent session."""
|
||||
from metrics_db import get_ocr_labeling_session
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
session = await get_ocr_labeling_session("non-existent-id")
|
||||
assert session is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Upload Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestImageUpload:
|
||||
"""Tests for image upload functionality."""
|
||||
|
||||
def test_compute_image_hash(self):
|
||||
"""Test image hash computation."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
image_data = b"\x89PNG fake image data"
|
||||
hash1 = compute_image_hash(image_data)
|
||||
hash2 = compute_image_hash(image_data)
|
||||
|
||||
# Same data should produce same hash
|
||||
assert hash1 == hash2
|
||||
assert len(hash1) == 64 # SHA256 hex length
|
||||
|
||||
def test_compute_image_hash_different_data(self):
|
||||
"""Test that different images produce different hashes."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
hash1 = compute_image_hash(b"image 1 data")
|
||||
hash2 = compute_image_hash(b"image 2 data")
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_save_image_locally(self, tmp_path):
|
||||
"""Test local image saving."""
|
||||
from ocr_labeling_api import save_image_locally, LOCAL_STORAGE_PATH
|
||||
|
||||
# Temporarily override storage path
|
||||
with patch('ocr_labeling_api.LOCAL_STORAGE_PATH', str(tmp_path)):
|
||||
from ocr_labeling_api import save_image_locally
|
||||
|
||||
image_data = b"\x89PNG fake image data"
|
||||
filepath = save_image_locally(
|
||||
session_id="session-123",
|
||||
item_id="item-456",
|
||||
image_data=image_data,
|
||||
extension="png",
|
||||
)
|
||||
|
||||
assert filepath.endswith("item-456.png")
|
||||
# File should exist
|
||||
import os
|
||||
assert os.path.exists(filepath)
|
||||
|
||||
def test_get_image_url_local(self):
|
||||
"""Test URL generation for local images."""
|
||||
from ocr_labeling_api import get_image_url, LOCAL_STORAGE_PATH
|
||||
|
||||
local_path = f"{LOCAL_STORAGE_PATH}/session-123/item-456.png"
|
||||
url = get_image_url(local_path)
|
||||
|
||||
assert url == "/api/v1/ocr-label/images/session-123/item-456.png"
|
||||
|
||||
def test_get_image_url_minio(self):
|
||||
"""Test URL for MinIO images (passthrough)."""
|
||||
from ocr_labeling_api import get_image_url
|
||||
|
||||
minio_path = "ocr-labeling/session-123/item-456.png"
|
||||
url = get_image_url(minio_path)
|
||||
|
||||
# Non-local paths are passed through
|
||||
assert url == minio_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Labeling Workflow Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestConfirmLabel:
|
||||
"""Tests for label confirmation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_label_success(self, mock_db_pool):
|
||||
"""Test successful label confirmation."""
|
||||
from metrics_db import confirm_ocr_label
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.fetchrow.return_value = {"ocr_text": "Test text"}
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await confirm_ocr_label(
|
||||
item_id="item-456",
|
||||
labeled_by="admin",
|
||||
label_time_seconds=5,
|
||||
)
|
||||
|
||||
# Should update item status and ground_truth
|
||||
assert conn.execute.called
|
||||
|
||||
def test_confirm_request_validation(self):
|
||||
"""Test ConfirmRequest model validation."""
|
||||
from ocr_labeling_api import ConfirmRequest
|
||||
|
||||
request = ConfirmRequest(
|
||||
item_id="item-456",
|
||||
label_time_seconds=5,
|
||||
)
|
||||
assert request.item_id == "item-456"
|
||||
assert request.label_time_seconds == 5
|
||||
|
||||
|
||||
class TestCorrectLabel:
|
||||
"""Tests for label correction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correct_label_success(self, mock_db_pool):
|
||||
"""Test successful label correction."""
|
||||
from metrics_db import correct_ocr_label
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await correct_ocr_label(
|
||||
item_id="item-456",
|
||||
ground_truth="Korrigierter Text",
|
||||
labeled_by="admin",
|
||||
label_time_seconds=15,
|
||||
)
|
||||
|
||||
# Should update item with corrected ground_truth
|
||||
assert conn.execute.called
|
||||
|
||||
def test_correct_request_validation(self):
|
||||
"""Test CorrectRequest model validation."""
|
||||
from ocr_labeling_api import CorrectRequest
|
||||
|
||||
request = CorrectRequest(
|
||||
item_id="item-456",
|
||||
ground_truth="Korrigierter Text",
|
||||
label_time_seconds=15,
|
||||
)
|
||||
assert request.item_id == "item-456"
|
||||
assert request.ground_truth == "Korrigierter Text"
|
||||
|
||||
|
||||
class TestSkipItem:
|
||||
"""Tests for item skipping."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_item_success(self, mock_db_pool):
|
||||
"""Test successful item skip."""
|
||||
from metrics_db import skip_ocr_item
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await skip_ocr_item(
|
||||
item_id="item-456",
|
||||
labeled_by="admin",
|
||||
)
|
||||
|
||||
# Should update item status to skipped
|
||||
assert conn.execute.called
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestLabelingStats:
|
||||
"""Tests for labeling statistics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_no_db(self):
|
||||
"""Test stats when database is not available."""
|
||||
from metrics_db import get_ocr_labeling_stats
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
stats = await get_ocr_labeling_stats()
|
||||
assert "error" in stats or stats.get("total_items", 0) == 0
|
||||
|
||||
def test_stats_response_model(self):
|
||||
"""Test StatsResponse model structure."""
|
||||
from ocr_labeling_api import StatsResponse
|
||||
|
||||
stats = StatsResponse(
|
||||
total_items=100,
|
||||
labeled_items=50,
|
||||
confirmed_items=40,
|
||||
corrected_items=10,
|
||||
pending_items=50,
|
||||
accuracy_rate=0.8,
|
||||
)
|
||||
|
||||
assert stats.total_items == 100
|
||||
assert stats.accuracy_rate == 0.8
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTrainingExport:
|
||||
"""Tests for training data export."""
|
||||
|
||||
def test_export_request_validation(self):
|
||||
"""Test ExportRequest model validation."""
|
||||
from ocr_labeling_api import ExportRequest
|
||||
|
||||
# Default format is generic
|
||||
request = ExportRequest()
|
||||
assert request.export_format == "generic"
|
||||
|
||||
# TrOCR format
|
||||
request = ExportRequest(export_format="trocr")
|
||||
assert request.export_format == "trocr"
|
||||
|
||||
# Llama Vision format
|
||||
request = ExportRequest(export_format="llama_vision")
|
||||
assert request.export_format == "llama_vision"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_training_samples(self, mock_db_pool):
|
||||
"""Test training sample export from database."""
|
||||
from metrics_db import export_training_samples
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.fetch.return_value = [
|
||||
{
|
||||
"id": "sample-1",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-1.png",
|
||||
"ground_truth": "Text 1",
|
||||
},
|
||||
{
|
||||
"id": "sample-2",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-2.png",
|
||||
"ground_truth": "Text 2",
|
||||
},
|
||||
]
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
samples = await export_training_samples(
|
||||
export_format="generic",
|
||||
exported_by="admin",
|
||||
)
|
||||
|
||||
# Should return exportable samples
|
||||
assert conn.fetch.called or conn.execute.called
|
||||
|
||||
|
||||
class TestTrainingExportService:
|
||||
"""Tests for training export service."""
|
||||
|
||||
def test_trocr_export_format(self):
|
||||
"""Test TrOCR export format structure."""
|
||||
expected_format = {
|
||||
"file_name": "images/sample-1.png",
|
||||
"text": "Ground truth text",
|
||||
"id": "sample-1",
|
||||
}
|
||||
|
||||
assert "file_name" in expected_format
|
||||
assert "text" in expected_format
|
||||
|
||||
def test_llama_vision_export_format(self):
|
||||
"""Test Llama Vision export format structure."""
|
||||
expected_format = {
|
||||
"id": "sample-1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Du bist ein OCR-Experte..."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "..."}},
|
||||
{"type": "text", "text": "Lies den Text..."},
|
||||
]},
|
||||
{"role": "assistant", "content": "Ground truth text"},
|
||||
],
|
||||
}
|
||||
|
||||
assert "messages" in expected_format
|
||||
assert len(expected_format["messages"]) == 3
|
||||
assert expected_format["messages"][2]["role"] == "assistant"
|
||||
|
||||
def test_generic_export_format(self):
|
||||
"""Test generic export format structure."""
|
||||
expected_format = {
|
||||
"id": "sample-1",
|
||||
"image_path": "images/sample-1.png",
|
||||
"ground_truth": "Ground truth text",
|
||||
"ocr_text": "OCR recognized text",
|
||||
"ocr_confidence": 0.87,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
assert "image_path" in expected_format
|
||||
assert "ground_truth" in expected_format
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR Processing Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRProcessing:
|
||||
"""Tests for OCR processing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ocr_on_image_no_service(self):
|
||||
"""Test OCR when service is not available."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.TROCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.DONUT_AVAILABLE', False):
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
assert text is None
|
||||
assert confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ocr_on_image_success(self, mock_vision_ocr):
|
||||
"""Test successful OCR processing."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR Model Dispatcher Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRModelDispatcher:
|
||||
"""Tests for the OCR model dispatcher (v1.1.0)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_vision_model_default(self, mock_vision_ocr):
|
||||
"""Test dispatcher uses Vision OCR by default."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="llama3.2-vision:11b",
|
||||
)
|
||||
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_paddleocr_model(self):
|
||||
"""Test dispatcher routes to PaddleOCR."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
# Mock PaddleOCR
|
||||
mock_regions = []
|
||||
mock_text = "PaddleOCR erkannter Text"
|
||||
|
||||
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_paddle_ocr', return_value=(mock_regions, mock_text)):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="paddleocr",
|
||||
)
|
||||
|
||||
assert text == "PaddleOCR erkannter Text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_paddleocr_fallback_to_vision(self, mock_vision_ocr):
|
||||
"""Test PaddleOCR falls back to Vision OCR when unavailable."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False):
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="paddleocr",
|
||||
)
|
||||
|
||||
# Should fall back to Vision OCR
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_trocr_model(self):
|
||||
"""Test dispatcher routes to TrOCR."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
async def mock_trocr(image_data):
|
||||
return "TrOCR erkannter Text", 0.85
|
||||
|
||||
with patch('ocr_labeling_api.TROCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_trocr_ocr', mock_trocr):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="trocr",
|
||||
)
|
||||
|
||||
assert text == "TrOCR erkannter Text"
|
||||
assert confidence == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_donut_model(self):
|
||||
"""Test dispatcher routes to Donut."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
async def mock_donut(image_data):
|
||||
return "Donut erkannter Text", 0.80
|
||||
|
||||
with patch('ocr_labeling_api.DONUT_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_donut_ocr', mock_donut):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="donut",
|
||||
)
|
||||
|
||||
assert text == "Donut erkannter Text"
|
||||
assert confidence == 0.80
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_unknown_model_uses_vision(self, mock_vision_ocr):
|
||||
"""Test dispatcher uses Vision OCR for unknown models."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="unknown-model",
|
||||
)
|
||||
|
||||
# Unknown model should fall back to Vision OCR
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
|
||||
class TestOCRModelTypes:
|
||||
"""Tests for OCR model type definitions."""
|
||||
|
||||
def test_session_with_paddleocr_model(self):
|
||||
"""Test session creation with PaddleOCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="PaddleOCR Session",
|
||||
source_type="klausur",
|
||||
ocr_model="paddleocr",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "paddleocr"
|
||||
|
||||
def test_session_with_donut_model(self):
|
||||
"""Test session creation with Donut model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="Donut Session",
|
||||
source_type="scan",
|
||||
ocr_model="donut",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "donut"
|
||||
|
||||
def test_session_with_trocr_model(self):
|
||||
"""Test session creation with TrOCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="TrOCR Session",
|
||||
source_type="handwriting_sample",
|
||||
ocr_model="trocr",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "trocr"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Response Model Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestResponseModels:
|
||||
"""Tests for API response models."""
|
||||
|
||||
def test_session_response_model(self):
|
||||
"""Test SessionResponse model."""
|
||||
from ocr_labeling_api import SessionResponse
|
||||
|
||||
session = SessionResponse(
|
||||
id="session-123",
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test",
|
||||
ocr_model="llama3.2-vision:11b",
|
||||
total_items=10,
|
||||
labeled_items=5,
|
||||
confirmed_items=3,
|
||||
corrected_items=2,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert session.id == "session-123"
|
||||
assert session.total_items == 10
|
||||
|
||||
def test_item_response_model(self):
|
||||
"""Test ItemResponse model."""
|
||||
from ocr_labeling_api import ItemResponse
|
||||
|
||||
item = ItemResponse(
|
||||
id="item-456",
|
||||
session_id="session-123",
|
||||
session_name="Test Session",
|
||||
image_path="/app/ocr-labeling/session-123/item-456.png",
|
||||
image_url="/api/v1/ocr-label/images/session-123/item-456.png",
|
||||
ocr_text="Test OCR text",
|
||||
ocr_confidence=0.87,
|
||||
ground_truth=None,
|
||||
status="pending",
|
||||
metadata={"page": 1},
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert item.id == "item-456"
|
||||
assert item.status == "pending"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Deduplication Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestDeduplication:
|
||||
"""Tests for image deduplication."""
|
||||
|
||||
def test_hash_based_deduplication(self):
|
||||
"""Test that same images produce same hash for deduplication."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
# Same content should be detected as duplicate
|
||||
image1 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
|
||||
image2 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
|
||||
|
||||
hash1 = compute_image_hash(image1)
|
||||
hash2 = compute_image_hash(image2)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_unique_images_different_hash(self):
|
||||
"""Test that different images produce different hashes."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
image1 = b"\x89PNG unique content 1"
|
||||
image2 = b"\x89PNG unique content 2"
|
||||
|
||||
hash1 = compute_image_hash(image1)
|
||||
hash2 = compute_image_hash(image2)
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests (require running services)
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRLabelingIntegration:
|
||||
"""Integration tests - require Ollama, MinIO, PostgreSQL running."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires running Ollama with llama3.2-vision")
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_labeling_workflow(self):
|
||||
"""Test complete labeling workflow."""
|
||||
# This would require:
|
||||
# 1. Create session
|
||||
# 2. Upload image
|
||||
# 3. Run OCR
|
||||
# 4. Confirm or correct label
|
||||
# 5. Export training data
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Requires running PostgreSQL")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_calculation(self):
|
||||
"""Test statistics calculation with real data."""
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Requires running MinIO")
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_storage_and_retrieval(self):
|
||||
"""Test image upload and download from MinIO."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run Tests
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user