""" OCR Labeling - Session and Labeling Route Handlers Extracted from ocr_labeling_api.py to keep files under 500 LOC. Endpoints: - POST /sessions - Create labeling session - GET /sessions - List sessions - GET /sessions/{id} - Get session - GET /queue - Get labeling queue - GET /items/{id} - Get item - POST /confirm - Confirm OCR - POST /correct - Correct ground truth - POST /skip - Skip item - GET /stats - Get statistics """ from fastapi import APIRouter, HTTPException, Query from typing import Optional, List from datetime import datetime import uuid from metrics_db import ( create_ocr_labeling_session, get_ocr_labeling_sessions, get_ocr_labeling_session, get_ocr_labeling_queue, get_ocr_labeling_item, confirm_ocr_label, correct_ocr_label, skip_ocr_item, get_ocr_labeling_stats, ) from ocr_labeling_models import ( SessionCreate, SessionResponse, ItemResponse, ConfirmRequest, CorrectRequest, SkipRequest, ) from ocr_labeling_helpers import get_image_url router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) # ============================================================================= # Session Endpoints # ============================================================================= @router.post("/sessions", response_model=SessionResponse) async def create_session(session: SessionCreate): """Create a new OCR labeling session.""" session_id = str(uuid.uuid4()) success = await create_ocr_labeling_session( session_id=session_id, name=session.name, source_type=session.source_type, description=session.description, ocr_model=session.ocr_model, ) if not success: raise HTTPException(status_code=500, detail="Failed to create session") return SessionResponse( id=session_id, name=session.name, source_type=session.source_type, description=session.description, ocr_model=session.ocr_model, total_items=0, labeled_items=0, confirmed_items=0, corrected_items=0, skipped_items=0, created_at=datetime.utcnow(), ) @router.get("/sessions", response_model=List[SessionResponse]) async def list_sessions(limit: int = Query(50, ge=1, le=100)): """List all OCR labeling sessions.""" sessions = await get_ocr_labeling_sessions(limit=limit) return [ SessionResponse( id=s['id'], name=s['name'], source_type=s['source_type'], description=s.get('description'), ocr_model=s.get('ocr_model'), total_items=s.get('total_items', 0), labeled_items=s.get('labeled_items', 0), confirmed_items=s.get('confirmed_items', 0), corrected_items=s.get('corrected_items', 0), skipped_items=s.get('skipped_items', 0), created_at=s.get('created_at', datetime.utcnow()), ) for s in sessions ] @router.get("/sessions/{session_id}", response_model=SessionResponse) async def get_session(session_id: str): """Get a specific OCR labeling session.""" session = await get_ocr_labeling_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return SessionResponse( id=session['id'], name=session['name'], source_type=session['source_type'], description=session.get('description'), ocr_model=session.get('ocr_model'), total_items=session.get('total_items', 0), labeled_items=session.get('labeled_items', 0), confirmed_items=session.get('confirmed_items', 0), corrected_items=session.get('corrected_items', 0), skipped_items=session.get('skipped_items', 0), created_at=session.get('created_at', datetime.utcnow()), ) # ============================================================================= # Queue and Item Endpoints # ============================================================================= @router.get("/queue", response_model=List[ItemResponse]) async def get_labeling_queue( session_id: Optional[str] = Query(None), status: str = Query("pending"), limit: int = Query(10, ge=1, le=50), ): """Get items from the labeling queue.""" items = await get_ocr_labeling_queue( session_id=session_id, status=status, limit=limit, ) return [ ItemResponse( id=item['id'], session_id=item['session_id'], session_name=item.get('session_name', ''), image_path=item['image_path'], image_url=get_image_url(item['image_path']), ocr_text=item.get('ocr_text'), ocr_confidence=item.get('ocr_confidence'), ground_truth=item.get('ground_truth'), status=item.get('status', 'pending'), metadata=item.get('metadata'), created_at=item.get('created_at', datetime.utcnow()), ) for item in items ] @router.get("/items/{item_id}", response_model=ItemResponse) async def get_item(item_id: str): """Get a specific labeling item.""" item = await get_ocr_labeling_item(item_id) if not item: raise HTTPException(status_code=404, detail="Item not found") return ItemResponse( id=item['id'], session_id=item['session_id'], session_name=item.get('session_name', ''), image_path=item['image_path'], image_url=get_image_url(item['image_path']), ocr_text=item.get('ocr_text'), ocr_confidence=item.get('ocr_confidence'), ground_truth=item.get('ground_truth'), status=item.get('status', 'pending'), metadata=item.get('metadata'), created_at=item.get('created_at', datetime.utcnow()), ) # ============================================================================= # Labeling Action Endpoints # ============================================================================= @router.post("/confirm") async def confirm_item(request: ConfirmRequest): """Confirm that OCR text is correct.""" success = await confirm_ocr_label( item_id=request.item_id, labeled_by="admin", label_time_seconds=request.label_time_seconds, ) if not success: raise HTTPException(status_code=400, detail="Failed to confirm item") return {"status": "confirmed", "item_id": request.item_id} @router.post("/correct") async def correct_item(request: CorrectRequest): """Save corrected ground truth for an item.""" success = await correct_ocr_label( item_id=request.item_id, ground_truth=request.ground_truth, labeled_by="admin", label_time_seconds=request.label_time_seconds, ) if not success: raise HTTPException(status_code=400, detail="Failed to correct item") return {"status": "corrected", "item_id": request.item_id} @router.post("/skip") async def skip_item(request: SkipRequest): """Skip an item (unusable image, etc.).""" success = await skip_ocr_item( item_id=request.item_id, labeled_by="admin", ) if not success: raise HTTPException(status_code=400, detail="Failed to skip item") return {"status": "skipped", "item_id": request.item_id} @router.get("/stats") async def get_stats(session_id: Optional[str] = Query(None)): """Get labeling statistics.""" stats = await get_ocr_labeling_stats(session_id=session_id) if "error" in stats: raise HTTPException(status_code=500, detail=stats["error"]) return stats