Files
breakpilot-lehrer/klausur-service/backend/ocr_labeling_api.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

846 lines
26 KiB
Python

"""
OCR Labeling API for Handwriting Training Data Collection
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
- Bilder werden mit SHA256-Hash dedupliziert
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
Endpoints:
- POST /sessions - Create labeling session
- POST /sessions/{id}/upload - Upload images for labeling
- GET /queue - Get next items to label
- POST /confirm - Confirm OCR as correct
- POST /correct - Save corrected ground truth
- POST /skip - Skip unusable item
- GET /stats - Get labeling statistics
- POST /export - Export training data
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from datetime import datetime
import uuid
import hashlib
import os
import base64
# Import database functions
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
export_training_samples,
get_training_samples,
)
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Try to import PaddleOCR from hybrid_vocab_extractor
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path
# =============================================================================
# API Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""
Create a new OCR labeling session.
A session groups related images for labeling (e.g., all scans from one class).
"""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None), # JSON string
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
# Verify session exists
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Parse metadata
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
# Read file content
content = await file.read()
# Compute hash for deduplication
image_hash = compute_image_hash(content)
# Generate item ID
item_id = str(uuid.uuid4())
# Determine file extension
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
# Save image
if MINIO_AVAILABLE:
# Upload to MinIO
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
# Save locally
image_path = save_image_locally(session_id, item_id, content, extension)
# Run OCR if requested
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
# Add to database
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""
Get items from the labeling queue.
Args:
session_id: Optional filter by session
status: Filter by status (pending, confirmed, corrected, skipped)
limit: Number of items to return
"""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""
Confirm that OCR text is correct.
Sets ground_truth = ocr_text and marks item as confirmed.
"""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""
Save corrected ground truth for an item.
Use this when OCR text is wrong and needs manual correction.
"""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""
Skip an item (unusable image, etc.).
Skipped items are not included in training exports.
"""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""
Get labeling statistics.
Args:
session_id: Optional session ID for session-specific stats
"""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats
@router.post("/export")
async def export_data(request: ExportRequest):
"""
Export labeled data for training.
Formats:
- generic: JSONL with image_path and ground_truth
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
- llama_vision: Format for llama3.2-vision fine-tuning
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
"""
# First, get samples from database
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin", # TODO: Get from auth
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
# If training export service is available, also write to disk
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
# Convert DB samples to TrainingSample objects
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
# Export to files
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
# Continue without file export
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""
Serve an image from local storage.
This endpoint is used when images are stored locally (not in MinIO).
"""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
# Determine content type
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""
Run OCR on an existing item.
Use this to re-run OCR or run it if it was skipped during upload.
"""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
# Load image
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
# Load from local storage
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
# Load from MinIO
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
# Get OCR model from session
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
# Run OCR
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
# Update item in database
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""
List all available training data exports.
Args:
export_format: Optional filter by format (generic, trocr, llama_vision)
Returns:
List of export manifests with paths and metadata
"""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")