""" OCR Labeling API for Handwriting Training Data Collection DATENSCHUTZ/PRIVACY: - Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama) - Keine Daten werden an externe Server gesendet - Bilder werden mit SHA256-Hash dedupliziert - Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision) Endpoints: - POST /sessions - Create labeling session - POST /sessions/{id}/upload - Upload images for labeling - GET /queue - Get next items to label - POST /confirm - Confirm OCR as correct - POST /correct - Save corrected ground truth - POST /skip - Skip unusable item - GET /stats - Get labeling statistics - POST /export - Export training data """ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks from pydantic import BaseModel from typing import Optional, List, Dict, Any from datetime import datetime import uuid import hashlib import os import base64 # Import database functions from metrics_db import ( create_ocr_labeling_session, get_ocr_labeling_sessions, get_ocr_labeling_session, add_ocr_labeling_item, get_ocr_labeling_queue, get_ocr_labeling_item, confirm_ocr_label, correct_ocr_label, skip_ocr_item, get_ocr_labeling_stats, export_training_samples, get_training_samples, ) # Try to import Vision OCR service try: import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services')) from vision_ocr_service import get_vision_ocr_service, VisionOCRService VISION_OCR_AVAILABLE = True except ImportError: VISION_OCR_AVAILABLE = False print("Warning: Vision OCR service not available") # Try to import PaddleOCR from hybrid_vocab_extractor try: from hybrid_vocab_extractor import run_paddle_ocr PADDLEOCR_AVAILABLE = True except ImportError: PADDLEOCR_AVAILABLE = False print("Warning: PaddleOCR not available") # Try to import TrOCR service try: from services.trocr_service import run_trocr_ocr TROCR_AVAILABLE = True except ImportError: TROCR_AVAILABLE = False print("Warning: TrOCR service not available") # Try to import Donut service try: from services.donut_ocr_service import run_donut_ocr DONUT_AVAILABLE = True except ImportError: DONUT_AVAILABLE = False print("Warning: Donut OCR service not available") # Try to import MinIO storage try: from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET MINIO_AVAILABLE = True except ImportError: MINIO_AVAILABLE = False print("Warning: MinIO storage not available, using local storage") # Try to import Training Export Service try: from training_export_service import ( TrainingExportService, TrainingSample, get_training_export_service, ) TRAINING_EXPORT_AVAILABLE = True except ImportError: TRAINING_EXPORT_AVAILABLE = False print("Warning: Training export service not available") router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) # Local storage path (fallback if MinIO not available) LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling") # ============================================================================= # Pydantic Models # ============================================================================= class SessionCreate(BaseModel): name: str source_type: str = "klausur" # klausur, handwriting_sample, scan description: Optional[str] = None ocr_model: Optional[str] = "llama3.2-vision:11b" class SessionResponse(BaseModel): id: str name: str source_type: str description: Optional[str] ocr_model: Optional[str] total_items: int labeled_items: int confirmed_items: int corrected_items: int skipped_items: int created_at: datetime class ItemResponse(BaseModel): id: str session_id: str session_name: str image_path: str image_url: Optional[str] ocr_text: Optional[str] ocr_confidence: Optional[float] ground_truth: Optional[str] status: str metadata: Optional[Dict] created_at: datetime class ConfirmRequest(BaseModel): item_id: str label_time_seconds: Optional[int] = None class CorrectRequest(BaseModel): item_id: str ground_truth: str label_time_seconds: Optional[int] = None class SkipRequest(BaseModel): item_id: str class ExportRequest(BaseModel): export_format: str = "generic" # generic, trocr, llama_vision session_id: Optional[str] = None batch_id: Optional[str] = None class StatsResponse(BaseModel): total_sessions: Optional[int] = None total_items: int labeled_items: int confirmed_items: int corrected_items: int pending_items: int exportable_items: Optional[int] = None accuracy_rate: float avg_label_time_seconds: Optional[float] = None # ============================================================================= # Helper Functions # ============================================================================= def compute_image_hash(image_data: bytes) -> str: """Compute SHA256 hash of image data.""" return hashlib.sha256(image_data).hexdigest() async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple: """ Run OCR on an image using the specified model. Models: - llama3.2-vision:11b: Vision LLM (default, best for handwriting) - trocr: Microsoft TrOCR (fast for printed text) - paddleocr: PaddleOCR + LLM hybrid (4x faster) - donut: Document Understanding Transformer (structured documents) Returns: Tuple of (ocr_text, confidence) """ print(f"Running OCR with model: {model}") # Route to appropriate OCR service based on model if model == "paddleocr": return await run_paddleocr_wrapper(image_data, filename) elif model == "donut": return await run_donut_wrapper(image_data, filename) elif model == "trocr": return await run_trocr_wrapper(image_data, filename) else: # Default: Vision LLM (llama3.2-vision or similar) return await run_vision_ocr_wrapper(image_data, filename) async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple: """Vision LLM OCR wrapper.""" if not VISION_OCR_AVAILABLE: print("Vision OCR service not available") return None, 0.0 try: service = get_vision_ocr_service() if not await service.is_available(): print("Vision OCR service not available (is_available check failed)") return None, 0.0 result = await service.extract_text( image_data, filename=filename, is_handwriting=True ) return result.text, result.confidence except Exception as e: print(f"Vision OCR failed: {e}") return None, 0.0 async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple: """PaddleOCR wrapper - uses hybrid_vocab_extractor.""" if not PADDLEOCR_AVAILABLE: print("PaddleOCR not available, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) try: # run_paddle_ocr returns (regions, raw_text) regions, raw_text = run_paddle_ocr(image_data) if not raw_text: print("PaddleOCR returned empty text") return None, 0.0 # Calculate average confidence from regions if regions: avg_confidence = sum(r.confidence for r in regions) / len(regions) else: avg_confidence = 0.5 return raw_text, avg_confidence except Exception as e: print(f"PaddleOCR failed: {e}, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple: """TrOCR wrapper.""" if not TROCR_AVAILABLE: print("TrOCR not available, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) try: text, confidence = await run_trocr_ocr(image_data) return text, confidence except Exception as e: print(f"TrOCR failed: {e}, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple: """Donut OCR wrapper.""" if not DONUT_AVAILABLE: print("Donut not available, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) try: text, confidence = await run_donut_ocr(image_data) return text, confidence except Exception as e: print(f"Donut OCR failed: {e}, falling back to Vision OCR") return await run_vision_ocr_wrapper(image_data, filename) def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str: """Save image to local storage.""" session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id) os.makedirs(session_dir, exist_ok=True) filename = f"{item_id}.{extension}" filepath = os.path.join(session_dir, filename) with open(filepath, 'wb') as f: f.write(image_data) return filepath def get_image_url(image_path: str) -> str: """Get URL for an image.""" # For local images, return a relative path that the frontend can use if image_path.startswith(LOCAL_STORAGE_PATH): relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/') return f"/api/v1/ocr-label/images/{relative_path}" # For MinIO images, the path is already a URL or key return image_path # ============================================================================= # API Endpoints # ============================================================================= @router.post("/sessions", response_model=SessionResponse) async def create_session(session: SessionCreate): """ Create a new OCR labeling session. A session groups related images for labeling (e.g., all scans from one class). """ session_id = str(uuid.uuid4()) success = await create_ocr_labeling_session( session_id=session_id, name=session.name, source_type=session.source_type, description=session.description, ocr_model=session.ocr_model, ) if not success: raise HTTPException(status_code=500, detail="Failed to create session") return SessionResponse( id=session_id, name=session.name, source_type=session.source_type, description=session.description, ocr_model=session.ocr_model, total_items=0, labeled_items=0, confirmed_items=0, corrected_items=0, skipped_items=0, created_at=datetime.utcnow(), ) @router.get("/sessions", response_model=List[SessionResponse]) async def list_sessions(limit: int = Query(50, ge=1, le=100)): """List all OCR labeling sessions.""" sessions = await get_ocr_labeling_sessions(limit=limit) return [ SessionResponse( id=s['id'], name=s['name'], source_type=s['source_type'], description=s.get('description'), ocr_model=s.get('ocr_model'), total_items=s.get('total_items', 0), labeled_items=s.get('labeled_items', 0), confirmed_items=s.get('confirmed_items', 0), corrected_items=s.get('corrected_items', 0), skipped_items=s.get('skipped_items', 0), created_at=s.get('created_at', datetime.utcnow()), ) for s in sessions ] @router.get("/sessions/{session_id}", response_model=SessionResponse) async def get_session(session_id: str): """Get a specific OCR labeling session.""" session = await get_ocr_labeling_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return SessionResponse( id=session['id'], name=session['name'], source_type=session['source_type'], description=session.get('description'), ocr_model=session.get('ocr_model'), total_items=session.get('total_items', 0), labeled_items=session.get('labeled_items', 0), confirmed_items=session.get('confirmed_items', 0), corrected_items=session.get('corrected_items', 0), skipped_items=session.get('skipped_items', 0), created_at=session.get('created_at', datetime.utcnow()), ) @router.post("/sessions/{session_id}/upload") async def upload_images( session_id: str, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), run_ocr: bool = Form(True), metadata: Optional[str] = Form(None), # JSON string ): """ Upload images to a labeling session. Args: session_id: Session to add images to files: Image files to upload (PNG, JPG, PDF) run_ocr: Whether to run OCR immediately (default: True) metadata: Optional JSON metadata (subject, year, etc.) """ import json # Verify session exists session = await get_ocr_labeling_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Parse metadata meta_dict = None if metadata: try: meta_dict = json.loads(metadata) except json.JSONDecodeError: meta_dict = {"raw": metadata} results = [] ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') for file in files: # Read file content content = await file.read() # Compute hash for deduplication image_hash = compute_image_hash(content) # Generate item ID item_id = str(uuid.uuid4()) # Determine file extension extension = file.filename.split('.')[-1].lower() if file.filename else 'png' if extension not in ['png', 'jpg', 'jpeg', 'pdf']: extension = 'png' # Save image if MINIO_AVAILABLE: # Upload to MinIO try: image_path = upload_ocr_image(session_id, item_id, content, extension) except Exception as e: print(f"MinIO upload failed, using local storage: {e}") image_path = save_image_locally(session_id, item_id, content, extension) else: # Save locally image_path = save_image_locally(session_id, item_id, content, extension) # Run OCR if requested ocr_text = None ocr_confidence = None if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now ocr_text, ocr_confidence = await run_ocr_on_image( content, file.filename or f"{item_id}.{extension}", model=ocr_model ) # Add to database success = await add_ocr_labeling_item( item_id=item_id, session_id=session_id, image_path=image_path, image_hash=image_hash, ocr_text=ocr_text, ocr_confidence=ocr_confidence, ocr_model=ocr_model if ocr_text else None, metadata=meta_dict, ) if success: results.append({ "id": item_id, "filename": file.filename, "image_path": image_path, "image_hash": image_hash, "ocr_text": ocr_text, "ocr_confidence": ocr_confidence, "status": "pending", }) return { "session_id": session_id, "uploaded_count": len(results), "items": results, } @router.get("/queue", response_model=List[ItemResponse]) async def get_labeling_queue( session_id: Optional[str] = Query(None), status: str = Query("pending"), limit: int = Query(10, ge=1, le=50), ): """ Get items from the labeling queue. Args: session_id: Optional filter by session status: Filter by status (pending, confirmed, corrected, skipped) limit: Number of items to return """ items = await get_ocr_labeling_queue( session_id=session_id, status=status, limit=limit, ) return [ ItemResponse( id=item['id'], session_id=item['session_id'], session_name=item.get('session_name', ''), image_path=item['image_path'], image_url=get_image_url(item['image_path']), ocr_text=item.get('ocr_text'), ocr_confidence=item.get('ocr_confidence'), ground_truth=item.get('ground_truth'), status=item.get('status', 'pending'), metadata=item.get('metadata'), created_at=item.get('created_at', datetime.utcnow()), ) for item in items ] @router.get("/items/{item_id}", response_model=ItemResponse) async def get_item(item_id: str): """Get a specific labeling item.""" item = await get_ocr_labeling_item(item_id) if not item: raise HTTPException(status_code=404, detail="Item not found") return ItemResponse( id=item['id'], session_id=item['session_id'], session_name=item.get('session_name', ''), image_path=item['image_path'], image_url=get_image_url(item['image_path']), ocr_text=item.get('ocr_text'), ocr_confidence=item.get('ocr_confidence'), ground_truth=item.get('ground_truth'), status=item.get('status', 'pending'), metadata=item.get('metadata'), created_at=item.get('created_at', datetime.utcnow()), ) @router.post("/confirm") async def confirm_item(request: ConfirmRequest): """ Confirm that OCR text is correct. Sets ground_truth = ocr_text and marks item as confirmed. """ success = await confirm_ocr_label( item_id=request.item_id, labeled_by="admin", # TODO: Get from auth label_time_seconds=request.label_time_seconds, ) if not success: raise HTTPException(status_code=400, detail="Failed to confirm item") return {"status": "confirmed", "item_id": request.item_id} @router.post("/correct") async def correct_item(request: CorrectRequest): """ Save corrected ground truth for an item. Use this when OCR text is wrong and needs manual correction. """ success = await correct_ocr_label( item_id=request.item_id, ground_truth=request.ground_truth, labeled_by="admin", # TODO: Get from auth label_time_seconds=request.label_time_seconds, ) if not success: raise HTTPException(status_code=400, detail="Failed to correct item") return {"status": "corrected", "item_id": request.item_id} @router.post("/skip") async def skip_item(request: SkipRequest): """ Skip an item (unusable image, etc.). Skipped items are not included in training exports. """ success = await skip_ocr_item( item_id=request.item_id, labeled_by="admin", # TODO: Get from auth ) if not success: raise HTTPException(status_code=400, detail="Failed to skip item") return {"status": "skipped", "item_id": request.item_id} @router.get("/stats") async def get_stats(session_id: Optional[str] = Query(None)): """ Get labeling statistics. Args: session_id: Optional session ID for session-specific stats """ stats = await get_ocr_labeling_stats(session_id=session_id) if "error" in stats: raise HTTPException(status_code=500, detail=stats["error"]) return stats @router.post("/export") async def export_data(request: ExportRequest): """ Export labeled data for training. Formats: - generic: JSONL with image_path and ground_truth - trocr: Format for TrOCR/Microsoft Transformer fine-tuning - llama_vision: Format for llama3.2-vision fine-tuning Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/ """ # First, get samples from database db_samples = await export_training_samples( export_format=request.export_format, session_id=request.session_id, batch_id=request.batch_id, exported_by="admin", # TODO: Get from auth ) if not db_samples: return { "export_format": request.export_format, "batch_id": request.batch_id, "exported_count": 0, "samples": [], "message": "No labeled samples found to export", } # If training export service is available, also write to disk export_result = None if TRAINING_EXPORT_AVAILABLE: try: export_service = get_training_export_service() # Convert DB samples to TrainingSample objects training_samples = [] for s in db_samples: training_samples.append(TrainingSample( id=s.get('id', s.get('item_id', '')), image_path=s.get('image_path', ''), ground_truth=s.get('ground_truth', ''), ocr_text=s.get('ocr_text'), ocr_confidence=s.get('ocr_confidence'), metadata=s.get('metadata'), )) # Export to files export_result = export_service.export( samples=training_samples, export_format=request.export_format, batch_id=request.batch_id, ) except Exception as e: print(f"Training export failed: {e}") # Continue without file export response = { "export_format": request.export_format, "batch_id": request.batch_id or (export_result.batch_id if export_result else None), "exported_count": len(db_samples), "samples": db_samples, } if export_result: response["export_path"] = export_result.export_path response["manifest_path"] = export_result.manifest_path return response @router.get("/training-samples") async def list_training_samples( export_format: Optional[str] = Query(None), batch_id: Optional[str] = Query(None), limit: int = Query(100, ge=1, le=1000), ): """Get exported training samples.""" samples = await get_training_samples( export_format=export_format, batch_id=batch_id, limit=limit, ) return { "count": len(samples), "samples": samples, } @router.get("/images/{path:path}") async def get_image(path: str): """ Serve an image from local storage. This endpoint is used when images are stored locally (not in MinIO). """ from fastapi.responses import FileResponse filepath = os.path.join(LOCAL_STORAGE_PATH, path) if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="Image not found") # Determine content type extension = filepath.split('.')[-1].lower() content_type = { 'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'pdf': 'application/pdf', }.get(extension, 'application/octet-stream') return FileResponse(filepath, media_type=content_type) @router.post("/run-ocr/{item_id}") async def run_ocr_for_item(item_id: str): """ Run OCR on an existing item. Use this to re-run OCR or run it if it was skipped during upload. """ item = await get_ocr_labeling_item(item_id) if not item: raise HTTPException(status_code=404, detail="Item not found") # Load image image_path = item['image_path'] if image_path.startswith(LOCAL_STORAGE_PATH): # Load from local storage if not os.path.exists(image_path): raise HTTPException(status_code=404, detail="Image file not found") with open(image_path, 'rb') as f: image_data = f.read() elif MINIO_AVAILABLE: # Load from MinIO try: image_data = get_ocr_image(image_path) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load image: {e}") else: raise HTTPException(status_code=500, detail="Cannot load image") # Get OCR model from session session = await get_ocr_labeling_session(item['session_id']) ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b' # Run OCR ocr_text, ocr_confidence = await run_ocr_on_image( image_data, os.path.basename(image_path), model=ocr_model ) if ocr_text is None: raise HTTPException(status_code=500, detail="OCR failed") # Update item in database from metrics_db import get_pool pool = await get_pool() if pool: async with pool.acquire() as conn: await conn.execute( """ UPDATE ocr_labeling_items SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4 WHERE id = $1 """, item_id, ocr_text, ocr_confidence, ocr_model ) return { "item_id": item_id, "ocr_text": ocr_text, "ocr_confidence": ocr_confidence, "ocr_model": ocr_model, } @router.get("/exports") async def list_exports(export_format: Optional[str] = Query(None)): """ List all available training data exports. Args: export_format: Optional filter by format (generic, trocr, llama_vision) Returns: List of export manifests with paths and metadata """ if not TRAINING_EXPORT_AVAILABLE: return { "exports": [], "message": "Training export service not available", } try: export_service = get_training_export_service() exports = export_service.list_exports(export_format=export_format) return { "count": len(exports), "exports": exports, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")