fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
845
klausur-service/backend/ocr_labeling_api.py
Normal file
845
klausur-service/backend/ocr_labeling_api.py
Normal file
@@ -0,0 +1,845 @@
|
||||
"""
|
||||
OCR Labeling API for Handwriting Training Data Collection
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||
- Keine Daten werden an externe Server gesendet
|
||||
- Bilder werden mit SHA256-Hash dedupliziert
|
||||
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions - Create labeling session
|
||||
- POST /sessions/{id}/upload - Upload images for labeling
|
||||
- GET /queue - Get next items to label
|
||||
- POST /confirm - Confirm OCR as correct
|
||||
- POST /correct - Save corrected ground truth
|
||||
- POST /skip - Skip unusable item
|
||||
- GET /stats - Get labeling statistics
|
||||
- POST /export - Export training data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import hashlib
|
||||
import os
|
||||
import base64
|
||||
|
||||
# Import database functions
|
||||
from metrics_db import (
|
||||
create_ocr_labeling_session,
|
||||
get_ocr_labeling_sessions,
|
||||
get_ocr_labeling_session,
|
||||
add_ocr_labeling_item,
|
||||
get_ocr_labeling_queue,
|
||||
get_ocr_labeling_item,
|
||||
confirm_ocr_label,
|
||||
correct_ocr_label,
|
||||
skip_ocr_item,
|
||||
get_ocr_labeling_stats,
|
||||
export_training_samples,
|
||||
get_training_samples,
|
||||
)
|
||||
|
||||
# Try to import Vision OCR service
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||
VISION_OCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
VISION_OCR_AVAILABLE = False
|
||||
print("Warning: Vision OCR service not available")
|
||||
|
||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr
|
||||
PADDLEOCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
PADDLEOCR_AVAILABLE = False
|
||||
print("Warning: PaddleOCR not available")
|
||||
|
||||
# Try to import TrOCR service
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr
|
||||
TROCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
TROCR_AVAILABLE = False
|
||||
print("Warning: TrOCR service not available")
|
||||
|
||||
# Try to import Donut service
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr
|
||||
DONUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
DONUT_AVAILABLE = False
|
||||
print("Warning: Donut OCR service not available")
|
||||
|
||||
# Try to import MinIO storage
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||
MINIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
MINIO_AVAILABLE = False
|
||||
print("Warning: MinIO storage not available, using local storage")
|
||||
|
||||
# Try to import Training Export Service
|
||||
try:
|
||||
from training_export_service import (
|
||||
TrainingExportService,
|
||||
TrainingSample,
|
||||
get_training_export_service,
|
||||
)
|
||||
TRAINING_EXPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRAINING_EXPORT_AVAILABLE = False
|
||||
print("Warning: Training export service not available")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
# Local storage path (fallback if MinIO not available)
|
||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
name: str
|
||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||
description: Optional[str] = None
|
||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
source_type: str
|
||||
description: Optional[str]
|
||||
ocr_model: Optional[str]
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
skipped_items: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
session_name: str
|
||||
image_path: str
|
||||
image_url: Optional[str]
|
||||
ocr_text: Optional[str]
|
||||
ocr_confidence: Optional[float]
|
||||
ground_truth: Optional[str]
|
||||
status: str
|
||||
metadata: Optional[Dict]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ConfirmRequest(BaseModel):
|
||||
item_id: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class CorrectRequest(BaseModel):
|
||||
item_id: str
|
||||
ground_truth: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
item_id: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
export_format: str = "generic" # generic, trocr, llama_vision
|
||||
session_id: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_sessions: Optional[int] = None
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
pending_items: int
|
||||
exportable_items: Optional[int] = None
|
||||
accuracy_rate: float
|
||||
avg_label_time_seconds: Optional[float] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data."""
|
||||
return hashlib.sha256(image_data).hexdigest()
|
||||
|
||||
|
||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||
"""
|
||||
Run OCR on an image using the specified model.
|
||||
|
||||
Models:
|
||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||
- trocr: Microsoft TrOCR (fast for printed text)
|
||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||
- donut: Document Understanding Transformer (structured documents)
|
||||
|
||||
Returns:
|
||||
Tuple of (ocr_text, confidence)
|
||||
"""
|
||||
print(f"Running OCR with model: {model}")
|
||||
|
||||
# Route to appropriate OCR service based on model
|
||||
if model == "paddleocr":
|
||||
return await run_paddleocr_wrapper(image_data, filename)
|
||||
elif model == "donut":
|
||||
return await run_donut_wrapper(image_data, filename)
|
||||
elif model == "trocr":
|
||||
return await run_trocr_wrapper(image_data, filename)
|
||||
else:
|
||||
# Default: Vision LLM (llama3.2-vision or similar)
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Vision LLM OCR wrapper."""
|
||||
if not VISION_OCR_AVAILABLE:
|
||||
print("Vision OCR service not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
service = get_vision_ocr_service()
|
||||
if not await service.is_available():
|
||||
print("Vision OCR service not available (is_available check failed)")
|
||||
return None, 0.0
|
||||
|
||||
result = await service.extract_text(
|
||||
image_data,
|
||||
filename=filename,
|
||||
is_handwriting=True
|
||||
)
|
||||
return result.text, result.confidence
|
||||
except Exception as e:
|
||||
print(f"Vision OCR failed: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||
if not PADDLEOCR_AVAILABLE:
|
||||
print("PaddleOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
# run_paddle_ocr returns (regions, raw_text)
|
||||
regions, raw_text = run_paddle_ocr(image_data)
|
||||
|
||||
if not raw_text:
|
||||
print("PaddleOCR returned empty text")
|
||||
return None, 0.0
|
||||
|
||||
# Calculate average confidence from regions
|
||||
if regions:
|
||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||
else:
|
||||
avg_confidence = 0.5
|
||||
|
||||
return raw_text, avg_confidence
|
||||
except Exception as e:
|
||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""TrOCR wrapper."""
|
||||
if not TROCR_AVAILABLE:
|
||||
print("TrOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_trocr_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Donut OCR wrapper."""
|
||||
if not DONUT_AVAILABLE:
|
||||
print("Donut not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||
"""Save image to local storage."""
|
||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
filename = f"{item_id}.{extension}"
|
||||
filepath = os.path.join(session_dir, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def get_image_url(image_path: str) -> str:
|
||||
"""Get URL for an image."""
|
||||
# For local images, return a relative path that the frontend can use
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||
# For MinIO images, the path is already a URL or key
|
||||
return image_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(session: SessionCreate):
|
||||
"""
|
||||
Create a new OCR labeling session.
|
||||
|
||||
A session groups related images for labeling (e.g., all scans from one class).
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
success = await create_ocr_labeling_session(
|
||||
session_id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||
|
||||
return SessionResponse(
|
||||
id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
total_items=0,
|
||||
labeled_items=0,
|
||||
confirmed_items=0,
|
||||
corrected_items=0,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||
"""List all OCR labeling sessions."""
|
||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s['id'],
|
||||
name=s['name'],
|
||||
source_type=s['source_type'],
|
||||
description=s.get('description'),
|
||||
ocr_model=s.get('ocr_model'),
|
||||
total_items=s.get('total_items', 0),
|
||||
labeled_items=s.get('labeled_items', 0),
|
||||
confirmed_items=s.get('confirmed_items', 0),
|
||||
corrected_items=s.get('corrected_items', 0),
|
||||
skipped_items=s.get('skipped_items', 0),
|
||||
created_at=s.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str):
|
||||
"""Get a specific OCR labeling session."""
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
id=session['id'],
|
||||
name=session['name'],
|
||||
source_type=session['source_type'],
|
||||
description=session.get('description'),
|
||||
ocr_model=session.get('ocr_model'),
|
||||
total_items=session.get('total_items', 0),
|
||||
labeled_items=session.get('labeled_items', 0),
|
||||
confirmed_items=session.get('confirmed_items', 0),
|
||||
corrected_items=session.get('corrected_items', 0),
|
||||
skipped_items=session.get('skipped_items', 0),
|
||||
created_at=session.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/upload")
|
||||
async def upload_images(
|
||||
session_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
files: List[UploadFile] = File(...),
|
||||
run_ocr: bool = Form(True),
|
||||
metadata: Optional[str] = Form(None), # JSON string
|
||||
):
|
||||
"""
|
||||
Upload images to a labeling session.
|
||||
|
||||
Args:
|
||||
session_id: Session to add images to
|
||||
files: Image files to upload (PNG, JPG, PDF)
|
||||
run_ocr: Whether to run OCR immediately (default: True)
|
||||
metadata: Optional JSON metadata (subject, year, etc.)
|
||||
"""
|
||||
import json
|
||||
|
||||
# Verify session exists
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Parse metadata
|
||||
meta_dict = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
meta_dict = {"raw": metadata}
|
||||
|
||||
results = []
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||
|
||||
for file in files:
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Compute hash for deduplication
|
||||
image_hash = compute_image_hash(content)
|
||||
|
||||
# Generate item ID
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
# Determine file extension
|
||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||
extension = 'png'
|
||||
|
||||
# Save image
|
||||
if MINIO_AVAILABLE:
|
||||
# Upload to MinIO
|
||||
try:
|
||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed, using local storage: {e}")
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
else:
|
||||
# Save locally
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
|
||||
# Run OCR if requested
|
||||
ocr_text = None
|
||||
ocr_confidence = None
|
||||
|
||||
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
content,
|
||||
file.filename or f"{item_id}.{extension}",
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
# Add to database
|
||||
success = await add_ocr_labeling_item(
|
||||
item_id=item_id,
|
||||
session_id=session_id,
|
||||
image_path=image_path,
|
||||
image_hash=image_hash,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=ocr_confidence,
|
||||
ocr_model=ocr_model if ocr_text else None,
|
||||
metadata=meta_dict,
|
||||
)
|
||||
|
||||
if success:
|
||||
results.append({
|
||||
"id": item_id,
|
||||
"filename": file.filename,
|
||||
"image_path": image_path,
|
||||
"image_hash": image_hash,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"status": "pending",
|
||||
})
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"uploaded_count": len(results),
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/queue", response_model=List[ItemResponse])
|
||||
async def get_labeling_queue(
|
||||
session_id: Optional[str] = Query(None),
|
||||
status: str = Query("pending"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
):
|
||||
"""
|
||||
Get items from the labeling queue.
|
||||
|
||||
Args:
|
||||
session_id: Optional filter by session
|
||||
status: Filter by status (pending, confirmed, corrected, skipped)
|
||||
limit: Number of items to return
|
||||
"""
|
||||
items = await get_ocr_labeling_queue(
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
async def get_item(item_id: str):
|
||||
"""Get a specific labeling item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/confirm")
|
||||
async def confirm_item(request: ConfirmRequest):
|
||||
"""
|
||||
Confirm that OCR text is correct.
|
||||
|
||||
Sets ground_truth = ocr_text and marks item as confirmed.
|
||||
"""
|
||||
success = await confirm_ocr_label(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||
|
||||
return {"status": "confirmed", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/correct")
|
||||
async def correct_item(request: CorrectRequest):
|
||||
"""
|
||||
Save corrected ground truth for an item.
|
||||
|
||||
Use this when OCR text is wrong and needs manual correction.
|
||||
"""
|
||||
success = await correct_ocr_label(
|
||||
item_id=request.item_id,
|
||||
ground_truth=request.ground_truth,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||
|
||||
return {"status": "corrected", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/skip")
|
||||
async def skip_item(request: SkipRequest):
|
||||
"""
|
||||
Skip an item (unusable image, etc.).
|
||||
|
||||
Skipped items are not included in training exports.
|
||||
"""
|
||||
success = await skip_ocr_item(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||
|
||||
return {"status": "skipped", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||
"""
|
||||
Get labeling statistics.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID for session-specific stats
|
||||
"""
|
||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||
|
||||
if "error" in stats:
|
||||
raise HTTPException(status_code=500, detail=stats["error"])
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_data(request: ExportRequest):
|
||||
"""
|
||||
Export labeled data for training.
|
||||
|
||||
Formats:
|
||||
- generic: JSONL with image_path and ground_truth
|
||||
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
|
||||
- llama_vision: Format for llama3.2-vision fine-tuning
|
||||
|
||||
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
|
||||
"""
|
||||
# First, get samples from database
|
||||
db_samples = await export_training_samples(
|
||||
export_format=request.export_format,
|
||||
session_id=request.session_id,
|
||||
batch_id=request.batch_id,
|
||||
exported_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not db_samples:
|
||||
return {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id,
|
||||
"exported_count": 0,
|
||||
"samples": [],
|
||||
"message": "No labeled samples found to export",
|
||||
}
|
||||
|
||||
# If training export service is available, also write to disk
|
||||
export_result = None
|
||||
if TRAINING_EXPORT_AVAILABLE:
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
|
||||
# Convert DB samples to TrainingSample objects
|
||||
training_samples = []
|
||||
for s in db_samples:
|
||||
training_samples.append(TrainingSample(
|
||||
id=s.get('id', s.get('item_id', '')),
|
||||
image_path=s.get('image_path', ''),
|
||||
ground_truth=s.get('ground_truth', ''),
|
||||
ocr_text=s.get('ocr_text'),
|
||||
ocr_confidence=s.get('ocr_confidence'),
|
||||
metadata=s.get('metadata'),
|
||||
))
|
||||
|
||||
# Export to files
|
||||
export_result = export_service.export(
|
||||
samples=training_samples,
|
||||
export_format=request.export_format,
|
||||
batch_id=request.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Training export failed: {e}")
|
||||
# Continue without file export
|
||||
|
||||
response = {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||
"exported_count": len(db_samples),
|
||||
"samples": db_samples,
|
||||
}
|
||||
|
||||
if export_result:
|
||||
response["export_path"] = export_result.export_path
|
||||
response["manifest_path"] = export_result.manifest_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/training-samples")
|
||||
async def list_training_samples(
|
||||
export_format: Optional[str] = Query(None),
|
||||
batch_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
):
|
||||
"""Get exported training samples."""
|
||||
samples = await get_training_samples(
|
||||
export_format=export_format,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/images/{path:path}")
|
||||
async def get_image(path: str):
|
||||
"""
|
||||
Serve an image from local storage.
|
||||
|
||||
This endpoint is used when images are stored locally (not in MinIO).
|
||||
"""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Determine content type
|
||||
extension = filepath.split('.')[-1].lower()
|
||||
content_type = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'pdf': 'application/pdf',
|
||||
}.get(extension, 'application/octet-stream')
|
||||
|
||||
return FileResponse(filepath, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/run-ocr/{item_id}")
|
||||
async def run_ocr_for_item(item_id: str):
|
||||
"""
|
||||
Run OCR on an existing item.
|
||||
|
||||
Use this to re-run OCR or run it if it was skipped during upload.
|
||||
"""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# Load image
|
||||
image_path = item['image_path']
|
||||
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
# Load from local storage
|
||||
if not os.path.exists(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found")
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
elif MINIO_AVAILABLE:
|
||||
# Load from MinIO
|
||||
try:
|
||||
image_data = get_ocr_image(image_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||
|
||||
# Get OCR model from session
|
||||
session = await get_ocr_labeling_session(item['session_id'])
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||
|
||||
# Run OCR
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
image_data,
|
||||
os.path.basename(image_path),
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
if ocr_text is None:
|
||||
raise HTTPException(status_code=500, detail="OCR failed")
|
||||
|
||||
# Update item in database
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE ocr_labeling_items
|
||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||
WHERE id = $1
|
||||
""",
|
||||
item_id, ocr_text, ocr_confidence, ocr_model
|
||||
)
|
||||
|
||||
return {
|
||||
"item_id": item_id,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"ocr_model": ocr_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/exports")
|
||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||
"""
|
||||
List all available training data exports.
|
||||
|
||||
Args:
|
||||
export_format: Optional filter by format (generic, trocr, llama_vision)
|
||||
|
||||
Returns:
|
||||
List of export manifests with paths and metadata
|
||||
"""
|
||||
if not TRAINING_EXPORT_AVAILABLE:
|
||||
return {
|
||||
"exports": [],
|
||||
"message": "Training export service not available",
|
||||
}
|
||||
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
exports = export_service.list_exports(export_format=export_format)
|
||||
|
||||
return {
|
||||
"count": len(exports),
|
||||
"exports": exports,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||
Reference in New Issue
Block a user