Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
846 lines
26 KiB
Python
846 lines
26 KiB
Python
"""
|
|
OCR Labeling API for Handwriting Training Data Collection
|
|
|
|
DATENSCHUTZ/PRIVACY:
|
|
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
|
- Keine Daten werden an externe Server gesendet
|
|
- Bilder werden mit SHA256-Hash dedupliziert
|
|
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
|
|
|
|
Endpoints:
|
|
- POST /sessions - Create labeling session
|
|
- POST /sessions/{id}/upload - Upload images for labeling
|
|
- GET /queue - Get next items to label
|
|
- POST /confirm - Confirm OCR as correct
|
|
- POST /correct - Save corrected ground truth
|
|
- POST /skip - Skip unusable item
|
|
- GET /stats - Get labeling statistics
|
|
- POST /export - Export training data
|
|
"""
|
|
|
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
from typing import Optional, List, Dict, Any
|
|
from datetime import datetime
|
|
import uuid
|
|
import hashlib
|
|
import os
|
|
import base64
|
|
|
|
# Import database functions
|
|
from metrics_db import (
|
|
create_ocr_labeling_session,
|
|
get_ocr_labeling_sessions,
|
|
get_ocr_labeling_session,
|
|
add_ocr_labeling_item,
|
|
get_ocr_labeling_queue,
|
|
get_ocr_labeling_item,
|
|
confirm_ocr_label,
|
|
correct_ocr_label,
|
|
skip_ocr_item,
|
|
get_ocr_labeling_stats,
|
|
export_training_samples,
|
|
get_training_samples,
|
|
)
|
|
|
|
# Try to import Vision OCR service
|
|
try:
|
|
import sys
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
|
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
|
VISION_OCR_AVAILABLE = True
|
|
except ImportError:
|
|
VISION_OCR_AVAILABLE = False
|
|
print("Warning: Vision OCR service not available")
|
|
|
|
# Try to import PaddleOCR from hybrid_vocab_extractor
|
|
try:
|
|
from hybrid_vocab_extractor import run_paddle_ocr
|
|
PADDLEOCR_AVAILABLE = True
|
|
except ImportError:
|
|
PADDLEOCR_AVAILABLE = False
|
|
print("Warning: PaddleOCR not available")
|
|
|
|
# Try to import TrOCR service
|
|
try:
|
|
from services.trocr_service import run_trocr_ocr
|
|
TROCR_AVAILABLE = True
|
|
except ImportError:
|
|
TROCR_AVAILABLE = False
|
|
print("Warning: TrOCR service not available")
|
|
|
|
# Try to import Donut service
|
|
try:
|
|
from services.donut_ocr_service import run_donut_ocr
|
|
DONUT_AVAILABLE = True
|
|
except ImportError:
|
|
DONUT_AVAILABLE = False
|
|
print("Warning: Donut OCR service not available")
|
|
|
|
# Try to import MinIO storage
|
|
try:
|
|
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
|
MINIO_AVAILABLE = True
|
|
except ImportError:
|
|
MINIO_AVAILABLE = False
|
|
print("Warning: MinIO storage not available, using local storage")
|
|
|
|
# Try to import Training Export Service
|
|
try:
|
|
from training_export_service import (
|
|
TrainingExportService,
|
|
TrainingSample,
|
|
get_training_export_service,
|
|
)
|
|
TRAINING_EXPORT_AVAILABLE = True
|
|
except ImportError:
|
|
TRAINING_EXPORT_AVAILABLE = False
|
|
print("Warning: Training export service not available")
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
|
|
|
# Local storage path (fallback if MinIO not available)
|
|
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
|
|
|
|
|
# =============================================================================
|
|
# Pydantic Models
|
|
# =============================================================================
|
|
|
|
class SessionCreate(BaseModel):
|
|
name: str
|
|
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
|
description: Optional[str] = None
|
|
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
id: str
|
|
name: str
|
|
source_type: str
|
|
description: Optional[str]
|
|
ocr_model: Optional[str]
|
|
total_items: int
|
|
labeled_items: int
|
|
confirmed_items: int
|
|
corrected_items: int
|
|
skipped_items: int
|
|
created_at: datetime
|
|
|
|
|
|
class ItemResponse(BaseModel):
|
|
id: str
|
|
session_id: str
|
|
session_name: str
|
|
image_path: str
|
|
image_url: Optional[str]
|
|
ocr_text: Optional[str]
|
|
ocr_confidence: Optional[float]
|
|
ground_truth: Optional[str]
|
|
status: str
|
|
metadata: Optional[Dict]
|
|
created_at: datetime
|
|
|
|
|
|
class ConfirmRequest(BaseModel):
|
|
item_id: str
|
|
label_time_seconds: Optional[int] = None
|
|
|
|
|
|
class CorrectRequest(BaseModel):
|
|
item_id: str
|
|
ground_truth: str
|
|
label_time_seconds: Optional[int] = None
|
|
|
|
|
|
class SkipRequest(BaseModel):
|
|
item_id: str
|
|
|
|
|
|
class ExportRequest(BaseModel):
|
|
export_format: str = "generic" # generic, trocr, llama_vision
|
|
session_id: Optional[str] = None
|
|
batch_id: Optional[str] = None
|
|
|
|
|
|
class StatsResponse(BaseModel):
|
|
total_sessions: Optional[int] = None
|
|
total_items: int
|
|
labeled_items: int
|
|
confirmed_items: int
|
|
corrected_items: int
|
|
pending_items: int
|
|
exportable_items: Optional[int] = None
|
|
accuracy_rate: float
|
|
avg_label_time_seconds: Optional[float] = None
|
|
|
|
|
|
# =============================================================================
|
|
# Helper Functions
|
|
# =============================================================================
|
|
|
|
def compute_image_hash(image_data: bytes) -> str:
|
|
"""Compute SHA256 hash of image data."""
|
|
return hashlib.sha256(image_data).hexdigest()
|
|
|
|
|
|
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
|
"""
|
|
Run OCR on an image using the specified model.
|
|
|
|
Models:
|
|
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
|
- trocr: Microsoft TrOCR (fast for printed text)
|
|
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
|
- donut: Document Understanding Transformer (structured documents)
|
|
|
|
Returns:
|
|
Tuple of (ocr_text, confidence)
|
|
"""
|
|
print(f"Running OCR with model: {model}")
|
|
|
|
# Route to appropriate OCR service based on model
|
|
if model == "paddleocr":
|
|
return await run_paddleocr_wrapper(image_data, filename)
|
|
elif model == "donut":
|
|
return await run_donut_wrapper(image_data, filename)
|
|
elif model == "trocr":
|
|
return await run_trocr_wrapper(image_data, filename)
|
|
else:
|
|
# Default: Vision LLM (llama3.2-vision or similar)
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
|
|
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
"""Vision LLM OCR wrapper."""
|
|
if not VISION_OCR_AVAILABLE:
|
|
print("Vision OCR service not available")
|
|
return None, 0.0
|
|
|
|
try:
|
|
service = get_vision_ocr_service()
|
|
if not await service.is_available():
|
|
print("Vision OCR service not available (is_available check failed)")
|
|
return None, 0.0
|
|
|
|
result = await service.extract_text(
|
|
image_data,
|
|
filename=filename,
|
|
is_handwriting=True
|
|
)
|
|
return result.text, result.confidence
|
|
except Exception as e:
|
|
print(f"Vision OCR failed: {e}")
|
|
return None, 0.0
|
|
|
|
|
|
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
|
if not PADDLEOCR_AVAILABLE:
|
|
print("PaddleOCR not available, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
try:
|
|
# run_paddle_ocr returns (regions, raw_text)
|
|
regions, raw_text = run_paddle_ocr(image_data)
|
|
|
|
if not raw_text:
|
|
print("PaddleOCR returned empty text")
|
|
return None, 0.0
|
|
|
|
# Calculate average confidence from regions
|
|
if regions:
|
|
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
|
else:
|
|
avg_confidence = 0.5
|
|
|
|
return raw_text, avg_confidence
|
|
except Exception as e:
|
|
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
|
|
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
"""TrOCR wrapper."""
|
|
if not TROCR_AVAILABLE:
|
|
print("TrOCR not available, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
try:
|
|
text, confidence = await run_trocr_ocr(image_data)
|
|
return text, confidence
|
|
except Exception as e:
|
|
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
|
|
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
"""Donut OCR wrapper."""
|
|
if not DONUT_AVAILABLE:
|
|
print("Donut not available, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
try:
|
|
text, confidence = await run_donut_ocr(image_data)
|
|
return text, confidence
|
|
except Exception as e:
|
|
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
|
return await run_vision_ocr_wrapper(image_data, filename)
|
|
|
|
|
|
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
|
"""Save image to local storage."""
|
|
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
|
os.makedirs(session_dir, exist_ok=True)
|
|
|
|
filename = f"{item_id}.{extension}"
|
|
filepath = os.path.join(session_dir, filename)
|
|
|
|
with open(filepath, 'wb') as f:
|
|
f.write(image_data)
|
|
|
|
return filepath
|
|
|
|
|
|
def get_image_url(image_path: str) -> str:
|
|
"""Get URL for an image."""
|
|
# For local images, return a relative path that the frontend can use
|
|
if image_path.startswith(LOCAL_STORAGE_PATH):
|
|
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
|
return f"/api/v1/ocr-label/images/{relative_path}"
|
|
# For MinIO images, the path is already a URL or key
|
|
return image_path
|
|
|
|
|
|
# =============================================================================
|
|
# API Endpoints
|
|
# =============================================================================
|
|
|
|
@router.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(session: SessionCreate):
|
|
"""
|
|
Create a new OCR labeling session.
|
|
|
|
A session groups related images for labeling (e.g., all scans from one class).
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
|
|
success = await create_ocr_labeling_session(
|
|
session_id=session_id,
|
|
name=session.name,
|
|
source_type=session.source_type,
|
|
description=session.description,
|
|
ocr_model=session.ocr_model,
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail="Failed to create session")
|
|
|
|
return SessionResponse(
|
|
id=session_id,
|
|
name=session.name,
|
|
source_type=session.source_type,
|
|
description=session.description,
|
|
ocr_model=session.ocr_model,
|
|
total_items=0,
|
|
labeled_items=0,
|
|
confirmed_items=0,
|
|
corrected_items=0,
|
|
skipped_items=0,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
|
|
@router.get("/sessions", response_model=List[SessionResponse])
|
|
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
|
"""List all OCR labeling sessions."""
|
|
sessions = await get_ocr_labeling_sessions(limit=limit)
|
|
|
|
return [
|
|
SessionResponse(
|
|
id=s['id'],
|
|
name=s['name'],
|
|
source_type=s['source_type'],
|
|
description=s.get('description'),
|
|
ocr_model=s.get('ocr_model'),
|
|
total_items=s.get('total_items', 0),
|
|
labeled_items=s.get('labeled_items', 0),
|
|
confirmed_items=s.get('confirmed_items', 0),
|
|
corrected_items=s.get('corrected_items', 0),
|
|
skipped_items=s.get('skipped_items', 0),
|
|
created_at=s.get('created_at', datetime.utcnow()),
|
|
)
|
|
for s in sessions
|
|
]
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
|
async def get_session(session_id: str):
|
|
"""Get a specific OCR labeling session."""
|
|
session = await get_ocr_labeling_session(session_id)
|
|
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
return SessionResponse(
|
|
id=session['id'],
|
|
name=session['name'],
|
|
source_type=session['source_type'],
|
|
description=session.get('description'),
|
|
ocr_model=session.get('ocr_model'),
|
|
total_items=session.get('total_items', 0),
|
|
labeled_items=session.get('labeled_items', 0),
|
|
confirmed_items=session.get('confirmed_items', 0),
|
|
corrected_items=session.get('corrected_items', 0),
|
|
skipped_items=session.get('skipped_items', 0),
|
|
created_at=session.get('created_at', datetime.utcnow()),
|
|
)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/upload")
|
|
async def upload_images(
|
|
session_id: str,
|
|
background_tasks: BackgroundTasks,
|
|
files: List[UploadFile] = File(...),
|
|
run_ocr: bool = Form(True),
|
|
metadata: Optional[str] = Form(None), # JSON string
|
|
):
|
|
"""
|
|
Upload images to a labeling session.
|
|
|
|
Args:
|
|
session_id: Session to add images to
|
|
files: Image files to upload (PNG, JPG, PDF)
|
|
run_ocr: Whether to run OCR immediately (default: True)
|
|
metadata: Optional JSON metadata (subject, year, etc.)
|
|
"""
|
|
import json
|
|
|
|
# Verify session exists
|
|
session = await get_ocr_labeling_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Parse metadata
|
|
meta_dict = None
|
|
if metadata:
|
|
try:
|
|
meta_dict = json.loads(metadata)
|
|
except json.JSONDecodeError:
|
|
meta_dict = {"raw": metadata}
|
|
|
|
results = []
|
|
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
|
|
|
for file in files:
|
|
# Read file content
|
|
content = await file.read()
|
|
|
|
# Compute hash for deduplication
|
|
image_hash = compute_image_hash(content)
|
|
|
|
# Generate item ID
|
|
item_id = str(uuid.uuid4())
|
|
|
|
# Determine file extension
|
|
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
|
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
|
extension = 'png'
|
|
|
|
# Save image
|
|
if MINIO_AVAILABLE:
|
|
# Upload to MinIO
|
|
try:
|
|
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
|
except Exception as e:
|
|
print(f"MinIO upload failed, using local storage: {e}")
|
|
image_path = save_image_locally(session_id, item_id, content, extension)
|
|
else:
|
|
# Save locally
|
|
image_path = save_image_locally(session_id, item_id, content, extension)
|
|
|
|
# Run OCR if requested
|
|
ocr_text = None
|
|
ocr_confidence = None
|
|
|
|
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
|
|
ocr_text, ocr_confidence = await run_ocr_on_image(
|
|
content,
|
|
file.filename or f"{item_id}.{extension}",
|
|
model=ocr_model
|
|
)
|
|
|
|
# Add to database
|
|
success = await add_ocr_labeling_item(
|
|
item_id=item_id,
|
|
session_id=session_id,
|
|
image_path=image_path,
|
|
image_hash=image_hash,
|
|
ocr_text=ocr_text,
|
|
ocr_confidence=ocr_confidence,
|
|
ocr_model=ocr_model if ocr_text else None,
|
|
metadata=meta_dict,
|
|
)
|
|
|
|
if success:
|
|
results.append({
|
|
"id": item_id,
|
|
"filename": file.filename,
|
|
"image_path": image_path,
|
|
"image_hash": image_hash,
|
|
"ocr_text": ocr_text,
|
|
"ocr_confidence": ocr_confidence,
|
|
"status": "pending",
|
|
})
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"uploaded_count": len(results),
|
|
"items": results,
|
|
}
|
|
|
|
|
|
@router.get("/queue", response_model=List[ItemResponse])
|
|
async def get_labeling_queue(
|
|
session_id: Optional[str] = Query(None),
|
|
status: str = Query("pending"),
|
|
limit: int = Query(10, ge=1, le=50),
|
|
):
|
|
"""
|
|
Get items from the labeling queue.
|
|
|
|
Args:
|
|
session_id: Optional filter by session
|
|
status: Filter by status (pending, confirmed, corrected, skipped)
|
|
limit: Number of items to return
|
|
"""
|
|
items = await get_ocr_labeling_queue(
|
|
session_id=session_id,
|
|
status=status,
|
|
limit=limit,
|
|
)
|
|
|
|
return [
|
|
ItemResponse(
|
|
id=item['id'],
|
|
session_id=item['session_id'],
|
|
session_name=item.get('session_name', ''),
|
|
image_path=item['image_path'],
|
|
image_url=get_image_url(item['image_path']),
|
|
ocr_text=item.get('ocr_text'),
|
|
ocr_confidence=item.get('ocr_confidence'),
|
|
ground_truth=item.get('ground_truth'),
|
|
status=item.get('status', 'pending'),
|
|
metadata=item.get('metadata'),
|
|
created_at=item.get('created_at', datetime.utcnow()),
|
|
)
|
|
for item in items
|
|
]
|
|
|
|
|
|
@router.get("/items/{item_id}", response_model=ItemResponse)
|
|
async def get_item(item_id: str):
|
|
"""Get a specific labeling item."""
|
|
item = await get_ocr_labeling_item(item_id)
|
|
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
|
|
return ItemResponse(
|
|
id=item['id'],
|
|
session_id=item['session_id'],
|
|
session_name=item.get('session_name', ''),
|
|
image_path=item['image_path'],
|
|
image_url=get_image_url(item['image_path']),
|
|
ocr_text=item.get('ocr_text'),
|
|
ocr_confidence=item.get('ocr_confidence'),
|
|
ground_truth=item.get('ground_truth'),
|
|
status=item.get('status', 'pending'),
|
|
metadata=item.get('metadata'),
|
|
created_at=item.get('created_at', datetime.utcnow()),
|
|
)
|
|
|
|
|
|
@router.post("/confirm")
|
|
async def confirm_item(request: ConfirmRequest):
|
|
"""
|
|
Confirm that OCR text is correct.
|
|
|
|
Sets ground_truth = ocr_text and marks item as confirmed.
|
|
"""
|
|
success = await confirm_ocr_label(
|
|
item_id=request.item_id,
|
|
labeled_by="admin", # TODO: Get from auth
|
|
label_time_seconds=request.label_time_seconds,
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
|
|
|
return {"status": "confirmed", "item_id": request.item_id}
|
|
|
|
|
|
@router.post("/correct")
|
|
async def correct_item(request: CorrectRequest):
|
|
"""
|
|
Save corrected ground truth for an item.
|
|
|
|
Use this when OCR text is wrong and needs manual correction.
|
|
"""
|
|
success = await correct_ocr_label(
|
|
item_id=request.item_id,
|
|
ground_truth=request.ground_truth,
|
|
labeled_by="admin", # TODO: Get from auth
|
|
label_time_seconds=request.label_time_seconds,
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to correct item")
|
|
|
|
return {"status": "corrected", "item_id": request.item_id}
|
|
|
|
|
|
@router.post("/skip")
|
|
async def skip_item(request: SkipRequest):
|
|
"""
|
|
Skip an item (unusable image, etc.).
|
|
|
|
Skipped items are not included in training exports.
|
|
"""
|
|
success = await skip_ocr_item(
|
|
item_id=request.item_id,
|
|
labeled_by="admin", # TODO: Get from auth
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to skip item")
|
|
|
|
return {"status": "skipped", "item_id": request.item_id}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def get_stats(session_id: Optional[str] = Query(None)):
|
|
"""
|
|
Get labeling statistics.
|
|
|
|
Args:
|
|
session_id: Optional session ID for session-specific stats
|
|
"""
|
|
stats = await get_ocr_labeling_stats(session_id=session_id)
|
|
|
|
if "error" in stats:
|
|
raise HTTPException(status_code=500, detail=stats["error"])
|
|
|
|
return stats
|
|
|
|
|
|
@router.post("/export")
|
|
async def export_data(request: ExportRequest):
|
|
"""
|
|
Export labeled data for training.
|
|
|
|
Formats:
|
|
- generic: JSONL with image_path and ground_truth
|
|
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
|
|
- llama_vision: Format for llama3.2-vision fine-tuning
|
|
|
|
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
|
|
"""
|
|
# First, get samples from database
|
|
db_samples = await export_training_samples(
|
|
export_format=request.export_format,
|
|
session_id=request.session_id,
|
|
batch_id=request.batch_id,
|
|
exported_by="admin", # TODO: Get from auth
|
|
)
|
|
|
|
if not db_samples:
|
|
return {
|
|
"export_format": request.export_format,
|
|
"batch_id": request.batch_id,
|
|
"exported_count": 0,
|
|
"samples": [],
|
|
"message": "No labeled samples found to export",
|
|
}
|
|
|
|
# If training export service is available, also write to disk
|
|
export_result = None
|
|
if TRAINING_EXPORT_AVAILABLE:
|
|
try:
|
|
export_service = get_training_export_service()
|
|
|
|
# Convert DB samples to TrainingSample objects
|
|
training_samples = []
|
|
for s in db_samples:
|
|
training_samples.append(TrainingSample(
|
|
id=s.get('id', s.get('item_id', '')),
|
|
image_path=s.get('image_path', ''),
|
|
ground_truth=s.get('ground_truth', ''),
|
|
ocr_text=s.get('ocr_text'),
|
|
ocr_confidence=s.get('ocr_confidence'),
|
|
metadata=s.get('metadata'),
|
|
))
|
|
|
|
# Export to files
|
|
export_result = export_service.export(
|
|
samples=training_samples,
|
|
export_format=request.export_format,
|
|
batch_id=request.batch_id,
|
|
)
|
|
except Exception as e:
|
|
print(f"Training export failed: {e}")
|
|
# Continue without file export
|
|
|
|
response = {
|
|
"export_format": request.export_format,
|
|
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
|
"exported_count": len(db_samples),
|
|
"samples": db_samples,
|
|
}
|
|
|
|
if export_result:
|
|
response["export_path"] = export_result.export_path
|
|
response["manifest_path"] = export_result.manifest_path
|
|
|
|
return response
|
|
|
|
|
|
@router.get("/training-samples")
|
|
async def list_training_samples(
|
|
export_format: Optional[str] = Query(None),
|
|
batch_id: Optional[str] = Query(None),
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
):
|
|
"""Get exported training samples."""
|
|
samples = await get_training_samples(
|
|
export_format=export_format,
|
|
batch_id=batch_id,
|
|
limit=limit,
|
|
)
|
|
|
|
return {
|
|
"count": len(samples),
|
|
"samples": samples,
|
|
}
|
|
|
|
|
|
@router.get("/images/{path:path}")
|
|
async def get_image(path: str):
|
|
"""
|
|
Serve an image from local storage.
|
|
|
|
This endpoint is used when images are stored locally (not in MinIO).
|
|
"""
|
|
from fastapi.responses import FileResponse
|
|
|
|
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
|
|
|
if not os.path.exists(filepath):
|
|
raise HTTPException(status_code=404, detail="Image not found")
|
|
|
|
# Determine content type
|
|
extension = filepath.split('.')[-1].lower()
|
|
content_type = {
|
|
'png': 'image/png',
|
|
'jpg': 'image/jpeg',
|
|
'jpeg': 'image/jpeg',
|
|
'pdf': 'application/pdf',
|
|
}.get(extension, 'application/octet-stream')
|
|
|
|
return FileResponse(filepath, media_type=content_type)
|
|
|
|
|
|
@router.post("/run-ocr/{item_id}")
|
|
async def run_ocr_for_item(item_id: str):
|
|
"""
|
|
Run OCR on an existing item.
|
|
|
|
Use this to re-run OCR or run it if it was skipped during upload.
|
|
"""
|
|
item = await get_ocr_labeling_item(item_id)
|
|
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
|
|
# Load image
|
|
image_path = item['image_path']
|
|
|
|
if image_path.startswith(LOCAL_STORAGE_PATH):
|
|
# Load from local storage
|
|
if not os.path.exists(image_path):
|
|
raise HTTPException(status_code=404, detail="Image file not found")
|
|
with open(image_path, 'rb') as f:
|
|
image_data = f.read()
|
|
elif MINIO_AVAILABLE:
|
|
# Load from MinIO
|
|
try:
|
|
image_data = get_ocr_image(image_path)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Cannot load image")
|
|
|
|
# Get OCR model from session
|
|
session = await get_ocr_labeling_session(item['session_id'])
|
|
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
|
|
|
# Run OCR
|
|
ocr_text, ocr_confidence = await run_ocr_on_image(
|
|
image_data,
|
|
os.path.basename(image_path),
|
|
model=ocr_model
|
|
)
|
|
|
|
if ocr_text is None:
|
|
raise HTTPException(status_code=500, detail="OCR failed")
|
|
|
|
# Update item in database
|
|
from metrics_db import get_pool
|
|
pool = await get_pool()
|
|
if pool:
|
|
async with pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE ocr_labeling_items
|
|
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
|
WHERE id = $1
|
|
""",
|
|
item_id, ocr_text, ocr_confidence, ocr_model
|
|
)
|
|
|
|
return {
|
|
"item_id": item_id,
|
|
"ocr_text": ocr_text,
|
|
"ocr_confidence": ocr_confidence,
|
|
"ocr_model": ocr_model,
|
|
}
|
|
|
|
|
|
@router.get("/exports")
|
|
async def list_exports(export_format: Optional[str] = Query(None)):
|
|
"""
|
|
List all available training data exports.
|
|
|
|
Args:
|
|
export_format: Optional filter by format (generic, trocr, llama_vision)
|
|
|
|
Returns:
|
|
List of export manifests with paths and metadata
|
|
"""
|
|
if not TRAINING_EXPORT_AVAILABLE:
|
|
return {
|
|
"exports": [],
|
|
"message": "Training export service not available",
|
|
}
|
|
|
|
try:
|
|
export_service = get_training_export_service()
|
|
exports = export_service.list_exports(export_format=export_format)
|
|
|
|
return {
|
|
"count": len(exports),
|
|
"exports": exports,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|