Restructure: Move ocr_pipeline + labeling + crop into ocr/ package
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -47,6 +47,7 @@
|
|||||||
|
|
||||||
# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context
|
# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context
|
||||||
**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01
|
**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01
|
||||||
|
**/ocr/pipeline/auto_steps.py | owner=klausur | reason=Same file moved to ocr/ package | review=2026-10-01
|
||||||
|
|
||||||
# Legacy — TEMPORAER bis Refactoring abgeschlossen
|
# Legacy — TEMPORAER bis Refactoring abgeschlossen
|
||||||
# Dateien hier werden Phase fuer Phase abgearbeitet und entfernt.
|
# Dateien hier werden Phase fuer Phase abgearbeitet und entfernt.
|
||||||
|
|||||||
@@ -1,290 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/crop_api.py
|
||||||
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Auto-crop, manual crop, and skip-crop for scanner/book borders.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.crop_api")
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from page_crop import detect_and_crop_page, detect_page_splits
|
|
||||||
from ocr_pipeline_session_store import get_sub_sessions, update_session_db
|
|
||||||
|
|
||||||
from orientation_crop_helpers import ensure_cached, append_pipeline_log
|
|
||||||
from page_sub_sessions import create_page_sub_sessions
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/crop")
|
|
||||||
async def auto_crop(session_id: str):
|
|
||||||
"""Auto-detect and crop scanner/book borders.
|
|
||||||
|
|
||||||
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
|
|
||||||
Falls back to oriented -> original if earlier steps were skipped.
|
|
||||||
|
|
||||||
If the image is a multi-page spread (e.g. book on scanner), it will
|
|
||||||
automatically split into separate sub-sessions per page, crop each
|
|
||||||
individually, and return the split info.
|
|
||||||
"""
|
|
||||||
cached = await ensure_cached(session_id)
|
|
||||||
|
|
||||||
# Use dewarped (preferred), fall back to oriented, then original
|
|
||||||
img_bgr = next(
|
|
||||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# --- Check for existing sub-sessions (from page-split step) ---
|
|
||||||
# If page-split already created sub-sessions, skip multi-page detection
|
|
||||||
# in the crop step. Each sub-session runs its own crop independently.
|
|
||||||
existing_subs = await get_sub_sessions(session_id)
|
|
||||||
if existing_subs:
|
|
||||||
crop_result = cached.get("crop_result") or {}
|
|
||||||
if crop_result.get("multi_page"):
|
|
||||||
# Already split -- just return the existing info
|
|
||||||
duration = time.time() - t0
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**crop_result,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"sub_sessions": [
|
|
||||||
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
|
|
||||||
for i, s in enumerate(existing_subs)
|
|
||||||
],
|
|
||||||
"note": "Page split was already performed; each sub-session runs its own crop.",
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
|
|
||||||
page_splits = detect_page_splits(img_bgr)
|
|
||||||
|
|
||||||
if page_splits and len(page_splits) >= 2:
|
|
||||||
# Multi-page spread detected -- create sub-sessions
|
|
||||||
sub_sessions = await create_page_sub_sessions(
|
|
||||||
session_id, cached, img_bgr, page_splits,
|
|
||||||
)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
crop_info: Dict[str, Any] = {
|
|
||||||
"crop_applied": True,
|
|
||||||
"multi_page": True,
|
|
||||||
"page_count": len(page_splits),
|
|
||||||
"page_splits": page_splits,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
cached["crop_result"] = crop_info
|
|
||||||
|
|
||||||
# Store the first page as the main cropped image for backward compat
|
|
||||||
first_page = page_splits[0]
|
|
||||||
first_bgr = img_bgr[
|
|
||||||
first_page["y"]:first_page["y"] + first_page["height"],
|
|
||||||
first_page["x"]:first_page["x"] + first_page["width"],
|
|
||||||
].copy()
|
|
||||||
first_cropped, _ = detect_and_crop_page(first_bgr)
|
|
||||||
cached["cropped_bgr"] = first_cropped
|
|
||||||
|
|
||||||
ok, png_buf = cv2.imencode(".png", first_cropped)
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
cropped_png=png_buf.tobytes() if ok else b"",
|
|
||||||
crop_result=crop_info,
|
|
||||||
current_step=5,
|
|
||||||
status='split',
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
|
|
||||||
session_id, len(page_splits), duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
await append_pipeline_log(session_id, "crop", {
|
|
||||||
"multi_page": True,
|
|
||||||
"page_count": len(page_splits),
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
h, w = first_cropped.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**crop_info,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
|
||||||
"sub_sessions": sub_sessions,
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- Single page (normal) ---
|
|
||||||
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
crop_info["duration_seconds"] = round(duration, 2)
|
|
||||||
crop_info["multi_page"] = False
|
|
||||||
|
|
||||||
# Encode cropped image
|
|
||||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
|
||||||
cropped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["cropped_bgr"] = cropped_bgr
|
|
||||||
cached["crop_result"] = crop_info
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
cropped_png=cropped_png,
|
|
||||||
crop_result=crop_info,
|
|
||||||
current_step=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
|
|
||||||
session_id, crop_info["crop_applied"],
|
|
||||||
crop_info.get("detected_format", "?"),
|
|
||||||
duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
await append_pipeline_log(session_id, "crop", {
|
|
||||||
"crop_applied": crop_info["crop_applied"],
|
|
||||||
"detected_format": crop_info.get("detected_format"),
|
|
||||||
"format_confidence": crop_info.get("format_confidence"),
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
h, w = cropped_bgr.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**crop_info,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ManualCropRequest(BaseModel):
|
|
||||||
x: float # percentage 0-100
|
|
||||||
y: float # percentage 0-100
|
|
||||||
width: float # percentage 0-100
|
|
||||||
height: float # percentage 0-100
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/crop/manual")
|
|
||||||
async def manual_crop(session_id: str, req: ManualCropRequest):
|
|
||||||
"""Manually crop using percentage coordinates."""
|
|
||||||
cached = await ensure_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = next(
|
|
||||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# Convert percentages to pixels
|
|
||||||
px_x = int(w * req.x / 100.0)
|
|
||||||
px_y = int(h * req.y / 100.0)
|
|
||||||
px_w = int(w * req.width / 100.0)
|
|
||||||
px_h = int(h * req.height / 100.0)
|
|
||||||
|
|
||||||
# Clamp
|
|
||||||
px_x = max(0, min(px_x, w - 1))
|
|
||||||
px_y = max(0, min(px_y, h - 1))
|
|
||||||
px_w = max(1, min(px_w, w - px_x))
|
|
||||||
px_h = max(1, min(px_h, h - px_y))
|
|
||||||
|
|
||||||
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
|
|
||||||
|
|
||||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
|
||||||
cropped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
crop_result = {
|
|
||||||
"crop_applied": True,
|
|
||||||
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
|
|
||||||
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
|
|
||||||
"width": round(req.width, 2), "height": round(req.height, 2)},
|
|
||||||
"original_size": {"width": w, "height": h},
|
|
||||||
"cropped_size": {"width": px_w, "height": px_h},
|
|
||||||
"method": "manual",
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["cropped_bgr"] = cropped_bgr
|
|
||||||
cached["crop_result"] = crop_result
|
|
||||||
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
cropped_png=cropped_png,
|
|
||||||
crop_result=crop_result,
|
|
||||||
current_step=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
ch, cw = cropped_bgr.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**crop_result,
|
|
||||||
"image_width": cw,
|
|
||||||
"image_height": ch,
|
|
||||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/crop/skip")
|
|
||||||
async def skip_crop(session_id: str):
|
|
||||||
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
|
|
||||||
cached = await ensure_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = next(
|
|
||||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available")
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# Store the dewarped image as cropped (identity crop)
|
|
||||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
|
||||||
cropped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
crop_result = {
|
|
||||||
"crop_applied": False,
|
|
||||||
"skipped": True,
|
|
||||||
"original_size": {"width": w, "height": h},
|
|
||||||
"cropped_size": {"width": w, "height": h},
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["cropped_bgr"] = img_bgr
|
|
||||||
cached["crop_result"] = crop_result
|
|
||||||
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
cropped_png=cropped_png,
|
|
||||||
crop_result=crop_result,
|
|
||||||
current_step=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**crop_result,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Backward-compat shim -- module moved to ocr\/pipeline.py
|
# Backward-compat shim -- module moved to ocr/cv_pipeline.py
|
||||||
import importlib as _importlib
|
import importlib as _importlib
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline")
|
_sys.modules[__name__] = _importlib.import_module("ocr.cv_pipeline")
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ Backward-compatible re-exports: consumers can still use
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .types import * # noqa: F401,F403
|
from .types import * # noqa: F401,F403
|
||||||
from .pipeline import * # noqa: F401,F403
|
from .cv_pipeline import * # noqa: F401,F403
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling sub-package — labeling API, models, helpers, and route handlers.
|
||||||
|
|
||||||
|
Moved from backend/ flat modules (ocr_labeling_*.py).
|
||||||
|
Backward-compatible shim files remain at the old locations.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling API — Barrel Re-export
|
||||||
|
|
||||||
|
Split into:
|
||||||
|
- ocr_labeling_models.py — Pydantic models and constants
|
||||||
|
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
|
||||||
|
- ocr_labeling_routes.py — Session/queue/labeling route handlers
|
||||||
|
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
|
||||||
|
|
||||||
|
All public names are re-exported here for backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Models
|
||||||
|
from .models import ( # noqa: F401
|
||||||
|
LOCAL_STORAGE_PATH,
|
||||||
|
SessionCreate,
|
||||||
|
SessionResponse,
|
||||||
|
ItemResponse,
|
||||||
|
ConfirmRequest,
|
||||||
|
CorrectRequest,
|
||||||
|
SkipRequest,
|
||||||
|
ExportRequest,
|
||||||
|
StatsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
from .helpers import ( # noqa: F401
|
||||||
|
VISION_OCR_AVAILABLE,
|
||||||
|
PADDLEOCR_AVAILABLE,
|
||||||
|
TROCR_AVAILABLE,
|
||||||
|
DONUT_AVAILABLE,
|
||||||
|
MINIO_AVAILABLE,
|
||||||
|
TRAINING_EXPORT_AVAILABLE,
|
||||||
|
compute_image_hash,
|
||||||
|
run_ocr_on_image,
|
||||||
|
run_vision_ocr_wrapper,
|
||||||
|
run_paddleocr_wrapper,
|
||||||
|
run_trocr_wrapper,
|
||||||
|
run_donut_wrapper,
|
||||||
|
save_image_locally,
|
||||||
|
get_image_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conditional re-exports from helpers' optional imports
|
||||||
|
try:
|
||||||
|
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from training_export_service import ( # noqa: F401
|
||||||
|
TrainingExportService,
|
||||||
|
TrainingSample,
|
||||||
|
get_training_export_service,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from services.trocr_service import run_trocr_ocr # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from services.donut_ocr_service import run_donut_ocr # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Routes (router is the main export for app.include_router)
|
||||||
|
from .routes import router # noqa: F401
|
||||||
|
from .upload_routes import router as upload_router # noqa: F401
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling - Helper Functions and OCR Wrappers
|
||||||
|
|
||||||
|
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||||
|
|
||||||
|
DATENSCHUTZ/PRIVACY:
|
||||||
|
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||||
|
- Keine Daten werden an externe Server gesendet
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from .models import LOCAL_STORAGE_PATH
|
||||||
|
|
||||||
|
# Try to import Vision OCR service
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||||
|
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||||
|
VISION_OCR_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VISION_OCR_AVAILABLE = False
|
||||||
|
print("Warning: Vision OCR service not available")
|
||||||
|
|
||||||
|
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||||
|
try:
|
||||||
|
from hybrid_vocab_extractor import run_paddle_ocr
|
||||||
|
PADDLEOCR_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PADDLEOCR_AVAILABLE = False
|
||||||
|
print("Warning: PaddleOCR not available")
|
||||||
|
|
||||||
|
# Try to import TrOCR service
|
||||||
|
try:
|
||||||
|
from services.trocr_service import run_trocr_ocr
|
||||||
|
TROCR_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TROCR_AVAILABLE = False
|
||||||
|
print("Warning: TrOCR service not available")
|
||||||
|
|
||||||
|
# Try to import Donut service
|
||||||
|
try:
|
||||||
|
from services.donut_ocr_service import run_donut_ocr
|
||||||
|
DONUT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
DONUT_AVAILABLE = False
|
||||||
|
print("Warning: Donut OCR service not available")
|
||||||
|
|
||||||
|
# Try to import MinIO storage
|
||||||
|
try:
|
||||||
|
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||||
|
MINIO_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MINIO_AVAILABLE = False
|
||||||
|
print("Warning: MinIO storage not available, using local storage")
|
||||||
|
|
||||||
|
# Try to import Training Export Service
|
||||||
|
try:
|
||||||
|
from training_export_service import (
|
||||||
|
TrainingExportService,
|
||||||
|
TrainingSample,
|
||||||
|
get_training_export_service,
|
||||||
|
)
|
||||||
|
TRAINING_EXPORT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TRAINING_EXPORT_AVAILABLE = False
|
||||||
|
print("Warning: Training export service not available")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def compute_image_hash(image_data: bytes) -> str:
|
||||||
|
"""Compute SHA256 hash of image data."""
|
||||||
|
return hashlib.sha256(image_data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||||
|
"""
|
||||||
|
Run OCR on an image using the specified model.
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||||
|
- trocr: Microsoft TrOCR (fast for printed text)
|
||||||
|
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||||
|
- donut: Document Understanding Transformer (structured documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ocr_text, confidence)
|
||||||
|
"""
|
||||||
|
print(f"Running OCR with model: {model}")
|
||||||
|
|
||||||
|
# Route to appropriate OCR service based on model
|
||||||
|
if model == "paddleocr":
|
||||||
|
return await run_paddleocr_wrapper(image_data, filename)
|
||||||
|
elif model == "donut":
|
||||||
|
return await run_donut_wrapper(image_data, filename)
|
||||||
|
elif model == "trocr":
|
||||||
|
return await run_trocr_wrapper(image_data, filename)
|
||||||
|
else:
|
||||||
|
# Default: Vision LLM (llama3.2-vision or similar)
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||||
|
"""Vision LLM OCR wrapper."""
|
||||||
|
if not VISION_OCR_AVAILABLE:
|
||||||
|
print("Vision OCR service not available")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = get_vision_ocr_service()
|
||||||
|
if not await service.is_available():
|
||||||
|
print("Vision OCR service not available (is_available check failed)")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
result = await service.extract_text(
|
||||||
|
image_data,
|
||||||
|
filename=filename,
|
||||||
|
is_handwriting=True
|
||||||
|
)
|
||||||
|
return result.text, result.confidence
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Vision OCR failed: {e}")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||||
|
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||||
|
if not PADDLEOCR_AVAILABLE:
|
||||||
|
print("PaddleOCR not available, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# run_paddle_ocr returns (regions, raw_text)
|
||||||
|
regions, raw_text = run_paddle_ocr(image_data)
|
||||||
|
|
||||||
|
if not raw_text:
|
||||||
|
print("PaddleOCR returned empty text")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# Calculate average confidence from regions
|
||||||
|
if regions:
|
||||||
|
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||||
|
else:
|
||||||
|
avg_confidence = 0.5
|
||||||
|
|
||||||
|
return raw_text, avg_confidence
|
||||||
|
except Exception as e:
|
||||||
|
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||||
|
"""TrOCR wrapper."""
|
||||||
|
if not TROCR_AVAILABLE:
|
||||||
|
print("TrOCR not available, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text, confidence = await run_trocr_ocr(image_data)
|
||||||
|
return text, confidence
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||||
|
"""Donut OCR wrapper."""
|
||||||
|
if not DONUT_AVAILABLE:
|
||||||
|
print("Donut not available, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text, confidence = await run_donut_ocr(image_data)
|
||||||
|
return text, confidence
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||||
|
return await run_vision_ocr_wrapper(image_data, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||||
|
"""Save image to local storage."""
|
||||||
|
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||||
|
os.makedirs(session_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename = f"{item_id}.{extension}"
|
||||||
|
filepath = os.path.join(session_dir, filename)
|
||||||
|
|
||||||
|
with open(filepath, 'wb') as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_url(image_path: str) -> str:
|
||||||
|
"""Get URL for an image."""
|
||||||
|
# For local images, return a relative path that the frontend can use
|
||||||
|
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||||
|
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||||
|
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||||
|
# For MinIO images, the path is already a URL or key
|
||||||
|
return image_path
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling - Pydantic Models and Constants
|
||||||
|
|
||||||
|
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
# Local storage path (fallback if MinIO not available)
|
||||||
|
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Pydantic Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class SessionCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||||
|
description: Optional[str] = None
|
||||||
|
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||||
|
|
||||||
|
|
||||||
|
class SessionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
source_type: str
|
||||||
|
description: Optional[str]
|
||||||
|
ocr_model: Optional[str]
|
||||||
|
total_items: int
|
||||||
|
labeled_items: int
|
||||||
|
confirmed_items: int
|
||||||
|
corrected_items: int
|
||||||
|
skipped_items: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ItemResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
session_id: str
|
||||||
|
session_name: str
|
||||||
|
image_path: str
|
||||||
|
image_url: Optional[str]
|
||||||
|
ocr_text: Optional[str]
|
||||||
|
ocr_confidence: Optional[float]
|
||||||
|
ground_truth: Optional[str]
|
||||||
|
status: str
|
||||||
|
metadata: Optional[Dict]
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ConfirmRequest(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
label_time_seconds: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectRequest(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
ground_truth: str
|
||||||
|
label_time_seconds: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SkipRequest(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExportRequest(BaseModel):
|
||||||
|
export_format: str = "generic" # generic, trocr, llama_vision
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
batch_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StatsResponse(BaseModel):
|
||||||
|
total_sessions: Optional[int] = None
|
||||||
|
total_items: int
|
||||||
|
labeled_items: int
|
||||||
|
confirmed_items: int
|
||||||
|
corrected_items: int
|
||||||
|
pending_items: int
|
||||||
|
exportable_items: Optional[int] = None
|
||||||
|
accuracy_rate: float
|
||||||
|
avg_label_time_seconds: Optional[float] = None
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling - Session and Labeling Route Handlers
|
||||||
|
|
||||||
|
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- POST /sessions - Create labeling session
|
||||||
|
- GET /sessions - List sessions
|
||||||
|
- GET /sessions/{id} - Get session
|
||||||
|
- GET /queue - Get labeling queue
|
||||||
|
- GET /items/{id} - Get item
|
||||||
|
- POST /confirm - Confirm OCR
|
||||||
|
- POST /correct - Correct ground truth
|
||||||
|
- POST /skip - Skip item
|
||||||
|
- GET /stats - Get statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from metrics_db import (
|
||||||
|
create_ocr_labeling_session,
|
||||||
|
get_ocr_labeling_sessions,
|
||||||
|
get_ocr_labeling_session,
|
||||||
|
get_ocr_labeling_queue,
|
||||||
|
get_ocr_labeling_item,
|
||||||
|
confirm_ocr_label,
|
||||||
|
correct_ocr_label,
|
||||||
|
skip_ocr_item,
|
||||||
|
get_ocr_labeling_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
SessionCreate, SessionResponse, ItemResponse,
|
||||||
|
ConfirmRequest, CorrectRequest, SkipRequest,
|
||||||
|
)
|
||||||
|
from .helpers import get_image_url
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Session Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.post("/sessions", response_model=SessionResponse)
|
||||||
|
async def create_session(session: SessionCreate):
|
||||||
|
"""Create a new OCR labeling session."""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
success = await create_ocr_labeling_session(
|
||||||
|
session_id=session_id,
|
||||||
|
name=session.name,
|
||||||
|
source_type=session.source_type,
|
||||||
|
description=session.description,
|
||||||
|
ocr_model=session.ocr_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||||
|
|
||||||
|
return SessionResponse(
|
||||||
|
id=session_id,
|
||||||
|
name=session.name,
|
||||||
|
source_type=session.source_type,
|
||||||
|
description=session.description,
|
||||||
|
ocr_model=session.ocr_model,
|
||||||
|
total_items=0,
|
||||||
|
labeled_items=0,
|
||||||
|
confirmed_items=0,
|
||||||
|
corrected_items=0,
|
||||||
|
skipped_items=0,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions", response_model=List[SessionResponse])
|
||||||
|
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||||
|
"""List all OCR labeling sessions."""
|
||||||
|
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||||
|
|
||||||
|
return [
|
||||||
|
SessionResponse(
|
||||||
|
id=s['id'],
|
||||||
|
name=s['name'],
|
||||||
|
source_type=s['source_type'],
|
||||||
|
description=s.get('description'),
|
||||||
|
ocr_model=s.get('ocr_model'),
|
||||||
|
total_items=s.get('total_items', 0),
|
||||||
|
labeled_items=s.get('labeled_items', 0),
|
||||||
|
confirmed_items=s.get('confirmed_items', 0),
|
||||||
|
corrected_items=s.get('corrected_items', 0),
|
||||||
|
skipped_items=s.get('skipped_items', 0),
|
||||||
|
created_at=s.get('created_at', datetime.utcnow()),
|
||||||
|
)
|
||||||
|
for s in sessions
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||||
|
async def get_session(session_id: str):
|
||||||
|
"""Get a specific OCR labeling session."""
|
||||||
|
session = await get_ocr_labeling_session(session_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
return SessionResponse(
|
||||||
|
id=session['id'],
|
||||||
|
name=session['name'],
|
||||||
|
source_type=session['source_type'],
|
||||||
|
description=session.get('description'),
|
||||||
|
ocr_model=session.get('ocr_model'),
|
||||||
|
total_items=session.get('total_items', 0),
|
||||||
|
labeled_items=session.get('labeled_items', 0),
|
||||||
|
confirmed_items=session.get('confirmed_items', 0),
|
||||||
|
corrected_items=session.get('corrected_items', 0),
|
||||||
|
skipped_items=session.get('skipped_items', 0),
|
||||||
|
created_at=session.get('created_at', datetime.utcnow()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Queue and Item Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/queue", response_model=List[ItemResponse])
|
||||||
|
async def get_labeling_queue(
|
||||||
|
session_id: Optional[str] = Query(None),
|
||||||
|
status: str = Query("pending"),
|
||||||
|
limit: int = Query(10, ge=1, le=50),
|
||||||
|
):
|
||||||
|
"""Get items from the labeling queue."""
|
||||||
|
items = await get_ocr_labeling_queue(
|
||||||
|
session_id=session_id,
|
||||||
|
status=status,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ItemResponse(
|
||||||
|
id=item['id'],
|
||||||
|
session_id=item['session_id'],
|
||||||
|
session_name=item.get('session_name', ''),
|
||||||
|
image_path=item['image_path'],
|
||||||
|
image_url=get_image_url(item['image_path']),
|
||||||
|
ocr_text=item.get('ocr_text'),
|
||||||
|
ocr_confidence=item.get('ocr_confidence'),
|
||||||
|
ground_truth=item.get('ground_truth'),
|
||||||
|
status=item.get('status', 'pending'),
|
||||||
|
metadata=item.get('metadata'),
|
||||||
|
created_at=item.get('created_at', datetime.utcnow()),
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||||
|
async def get_item(item_id: str):
|
||||||
|
"""Get a specific labeling item."""
|
||||||
|
item = await get_ocr_labeling_item(item_id)
|
||||||
|
|
||||||
|
if not item:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
|
||||||
|
return ItemResponse(
|
||||||
|
id=item['id'],
|
||||||
|
session_id=item['session_id'],
|
||||||
|
session_name=item.get('session_name', ''),
|
||||||
|
image_path=item['image_path'],
|
||||||
|
image_url=get_image_url(item['image_path']),
|
||||||
|
ocr_text=item.get('ocr_text'),
|
||||||
|
ocr_confidence=item.get('ocr_confidence'),
|
||||||
|
ground_truth=item.get('ground_truth'),
|
||||||
|
status=item.get('status', 'pending'),
|
||||||
|
metadata=item.get('metadata'),
|
||||||
|
created_at=item.get('created_at', datetime.utcnow()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Labeling Action Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.post("/confirm")
|
||||||
|
async def confirm_item(request: ConfirmRequest):
|
||||||
|
"""Confirm that OCR text is correct."""
|
||||||
|
success = await confirm_ocr_label(
|
||||||
|
item_id=request.item_id,
|
||||||
|
labeled_by="admin",
|
||||||
|
label_time_seconds=request.label_time_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||||
|
|
||||||
|
return {"status": "confirmed", "item_id": request.item_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/correct")
|
||||||
|
async def correct_item(request: CorrectRequest):
|
||||||
|
"""Save corrected ground truth for an item."""
|
||||||
|
success = await correct_ocr_label(
|
||||||
|
item_id=request.item_id,
|
||||||
|
ground_truth=request.ground_truth,
|
||||||
|
labeled_by="admin",
|
||||||
|
label_time_seconds=request.label_time_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||||
|
|
||||||
|
return {"status": "corrected", "item_id": request.item_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/skip")
|
||||||
|
async def skip_item(request: SkipRequest):
|
||||||
|
"""Skip an item (unusable image, etc.)."""
|
||||||
|
success = await skip_ocr_item(
|
||||||
|
item_id=request.item_id,
|
||||||
|
labeled_by="admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||||
|
|
||||||
|
return {"status": "skipped", "item_id": request.item_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats")
|
||||||
|
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||||
|
"""Get labeling statistics."""
|
||||||
|
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||||
|
|
||||||
|
if "error" in stats:
|
||||||
|
raise HTTPException(status_code=500, detail=stats["error"])
|
||||||
|
|
||||||
|
return stats
|
||||||
@@ -0,0 +1,313 @@
|
|||||||
|
"""
|
||||||
|
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
|
||||||
|
|
||||||
|
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- POST /sessions/{id}/upload - Upload images for labeling
|
||||||
|
- POST /run-ocr/{item_id} - Run OCR on existing item
|
||||||
|
- POST /export - Export training data
|
||||||
|
- GET /training-samples - List training samples
|
||||||
|
- GET /images/{path} - Serve images from local storage
|
||||||
|
- GET /exports - List exports
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||||
|
from typing import Optional, List
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
from metrics_db import (
|
||||||
|
get_ocr_labeling_session,
|
||||||
|
add_ocr_labeling_item,
|
||||||
|
get_ocr_labeling_item,
|
||||||
|
export_training_samples,
|
||||||
|
get_training_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
ExportRequest,
|
||||||
|
LOCAL_STORAGE_PATH,
|
||||||
|
)
|
||||||
|
from .helpers import (
|
||||||
|
compute_image_hash, run_ocr_on_image,
|
||||||
|
save_image_locally,
|
||||||
|
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conditional imports
|
||||||
|
try:
|
||||||
|
from minio_storage import upload_ocr_image, get_ocr_image
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from training_export_service import TrainingSample, get_training_export_service
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/upload")
|
||||||
|
async def upload_images(
|
||||||
|
session_id: str,
|
||||||
|
files: List[UploadFile] = File(...),
|
||||||
|
run_ocr: bool = Form(True),
|
||||||
|
metadata: Optional[str] = Form(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload images to a labeling session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session to add images to
|
||||||
|
files: Image files to upload (PNG, JPG, PDF)
|
||||||
|
run_ocr: Whether to run OCR immediately (default: True)
|
||||||
|
metadata: Optional JSON metadata (subject, year, etc.)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
session = await get_ocr_labeling_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
meta_dict = None
|
||||||
|
if metadata:
|
||||||
|
try:
|
||||||
|
meta_dict = json.loads(metadata)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta_dict = {"raw": metadata}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
content = await file.read()
|
||||||
|
image_hash = compute_image_hash(content)
|
||||||
|
item_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||||
|
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||||
|
extension = 'png'
|
||||||
|
|
||||||
|
if MINIO_AVAILABLE:
|
||||||
|
try:
|
||||||
|
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MinIO upload failed, using local storage: {e}")
|
||||||
|
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||||
|
else:
|
||||||
|
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||||
|
|
||||||
|
ocr_text = None
|
||||||
|
ocr_confidence = None
|
||||||
|
|
||||||
|
if run_ocr and extension != 'pdf':
|
||||||
|
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||||
|
content,
|
||||||
|
file.filename or f"{item_id}.{extension}",
|
||||||
|
model=ocr_model
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await add_ocr_labeling_item(
|
||||||
|
item_id=item_id,
|
||||||
|
session_id=session_id,
|
||||||
|
image_path=image_path,
|
||||||
|
image_hash=image_hash,
|
||||||
|
ocr_text=ocr_text,
|
||||||
|
ocr_confidence=ocr_confidence,
|
||||||
|
ocr_model=ocr_model if ocr_text else None,
|
||||||
|
metadata=meta_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
results.append({
|
||||||
|
"id": item_id,
|
||||||
|
"filename": file.filename,
|
||||||
|
"image_path": image_path,
|
||||||
|
"image_hash": image_hash,
|
||||||
|
"ocr_text": ocr_text,
|
||||||
|
"ocr_confidence": ocr_confidence,
|
||||||
|
"status": "pending",
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"uploaded_count": len(results),
|
||||||
|
"items": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/export")
|
||||||
|
async def export_data(request: ExportRequest):
|
||||||
|
"""Export labeled data for training."""
|
||||||
|
db_samples = await export_training_samples(
|
||||||
|
export_format=request.export_format,
|
||||||
|
session_id=request.session_id,
|
||||||
|
batch_id=request.batch_id,
|
||||||
|
exported_by="admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not db_samples:
|
||||||
|
return {
|
||||||
|
"export_format": request.export_format,
|
||||||
|
"batch_id": request.batch_id,
|
||||||
|
"exported_count": 0,
|
||||||
|
"samples": [],
|
||||||
|
"message": "No labeled samples found to export",
|
||||||
|
}
|
||||||
|
|
||||||
|
export_result = None
|
||||||
|
if TRAINING_EXPORT_AVAILABLE:
|
||||||
|
try:
|
||||||
|
export_service = get_training_export_service()
|
||||||
|
|
||||||
|
training_samples = []
|
||||||
|
for s in db_samples:
|
||||||
|
training_samples.append(TrainingSample(
|
||||||
|
id=s.get('id', s.get('item_id', '')),
|
||||||
|
image_path=s.get('image_path', ''),
|
||||||
|
ground_truth=s.get('ground_truth', ''),
|
||||||
|
ocr_text=s.get('ocr_text'),
|
||||||
|
ocr_confidence=s.get('ocr_confidence'),
|
||||||
|
metadata=s.get('metadata'),
|
||||||
|
))
|
||||||
|
|
||||||
|
export_result = export_service.export(
|
||||||
|
samples=training_samples,
|
||||||
|
export_format=request.export_format,
|
||||||
|
batch_id=request.batch_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Training export failed: {e}")
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"export_format": request.export_format,
|
||||||
|
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||||
|
"exported_count": len(db_samples),
|
||||||
|
"samples": db_samples,
|
||||||
|
}
|
||||||
|
|
||||||
|
if export_result:
|
||||||
|
response["export_path"] = export_result.export_path
|
||||||
|
response["manifest_path"] = export_result.manifest_path
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/training-samples")
|
||||||
|
async def list_training_samples(
|
||||||
|
export_format: Optional[str] = Query(None),
|
||||||
|
batch_id: Optional[str] = Query(None),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
):
|
||||||
|
"""Get exported training samples."""
|
||||||
|
samples = await get_training_samples(
|
||||||
|
export_format=export_format,
|
||||||
|
batch_id=batch_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"count": len(samples),
|
||||||
|
"samples": samples,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/images/{path:path}")
|
||||||
|
async def get_image(path: str):
|
||||||
|
"""Serve an image from local storage."""
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||||
|
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
extension = filepath.split('.')[-1].lower()
|
||||||
|
content_type = {
|
||||||
|
'png': 'image/png',
|
||||||
|
'jpg': 'image/jpeg',
|
||||||
|
'jpeg': 'image/jpeg',
|
||||||
|
'pdf': 'application/pdf',
|
||||||
|
}.get(extension, 'application/octet-stream')
|
||||||
|
|
||||||
|
return FileResponse(filepath, media_type=content_type)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/run-ocr/{item_id}")
|
||||||
|
async def run_ocr_for_item(item_id: str):
|
||||||
|
"""Run OCR on an existing item."""
|
||||||
|
item = await get_ocr_labeling_item(item_id)
|
||||||
|
|
||||||
|
if not item:
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
|
||||||
|
image_path = item['image_path']
|
||||||
|
|
||||||
|
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Image file not found")
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
image_data = f.read()
|
||||||
|
elif MINIO_AVAILABLE:
|
||||||
|
try:
|
||||||
|
image_data = get_ocr_image(image_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||||
|
|
||||||
|
session = await get_ocr_labeling_session(item['session_id'])
|
||||||
|
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||||
|
|
||||||
|
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||||
|
image_data,
|
||||||
|
os.path.basename(image_path),
|
||||||
|
model=ocr_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if ocr_text is None:
|
||||||
|
raise HTTPException(status_code=500, detail="OCR failed")
|
||||||
|
|
||||||
|
from metrics_db import get_pool
|
||||||
|
pool = await get_pool()
|
||||||
|
if pool:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE ocr_labeling_items
|
||||||
|
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||||
|
WHERE id = $1
|
||||||
|
""",
|
||||||
|
item_id, ocr_text, ocr_confidence, ocr_model
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"item_id": item_id,
|
||||||
|
"ocr_text": ocr_text,
|
||||||
|
"ocr_confidence": ocr_confidence,
|
||||||
|
"ocr_model": ocr_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/exports")
|
||||||
|
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||||
|
"""List all available training data exports."""
|
||||||
|
if not TRAINING_EXPORT_AVAILABLE:
|
||||||
|
return {
|
||||||
|
"exports": [],
|
||||||
|
"message": "Training export service not available",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
export_service = get_training_export_service()
|
||||||
|
exports = export_service.list_exports(export_format=export_format)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"count": len(exports),
|
||||||
|
"exports": exports,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline sub-package — API endpoints, session management, overlays,
|
||||||
|
geometry steps, word detection, regression testing, and related utilities.
|
||||||
|
|
||||||
|
Moved from backend/ flat modules (ocr_pipeline_*.py, page_crop*.py,
|
||||||
|
orientation_*.py, crop_api.py, etc.).
|
||||||
|
Backward-compatible shim files remain at the old locations.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
||||||
|
|
||||||
|
Thin wrapper that assembles all sub-module routers into a single
|
||||||
|
composite router. Backward-compatible: main.py and tests can still
|
||||||
|
import ``router``, ``_cache``, and helper functions from here.
|
||||||
|
|
||||||
|
Sub-modules (each < 1 000 lines):
|
||||||
|
ocr_pipeline_common – shared state, cache, Pydantic models, helpers
|
||||||
|
ocr_pipeline_sessions – session CRUD, image serving, doc-type
|
||||||
|
ocr_pipeline_geometry – deskew, dewarp, structure, columns
|
||||||
|
ocr_pipeline_rows – row detection, box-overlay helper
|
||||||
|
ocr_pipeline_words – word detection (SSE), paddle-direct, word GT
|
||||||
|
ocr_pipeline_ocr_merge – paddle/tesseract merge helpers, kombi endpoints
|
||||||
|
ocr_pipeline_postprocess – LLM review, reconstruction, export, validation
|
||||||
|
ocr_pipeline_auto – auto-mode orchestrator, reprocess
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared state (imported by main.py and orientation_crop_api.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
from .common import ( # noqa: F401 – re-exported
|
||||||
|
_cache,
|
||||||
|
_BORDER_GHOST_CHARS,
|
||||||
|
_filter_border_ghost_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sub-module routers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
from .sessions import router as _sessions_router
|
||||||
|
from .geometry import router as _geometry_router
|
||||||
|
from .rows import router as _rows_router
|
||||||
|
from .words import router as _words_router
|
||||||
|
from .ocr_merge import (
|
||||||
|
router as _ocr_merge_router,
|
||||||
|
# Re-export for test backward compatibility
|
||||||
|
_split_paddle_multi_words, # noqa: F401
|
||||||
|
_group_words_into_rows, # noqa: F401
|
||||||
|
_merge_row_sequences, # noqa: F401
|
||||||
|
_merge_paddle_tesseract, # noqa: F401
|
||||||
|
)
|
||||||
|
from .postprocess import router as _postprocess_router
|
||||||
|
from .auto import router as _auto_router
|
||||||
|
from .regression import router as _regression_router
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Composite router (used by main.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_sessions_router)
|
||||||
|
router.include_router(_geometry_router)
|
||||||
|
router.include_router(_rows_router)
|
||||||
|
router.include_router(_words_router)
|
||||||
|
router.include_router(_ocr_merge_router)
|
||||||
|
router.include_router(_postprocess_router)
|
||||||
|
router.include_router(_auto_router)
|
||||||
|
router.include_router(_regression_router)
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
|
||||||
|
|
||||||
|
Split into submodules:
|
||||||
|
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
|
||||||
|
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .reprocess import router as _reprocess_router
|
||||||
|
from .auto_steps import router as _steps_router
|
||||||
|
|
||||||
|
# Combine both sub-routers into a single router for backwards compatibility.
|
||||||
|
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
router.include_router(_reprocess_router)
|
||||||
|
router.include_router(_steps_router)
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Auto-Mode Helpers.
|
||||||
|
|
||||||
|
VLM shear detection, SSE event formatting, and request models.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunAutoRequest(BaseModel):
|
||||||
|
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||||
|
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||||
|
pronunciation: str = "british"
|
||||||
|
skip_llm_review: bool = False
|
||||||
|
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||||
|
|
||||||
|
|
||||||
|
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||||
|
"""Format a single SSE event line."""
|
||||||
|
payload = {"step": step, "status": status, **data}
|
||||||
|
return f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||||
|
|
||||||
|
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||||
|
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||||
|
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
|
||||||
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||||
|
"Are they perfectly vertical, or do they tilt slightly? "
|
||||||
|
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||||
|
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||||
|
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||||
|
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
|
||||||
|
# Parse JSON from response (may have surrounding text)
|
||||||
|
match = re.search(r'\{[^}]+\}', text)
|
||||||
|
if match:
|
||||||
|
data = json.loads(match.group(0))
|
||||||
|
shear = float(data.get("shear_degrees", 0.0))
|
||||||
|
conf = float(data.get("confidence", 0.0))
|
||||||
|
# Clamp to reasonable range
|
||||||
|
shear = max(-3.0, min(3.0, shear))
|
||||||
|
conf = max(0.0, min(1.0, conf))
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"VLM dewarp failed: {e}")
|
||||||
|
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||||
@@ -0,0 +1,528 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Auto-Mode Orchestrator.
|
||||||
|
|
||||||
|
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
OLLAMA_REVIEW_MODEL,
|
||||||
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
|
_cells_to_vocab_entries,
|
||||||
|
_detect_header_footer_gaps,
|
||||||
|
_detect_sub_columns,
|
||||||
|
_fix_character_confusion,
|
||||||
|
_fix_phonetic_brackets,
|
||||||
|
fix_cell_phonetics,
|
||||||
|
analyze_layout,
|
||||||
|
build_cell_grid,
|
||||||
|
classify_column_types,
|
||||||
|
create_layout_image,
|
||||||
|
create_ocr_image,
|
||||||
|
deskew_image,
|
||||||
|
deskew_image_by_word_alignment,
|
||||||
|
detect_column_geometry,
|
||||||
|
detect_row_geometry,
|
||||||
|
_apply_shear,
|
||||||
|
dewarp_image,
|
||||||
|
llm_review_entries,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .auto_helpers import (
|
||||||
|
RunAutoRequest,
|
||||||
|
auto_sse_event as _auto_sse_event,
|
||||||
|
detect_shear_with_vlm as _detect_shear_with_vlm,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/run-auto")
|
||||||
|
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||||
|
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Deskew -- straighten the scan
|
||||||
|
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
|
||||||
|
3. Columns -- detect column layout
|
||||||
|
4. Rows -- detect row layout
|
||||||
|
5. Words -- OCR each cell
|
||||||
|
6. LLM review -- correct OCR errors (optional)
|
||||||
|
|
||||||
|
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||||
|
Yields SSE events of the form:
|
||||||
|
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||||
|
|
||||||
|
Final event:
|
||||||
|
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||||
|
"""
|
||||||
|
if req.from_step < 1 or req.from_step > 6:
|
||||||
|
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||||
|
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||||
|
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||||
|
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
|
||||||
|
async def _generate():
|
||||||
|
steps_run: List[str] = []
|
||||||
|
steps_skipped: List[str] = []
|
||||||
|
error_step: Optional[str] = None
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||||
|
return
|
||||||
|
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
# Step 1: Deskew
|
||||||
|
if req.from_step <= 1:
|
||||||
|
yield await _auto_sse_event("deskew", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
orig_bgr = cached.get("original_bgr")
|
||||||
|
if orig_bgr is None:
|
||||||
|
raise ValueError("Original image not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||||
|
except Exception:
|
||||||
|
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||||
|
|
||||||
|
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||||
|
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||||
|
try:
|
||||||
|
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||||
|
except Exception:
|
||||||
|
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||||
|
|
||||||
|
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||||
|
method_used = "word_alignment"
|
||||||
|
angle_applied = angle_wa
|
||||||
|
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||||
|
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
deskewed_bgr = deskewed_hough
|
||||||
|
method_used = "hough"
|
||||||
|
angle_applied = angle_hough
|
||||||
|
else:
|
||||||
|
method_used = "hough"
|
||||||
|
angle_applied = angle_hough
|
||||||
|
deskewed_bgr = deskewed_hough
|
||||||
|
|
||||||
|
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
deskewed_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
deskew_result = {
|
||||||
|
"method_used": method_used,
|
||||||
|
"rotation_degrees": round(float(angle_applied), 3),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["deskewed_bgr"] = deskewed_bgr
|
||||||
|
cached["deskew_result"] = deskew_result
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
deskewed_png=deskewed_png,
|
||||||
|
deskew_result=deskew_result,
|
||||||
|
auto_rotation_degrees=float(angle_applied),
|
||||||
|
current_step=3,
|
||||||
|
)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("deskew")
|
||||||
|
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||||
|
error_step = "deskew"
|
||||||
|
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("deskew")
|
||||||
|
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||||
|
|
||||||
|
# Step 2: Dewarp
|
||||||
|
if req.from_step <= 2:
|
||||||
|
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
deskewed_bgr = cached.get("deskewed_bgr")
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
raise ValueError("Deskewed image not available")
|
||||||
|
|
||||||
|
if req.dewarp_method == "vlm":
|
||||||
|
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||||
|
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||||
|
shear_deg = vlm_det["shear_degrees"]
|
||||||
|
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||||
|
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||||
|
else:
|
||||||
|
dewarped_bgr = deskewed_bgr
|
||||||
|
dewarp_info = {
|
||||||
|
"method": vlm_det["method"],
|
||||||
|
"shear_degrees": shear_deg,
|
||||||
|
"confidence": vlm_det["confidence"],
|
||||||
|
"detections": [vlm_det],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||||
|
|
||||||
|
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||||
|
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||||
|
|
||||||
|
dewarp_result = {
|
||||||
|
"method_used": dewarp_info["method"],
|
||||||
|
"shear_degrees": dewarp_info["shear_degrees"],
|
||||||
|
"confidence": dewarp_info["confidence"],
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
"detections": dewarp_info.get("detections", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["dewarped_bgr"] = dewarped_bgr
|
||||||
|
cached["dewarp_result"] = dewarp_result
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
dewarped_png=dewarped_png,
|
||||||
|
dewarp_result=dewarp_result,
|
||||||
|
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||||
|
current_step=4,
|
||||||
|
)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("dewarp")
|
||||||
|
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||||
|
error_step = "dewarp"
|
||||||
|
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("dewarp")
|
||||||
|
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||||
|
|
||||||
|
# Step 3: Columns
|
||||||
|
if req.from_step <= 3:
|
||||||
|
yield await _auto_sse_event("columns", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if col_img is None:
|
||||||
|
raise ValueError("Cropped/dewarped image not available")
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(col_img)
|
||||||
|
h, w = ocr_img.shape[:2]
|
||||||
|
|
||||||
|
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||||
|
if geo_result is None:
|
||||||
|
layout_img = create_layout_image(col_img)
|
||||||
|
regions = analyze_layout(layout_img, ocr_img)
|
||||||
|
cached["_word_dicts"] = None
|
||||||
|
cached["_inv"] = None
|
||||||
|
cached["_content_bounds"] = None
|
||||||
|
else:
|
||||||
|
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
content_w = right_x - left_x
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||||
|
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||||
|
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||||
|
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||||
|
left_x=left_x, right_x=right_x, inv=inv)
|
||||||
|
|
||||||
|
columns = [asdict(r) for r in regions]
|
||||||
|
column_result = {
|
||||||
|
"columns": columns,
|
||||||
|
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["column_result"] = column_result
|
||||||
|
await update_session_db(session_id, column_result=column_result,
|
||||||
|
row_result=None, word_result=None, current_step=6)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("columns")
|
||||||
|
yield await _auto_sse_event("columns", "done", {
|
||||||
|
"column_count": len(columns),
|
||||||
|
"duration_seconds": column_result["duration_seconds"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||||
|
error_step = "columns"
|
||||||
|
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("columns")
|
||||||
|
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||||
|
|
||||||
|
# Step 4: Rows
|
||||||
|
if req.from_step <= 4:
|
||||||
|
yield await _auto_sse_event("rows", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
column_result = session.get("column_result") or cached.get("column_result")
|
||||||
|
if not column_result or not column_result.get("columns"):
|
||||||
|
raise ValueError("Column detection must complete first")
|
||||||
|
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"], x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
inv = cached.get("_inv")
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
|
||||||
|
if word_dicts is None or inv is None or content_bounds is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(row_img)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||||
|
if geo_result is None:
|
||||||
|
raise ValueError("Column geometry detection failed -- cannot detect rows")
|
||||||
|
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||||
|
content_bounds = (lx, rx, ty, by)
|
||||||
|
|
||||||
|
left_x, right_x, top_y, bottom_y = content_bounds
|
||||||
|
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
row_list = [
|
||||||
|
{
|
||||||
|
"index": r.index, "x": r.x, "y": r.y,
|
||||||
|
"width": r.width, "height": r.height,
|
||||||
|
"word_count": r.word_count,
|
||||||
|
"row_type": r.row_type,
|
||||||
|
"gap_before": r.gap_before,
|
||||||
|
}
|
||||||
|
for r in row_geoms
|
||||||
|
]
|
||||||
|
row_result = {
|
||||||
|
"rows": row_list,
|
||||||
|
"row_count": len(row_list),
|
||||||
|
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["row_result"] = row_result
|
||||||
|
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("rows")
|
||||||
|
yield await _auto_sse_event("rows", "done", {
|
||||||
|
"row_count": len(row_list),
|
||||||
|
"content_rows": row_result["content_rows"],
|
||||||
|
"duration_seconds": row_result["duration_seconds"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||||
|
error_step = "rows"
|
||||||
|
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("rows")
|
||||||
|
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||||
|
|
||||||
|
# Step 5: Words (OCR)
|
||||||
|
if req.from_step <= 5:
|
||||||
|
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
column_result = session.get("column_result") or cached.get("column_result")
|
||||||
|
row_result = session.get("row_result") or cached.get("row_result")
|
||||||
|
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"], x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
row_geoms = [
|
||||||
|
RowGeometry(
|
||||||
|
index=r["index"], x=r["x"], y=r["y"],
|
||||||
|
width=r["width"], height=r["height"],
|
||||||
|
word_count=r.get("word_count", 0), words=[],
|
||||||
|
row_type=r.get("row_type", "content"),
|
||||||
|
gap_before=r.get("gap_before", 0),
|
||||||
|
)
|
||||||
|
for r in row_result["rows"]
|
||||||
|
]
|
||||||
|
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
if word_dicts is not None:
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||||
|
for row in row_geoms:
|
||||||
|
row_y_rel = row.y - top_y
|
||||||
|
row_bottom_rel = row_y_rel + row.height
|
||||||
|
row.words = [
|
||||||
|
w for w in word_dicts
|
||||||
|
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||||
|
]
|
||||||
|
row.word_count = len(row.words)
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(word_img)
|
||||||
|
img_h, img_w = word_img.shape[:2]
|
||||||
|
|
||||||
|
cells, columns_meta = build_cell_grid(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||||
|
|
||||||
|
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||||
|
|
||||||
|
word_result_data = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_content_rows,
|
||||||
|
"cols": len(columns_meta),
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_character_confusion(entries)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||||
|
word_result_data["vocab_entries"] = entries
|
||||||
|
word_result_data["entries"] = entries
|
||||||
|
word_result_data["entry_count"] = len(entries)
|
||||||
|
word_result_data["summary"]["total_entries"] = len(entries)
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||||
|
cached["word_result"] = word_result_data
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("words")
|
||||||
|
yield await _auto_sse_event("words", "done", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"layout": word_result_data["layout"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": word_result_data["summary"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||||
|
error_step = "words"
|
||||||
|
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("words")
|
||||||
|
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||||
|
|
||||||
|
# Step 6: LLM Review (optional)
|
||||||
|
if req.from_step <= 6 and not req.skip_llm_review:
|
||||||
|
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||||
|
try:
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
word_result = session.get("word_result") or cached.get("word_result")
|
||||||
|
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
else:
|
||||||
|
reviewed = await llm_review_entries(entries)
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
word_result_updated = dict(session.get("word_result") or {})
|
||||||
|
word_result_updated["entries"] = reviewed
|
||||||
|
word_result_updated["vocab_entries"] = reviewed
|
||||||
|
word_result_updated["llm_reviewed"] = True
|
||||||
|
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||||
|
cached["word_result"] = word_result_updated
|
||||||
|
|
||||||
|
steps_run.append("llm_review")
|
||||||
|
yield await _auto_sse_event("llm_review", "done", {
|
||||||
|
"entries_reviewed": len(reviewed),
|
||||||
|
"model": OLLAMA_REVIEW_MODEL,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||||
|
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
else:
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||||
|
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||||
|
|
||||||
|
# Final event
|
||||||
|
yield await _auto_sse_event("complete", "done", {
|
||||||
|
"steps_run": steps_run,
|
||||||
|
"steps_skipped": steps_skipped,
|
||||||
|
})
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Column Detection Endpoints (Step 5)
|
||||||
|
|
||||||
|
Detect invisible columns, manual column override, and ground truth.
|
||||||
|
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
_detect_header_footer_gaps,
|
||||||
|
_detect_sub_columns,
|
||||||
|
classify_column_types,
|
||||||
|
create_layout_image,
|
||||||
|
create_ocr_image,
|
||||||
|
analyze_layout,
|
||||||
|
detect_column_geometry_zoned,
|
||||||
|
expand_narrow_columns,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
ManualColumnsRequest,
|
||||||
|
ColumnGroundTruthRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/columns")
|
||||||
|
async def detect_columns(session_id: str):
|
||||||
|
"""Run column detection on the cropped (or dewarped) image."""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Sub-sessions (box crops): skip column detection entirely.
|
||||||
|
# Instead, create a single pseudo-column spanning the full image width.
|
||||||
|
# Also run Tesseract + binarization here so that the row detection step
|
||||||
|
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
|
||||||
|
# instead of falling back to detect_column_geometry() which may fail
|
||||||
|
# on small box images with < 5 words.
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if session and session.get("parent_session_id"):
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Binarize + invert for row detection (horizontal projection profile)
|
||||||
|
ocr_img = create_ocr_image(img_bgr)
|
||||||
|
inv = cv2.bitwise_not(ocr_img)
|
||||||
|
|
||||||
|
# Run Tesseract to get word bounding boxes.
|
||||||
|
try:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
import pytesseract
|
||||||
|
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
|
||||||
|
word_dicts = []
|
||||||
|
for i in range(len(data['text'])):
|
||||||
|
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
|
||||||
|
text = str(data['text'][i]).strip()
|
||||||
|
if conf < 30 or not text:
|
||||||
|
continue
|
||||||
|
word_dicts.append({
|
||||||
|
'text': text, 'conf': conf,
|
||||||
|
'left': int(data['left'][i]),
|
||||||
|
'top': int(data['top'][i]),
|
||||||
|
'width': int(data['width'][i]),
|
||||||
|
'height': int(data['height'][i]),
|
||||||
|
})
|
||||||
|
# Log all words including low-confidence ones for debugging
|
||||||
|
all_count = sum(1 for i in range(len(data['text']))
|
||||||
|
if str(data['text'][i]).strip())
|
||||||
|
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
|
||||||
|
for i in range(len(data['text']))
|
||||||
|
if str(data['text'][i]).strip()
|
||||||
|
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
|
||||||
|
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
|
||||||
|
if low_conf:
|
||||||
|
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
|
||||||
|
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
|
||||||
|
word_dicts = []
|
||||||
|
|
||||||
|
# Cache intermediates for row detection (detect_rows reuses these)
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (0, w, 0, h)
|
||||||
|
|
||||||
|
column_result = {
|
||||||
|
"columns": [{
|
||||||
|
"type": "column_text",
|
||||||
|
"x": 0, "y": 0,
|
||||||
|
"width": w, "height": h,
|
||||||
|
}],
|
||||||
|
"zones": None,
|
||||||
|
"boxes_detected": 0,
|
||||||
|
"duration_seconds": 0,
|
||||||
|
"method": "sub_session_pseudo_column",
|
||||||
|
}
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
column_result=column_result,
|
||||||
|
row_result=None,
|
||||||
|
word_result=None,
|
||||||
|
current_step=6,
|
||||||
|
)
|
||||||
|
cached["column_result"] = column_result
|
||||||
|
cached.pop("row_result", None)
|
||||||
|
cached.pop("word_result", None)
|
||||||
|
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
|
||||||
|
return {"session_id": session_id, **column_result}
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Binarized image for layout analysis
|
||||||
|
ocr_img = create_ocr_image(img_bgr)
|
||||||
|
h, w = ocr_img.shape[:2]
|
||||||
|
|
||||||
|
# Phase A: Zone-aware geometry detection
|
||||||
|
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
|
||||||
|
|
||||||
|
boxes_detected = 0
|
||||||
|
if zoned_result is None:
|
||||||
|
# Fallback to projection-based layout
|
||||||
|
layout_img = create_layout_image(img_bgr)
|
||||||
|
regions = analyze_layout(layout_img, ocr_img)
|
||||||
|
zones_data = None
|
||||||
|
else:
|
||||||
|
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
|
||||||
|
content_w = right_x - left_x
|
||||||
|
boxes_detected = len(boxes)
|
||||||
|
|
||||||
|
# Cache intermediates for row detection (avoids second Tesseract run)
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
cached["_zones_data"] = zones_data
|
||||||
|
cached["_boxes_detected"] = boxes_detected
|
||||||
|
|
||||||
|
# Detect header/footer early so sub-column clustering ignores them
|
||||||
|
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||||
|
|
||||||
|
# Split sub-columns (e.g. page references) before classification
|
||||||
|
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||||
|
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||||
|
|
||||||
|
# Expand narrow columns (sub-columns are often very narrow)
|
||||||
|
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
|
||||||
|
|
||||||
|
# Phase B: Content-based classification
|
||||||
|
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||||
|
left_x=left_x, right_x=right_x, inv=inv)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
columns = [asdict(r) for r in regions]
|
||||||
|
|
||||||
|
# Determine classification methods used
|
||||||
|
methods = list(set(
|
||||||
|
c.get("classification_method", "") for c in columns
|
||||||
|
if c.get("classification_method")
|
||||||
|
))
|
||||||
|
|
||||||
|
column_result = {
|
||||||
|
"columns": columns,
|
||||||
|
"classification_methods": methods,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"boxes_detected": boxes_detected,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add zone data when boxes are present
|
||||||
|
if zones_data and boxes_detected > 0:
|
||||||
|
column_result["zones"] = zones_data
|
||||||
|
|
||||||
|
# Persist to DB -- also invalidate downstream results (rows, words)
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
column_result=column_result,
|
||||||
|
row_result=None,
|
||||||
|
word_result=None,
|
||||||
|
current_step=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["column_result"] = column_result
|
||||||
|
cached.pop("row_result", None)
|
||||||
|
cached.pop("word_result", None)
|
||||||
|
|
||||||
|
col_count = len([c for c in columns if c["type"].startswith("column")])
|
||||||
|
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
||||||
|
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
|
||||||
|
|
||||||
|
img_w = img_bgr.shape[1]
|
||||||
|
await _append_pipeline_log(session_id, "columns", {
|
||||||
|
"total_columns": len(columns),
|
||||||
|
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
|
||||||
|
"column_types": [c["type"] for c in columns],
|
||||||
|
"boxes_detected": boxes_detected,
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**column_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/columns/manual")
|
||||||
|
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
||||||
|
"""Override detected columns with manual definitions."""
|
||||||
|
column_result = {
|
||||||
|
"columns": req.columns,
|
||||||
|
"duration_seconds": 0,
|
||||||
|
"method": "manual",
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(session_id, column_result=column_result,
|
||||||
|
row_result=None, word_result=None)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["column_result"] = column_result
|
||||||
|
_cache[session_id].pop("row_result", None)
|
||||||
|
_cache[session_id].pop("word_result", None)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
||||||
|
f"{len(req.columns)} columns set")
|
||||||
|
|
||||||
|
return {"session_id": session_id, **column_result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/columns")
|
||||||
|
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the column detection step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_columns": req.corrected_columns,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"column_result": session.get("column_result"),
|
||||||
|
}
|
||||||
|
ground_truth["columns"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/ground-truth/columns")
|
||||||
|
async def get_column_ground_truth(session_id: str):
|
||||||
|
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
columns_gt = ground_truth.get("columns")
|
||||||
|
if not columns_gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"columns_gt": columns_gt,
|
||||||
|
"columns_auto": session.get("column_result"),
|
||||||
|
}
|
||||||
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
Shared common module for the OCR pipeline.
|
||||||
|
|
||||||
|
Contains in-memory cache, helper functions, Pydantic request models,
|
||||||
|
pipeline logging, and border-ghost word filtering used by the pipeline
|
||||||
|
API endpoints and related modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .session_store import get_session_db, get_session_image, update_session_db
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Cache
|
||||||
|
"_cache",
|
||||||
|
# Helper functions
|
||||||
|
"_get_base_image_png",
|
||||||
|
"_load_session_to_cache",
|
||||||
|
"_get_cached",
|
||||||
|
# Pydantic models
|
||||||
|
"ManualDeskewRequest",
|
||||||
|
"DeskewGroundTruthRequest",
|
||||||
|
"ManualDewarpRequest",
|
||||||
|
"CombinedAdjustRequest",
|
||||||
|
"DewarpGroundTruthRequest",
|
||||||
|
"VALID_DOCUMENT_CATEGORIES",
|
||||||
|
"UpdateSessionRequest",
|
||||||
|
"ManualColumnsRequest",
|
||||||
|
"ColumnGroundTruthRequest",
|
||||||
|
"ManualRowsRequest",
|
||||||
|
"RowGroundTruthRequest",
|
||||||
|
"RemoveHandwritingRequest",
|
||||||
|
# Pipeline log
|
||||||
|
"_append_pipeline_log",
|
||||||
|
# Border-ghost filter
|
||||||
|
"_BORDER_GHOST_CHARS",
|
||||||
|
"_filter_border_ghost_words",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||||||
|
# DB is source of truth, cache holds BGR arrays during active processing.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
||||||
|
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
||||||
|
for img_type in ("cropped", "dewarped", "original"):
|
||||||
|
png_data = await get_session_image(session_id, img_type)
|
||||||
|
if png_data:
|
||||||
|
return png_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
return _cache[session_id]
|
||||||
|
|
||||||
|
cache_entry: Dict[str, Any] = {
|
||||||
|
"id": session_id,
|
||||||
|
**session,
|
||||||
|
"original_bgr": None,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Decode images from DB into BGR numpy arrays
|
||||||
|
for img_type, bgr_key in [
|
||||||
|
("original", "original_bgr"),
|
||||||
|
("oriented", "oriented_bgr"),
|
||||||
|
("cropped", "cropped_bgr"),
|
||||||
|
("deskewed", "deskewed_bgr"),
|
||||||
|
("dewarped", "dewarped_bgr"),
|
||||||
|
]:
|
||||||
|
png_data = await get_session_image(session_id, img_type)
|
||||||
|
if png_data:
|
||||||
|
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||||
|
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
cache_entry[bgr_key] = bgr
|
||||||
|
|
||||||
|
# Sub-sessions: original image IS the cropped box region.
|
||||||
|
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
||||||
|
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
||||||
|
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
||||||
|
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
||||||
|
|
||||||
|
_cache[session_id] = cache_entry
|
||||||
|
return cache_entry
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get from cache or raise 404."""
|
||||||
|
entry = _cache.get(session_id)
|
||||||
|
if not entry:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ManualDeskewRequest(BaseModel):
|
||||||
|
angle: float
|
||||||
|
|
||||||
|
|
||||||
|
class DeskewGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_angle: Optional[float] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualDewarpRequest(BaseModel):
|
||||||
|
shear_degrees: float
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedAdjustRequest(BaseModel):
|
||||||
|
rotation_degrees: float = 0.0
|
||||||
|
shear_degrees: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class DewarpGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_shear: Optional[float] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
VALID_DOCUMENT_CATEGORIES = {
|
||||||
|
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
||||||
|
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSessionRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
document_category: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualColumnsRequest(BaseModel):
|
||||||
|
columns: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualRowsRequest(BaseModel):
|
||||||
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class RowGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveHandwritingRequest(BaseModel):
|
||||||
|
method: str = "auto" # "auto" | "telea" | "ns"
|
||||||
|
target_ink: str = "all" # "all" | "colored" | "pencil"
|
||||||
|
dilation: int = 2 # mask dilation iterations (0-5)
|
||||||
|
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pipeline Log Helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _append_pipeline_log(
|
||||||
|
session_id: str,
|
||||||
|
step_name: str,
|
||||||
|
metrics: Dict[str, Any],
|
||||||
|
success: bool = True,
|
||||||
|
duration_ms: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Append a step entry to the session's pipeline_log JSONB."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
log = session.get("pipeline_log") or {"steps": []}
|
||||||
|
if not isinstance(log, dict):
|
||||||
|
log = {"steps": []}
|
||||||
|
entry = {
|
||||||
|
"step": step_name,
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
"success": success,
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
if duration_ms is not None:
|
||||||
|
entry["duration_ms"] = duration_ms
|
||||||
|
log.setdefault("steps", []).append(entry)
|
||||||
|
await update_session_db(session_id, pipeline_log=log)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Border-ghost word filter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Characters that OCR produces when reading box-border lines.
|
||||||
|
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_border_ghost_words(
|
||||||
|
word_result: Dict,
|
||||||
|
boxes: List,
|
||||||
|
) -> int:
|
||||||
|
"""Remove OCR words that are actually box border lines.
|
||||||
|
|
||||||
|
A word is considered a border ghost when it sits on a known box edge
|
||||||
|
(left, right, top, or bottom) and looks like a line artefact (narrow
|
||||||
|
aspect ratio or text consists only of line-like characters).
|
||||||
|
|
||||||
|
After removing ghost cells, columns that have become empty are also
|
||||||
|
removed from ``columns_used`` so the grid no longer shows phantom
|
||||||
|
columns.
|
||||||
|
|
||||||
|
Modifies *word_result* in-place and returns the number of removed cells.
|
||||||
|
"""
|
||||||
|
if not boxes or not word_result:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cells = word_result.get("cells")
|
||||||
|
if not cells:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Build border bands — vertical (X) and horizontal (Y)
|
||||||
|
x_bands = [] # list of (x_lo, x_hi)
|
||||||
|
y_bands = [] # list of (y_lo, y_hi)
|
||||||
|
for b in boxes:
|
||||||
|
bx = b.x if hasattr(b, "x") else b.get("x", 0)
|
||||||
|
by = b.y if hasattr(b, "y") else b.get("y", 0)
|
||||||
|
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
|
||||||
|
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
|
||||||
|
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
|
||||||
|
margin = max(bt * 2, 10) + 6 # generous margin
|
||||||
|
|
||||||
|
# Vertical edges (left / right)
|
||||||
|
x_bands.append((bx - margin, bx + margin))
|
||||||
|
x_bands.append((bx + bw - margin, bx + bw + margin))
|
||||||
|
# Horizontal edges (top / bottom)
|
||||||
|
y_bands.append((by - margin, by + margin))
|
||||||
|
y_bands.append((by + bh - margin, by + bh + margin))
|
||||||
|
|
||||||
|
img_w = word_result.get("image_width", 1)
|
||||||
|
img_h = word_result.get("image_height", 1)
|
||||||
|
|
||||||
|
def _is_ghost(cell: Dict) -> bool:
|
||||||
|
text = (cell.get("text") or "").strip()
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compute absolute pixel position
|
||||||
|
if cell.get("bbox_px"):
|
||||||
|
px = cell["bbox_px"]
|
||||||
|
cx = px["x"] + px["w"] / 2
|
||||||
|
cy = px["y"] + px["h"] / 2
|
||||||
|
cw = px["w"]
|
||||||
|
ch = px["h"]
|
||||||
|
elif cell.get("bbox_pct"):
|
||||||
|
pct = cell["bbox_pct"]
|
||||||
|
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
|
||||||
|
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
|
||||||
|
cw = (pct["w"] / 100) * img_w
|
||||||
|
ch = (pct["h"] / 100) * img_h
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if center sits on a vertical or horizontal border
|
||||||
|
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
|
||||||
|
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
|
||||||
|
if not on_vertical and not on_horizontal:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Very short text (1-2 chars) on a border → very likely ghost
|
||||||
|
if len(text) <= 2:
|
||||||
|
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
|
||||||
|
if ch > 0 and cw / ch < 0.5:
|
||||||
|
return True
|
||||||
|
if cw > 0 and ch / cw < 0.5:
|
||||||
|
return True
|
||||||
|
# Text is only border-ghost characters?
|
||||||
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Longer text but still only ghost chars and very narrow
|
||||||
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||||
|
if ch > 0 and cw / ch < 0.35:
|
||||||
|
return True
|
||||||
|
if cw > 0 and ch / cw < 0.35:
|
||||||
|
return True
|
||||||
|
return True # all ghost chars on a border → remove
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
before = len(cells)
|
||||||
|
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
|
||||||
|
removed = before - len(word_result["cells"])
|
||||||
|
|
||||||
|
# --- Remove empty columns from columns_used ---
|
||||||
|
columns_used = word_result.get("columns_used")
|
||||||
|
if removed and columns_used and len(columns_used) > 1:
|
||||||
|
remaining_cells = word_result["cells"]
|
||||||
|
occupied_cols = {c.get("col_index") for c in remaining_cells}
|
||||||
|
before_cols = len(columns_used)
|
||||||
|
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
|
||||||
|
|
||||||
|
# Re-index columns and remap cell col_index values
|
||||||
|
if len(columns_used) < before_cols:
|
||||||
|
old_to_new = {}
|
||||||
|
for new_i, col in enumerate(columns_used):
|
||||||
|
old_to_new[col["index"]] = new_i
|
||||||
|
col["index"] = new_i
|
||||||
|
for cell in remaining_cells:
|
||||||
|
old_ci = cell.get("col_index")
|
||||||
|
if old_ci in old_to_new:
|
||||||
|
cell["col_index"] = old_to_new[old_ci]
|
||||||
|
word_result["columns_used"] = columns_used
|
||||||
|
logger.info("border-ghost: removed %d empty column(s), %d remaining",
|
||||||
|
before_cols - len(columns_used), len(columns_used))
|
||||||
|
|
||||||
|
if removed:
|
||||||
|
# Update summary counts
|
||||||
|
summary = word_result.get("summary", {})
|
||||||
|
summary["total_cells"] = len(word_result["cells"])
|
||||||
|
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
|
||||||
|
word_result["summary"] = summary
|
||||||
|
gs = word_result.get("grid_shape", {})
|
||||||
|
gs["total_cells"] = len(word_result["cells"])
|
||||||
|
if columns_used is not None:
|
||||||
|
gs["cols"] = len(columns_used)
|
||||||
|
word_result["grid_shape"] = gs
|
||||||
|
|
||||||
|
return removed
|
||||||
@@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
|
||||||
|
|
||||||
|
Auto-crop, manual crop, and skip-crop for scanner/book borders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .page_crop import detect_and_crop_page, detect_page_splits
|
||||||
|
from .session_store import get_sub_sessions, update_session_db
|
||||||
|
|
||||||
|
from .orientation_crop_helpers import ensure_cached, append_pipeline_log
|
||||||
|
from .page_sub_sessions import create_page_sub_sessions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/crop")
|
||||||
|
async def auto_crop(session_id: str):
|
||||||
|
"""Auto-detect and crop scanner/book borders.
|
||||||
|
|
||||||
|
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
|
||||||
|
Falls back to oriented -> original if earlier steps were skipped.
|
||||||
|
|
||||||
|
If the image is a multi-page spread (e.g. book on scanner), it will
|
||||||
|
automatically split into separate sub-sessions per page, crop each
|
||||||
|
individually, and return the split info.
|
||||||
|
"""
|
||||||
|
cached = await ensure_cached(session_id)
|
||||||
|
|
||||||
|
# Use dewarped (preferred), fall back to oriented, then original
|
||||||
|
img_bgr = next(
|
||||||
|
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# --- Check for existing sub-sessions (from page-split step) ---
|
||||||
|
# If page-split already created sub-sessions, skip multi-page detection
|
||||||
|
# in the crop step. Each sub-session runs its own crop independently.
|
||||||
|
existing_subs = await get_sub_sessions(session_id)
|
||||||
|
if existing_subs:
|
||||||
|
crop_result = cached.get("crop_result") or {}
|
||||||
|
if crop_result.get("multi_page"):
|
||||||
|
# Already split -- just return the existing info
|
||||||
|
duration = time.time() - t0
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**crop_result,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"sub_sessions": [
|
||||||
|
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
|
||||||
|
for i, s in enumerate(existing_subs)
|
||||||
|
],
|
||||||
|
"note": "Page split was already performed; each sub-session runs its own crop.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
|
||||||
|
page_splits = detect_page_splits(img_bgr)
|
||||||
|
|
||||||
|
if page_splits and len(page_splits) >= 2:
|
||||||
|
# Multi-page spread detected -- create sub-sessions
|
||||||
|
sub_sessions = await create_page_sub_sessions(
|
||||||
|
session_id, cached, img_bgr, page_splits,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
crop_info: Dict[str, Any] = {
|
||||||
|
"crop_applied": True,
|
||||||
|
"multi_page": True,
|
||||||
|
"page_count": len(page_splits),
|
||||||
|
"page_splits": page_splits,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
cached["crop_result"] = crop_info
|
||||||
|
|
||||||
|
# Store the first page as the main cropped image for backward compat
|
||||||
|
first_page = page_splits[0]
|
||||||
|
first_bgr = img_bgr[
|
||||||
|
first_page["y"]:first_page["y"] + first_page["height"],
|
||||||
|
first_page["x"]:first_page["x"] + first_page["width"],
|
||||||
|
].copy()
|
||||||
|
first_cropped, _ = detect_and_crop_page(first_bgr)
|
||||||
|
cached["cropped_bgr"] = first_cropped
|
||||||
|
|
||||||
|
ok, png_buf = cv2.imencode(".png", first_cropped)
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
cropped_png=png_buf.tobytes() if ok else b"",
|
||||||
|
crop_result=crop_info,
|
||||||
|
current_step=5,
|
||||||
|
status='split',
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
|
||||||
|
session_id, len(page_splits), duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
await append_pipeline_log(session_id, "crop", {
|
||||||
|
"multi_page": True,
|
||||||
|
"page_count": len(page_splits),
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
h, w = first_cropped.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**crop_info,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||||
|
"sub_sessions": sub_sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Single page (normal) ---
|
||||||
|
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
crop_info["duration_seconds"] = round(duration, 2)
|
||||||
|
crop_info["multi_page"] = False
|
||||||
|
|
||||||
|
# Encode cropped image
|
||||||
|
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||||
|
cropped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["cropped_bgr"] = cropped_bgr
|
||||||
|
cached["crop_result"] = crop_info
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
cropped_png=cropped_png,
|
||||||
|
crop_result=crop_info,
|
||||||
|
current_step=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
|
||||||
|
session_id, crop_info["crop_applied"],
|
||||||
|
crop_info.get("detected_format", "?"),
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
await append_pipeline_log(session_id, "crop", {
|
||||||
|
"crop_applied": crop_info["crop_applied"],
|
||||||
|
"detected_format": crop_info.get("detected_format"),
|
||||||
|
"format_confidence": crop_info.get("format_confidence"),
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
h, w = cropped_bgr.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**crop_info,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ManualCropRequest(BaseModel):
|
||||||
|
x: float # percentage 0-100
|
||||||
|
y: float # percentage 0-100
|
||||||
|
width: float # percentage 0-100
|
||||||
|
height: float # percentage 0-100
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/crop/manual")
|
||||||
|
async def manual_crop(session_id: str, req: ManualCropRequest):
|
||||||
|
"""Manually crop using percentage coordinates."""
|
||||||
|
cached = await ensure_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = next(
|
||||||
|
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Convert percentages to pixels
|
||||||
|
px_x = int(w * req.x / 100.0)
|
||||||
|
px_y = int(h * req.y / 100.0)
|
||||||
|
px_w = int(w * req.width / 100.0)
|
||||||
|
px_h = int(h * req.height / 100.0)
|
||||||
|
|
||||||
|
# Clamp
|
||||||
|
px_x = max(0, min(px_x, w - 1))
|
||||||
|
px_y = max(0, min(px_y, h - 1))
|
||||||
|
px_w = max(1, min(px_w, w - px_x))
|
||||||
|
px_h = max(1, min(px_h, h - px_y))
|
||||||
|
|
||||||
|
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
|
||||||
|
|
||||||
|
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||||
|
cropped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
crop_result = {
|
||||||
|
"crop_applied": True,
|
||||||
|
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
|
||||||
|
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
|
||||||
|
"width": round(req.width, 2), "height": round(req.height, 2)},
|
||||||
|
"original_size": {"width": w, "height": h},
|
||||||
|
"cropped_size": {"width": px_w, "height": px_h},
|
||||||
|
"method": "manual",
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["cropped_bgr"] = cropped_bgr
|
||||||
|
cached["crop_result"] = crop_result
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
cropped_png=cropped_png,
|
||||||
|
crop_result=crop_result,
|
||||||
|
current_step=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
ch, cw = cropped_bgr.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**crop_result,
|
||||||
|
"image_width": cw,
|
||||||
|
"image_height": ch,
|
||||||
|
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/crop/skip")
|
||||||
|
async def skip_crop(session_id: str):
|
||||||
|
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
|
||||||
|
cached = await ensure_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = next(
|
||||||
|
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available")
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Store the dewarped image as cropped (identity crop)
|
||||||
|
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||||
|
cropped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
crop_result = {
|
||||||
|
"crop_applied": False,
|
||||||
|
"skipped": True,
|
||||||
|
"original_size": {"width": w, "height": h},
|
||||||
|
"cropped_size": {"width": w, "height": h},
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["cropped_bgr"] = img_bgr
|
||||||
|
cached["crop_result"] = crop_result
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
cropped_png=cropped_png,
|
||||||
|
crop_result=crop_result,
|
||||||
|
current_step=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**crop_result,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||||
|
}
|
||||||
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Deskew Endpoints (Step 2)
|
||||||
|
|
||||||
|
Auto deskew, manual deskew, and ground truth for the deskew step.
|
||||||
|
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
create_ocr_image,
|
||||||
|
deskew_image,
|
||||||
|
deskew_image_by_word_alignment,
|
||||||
|
deskew_two_pass,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
ManualDeskewRequest,
|
||||||
|
DeskewGroundTruthRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/deskew")
|
||||||
|
async def auto_deskew(session_id: str):
|
||||||
|
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
|
||||||
|
# Ensure session is in cache
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
# Deskew runs right after orientation -- use oriented image, fall back to original
|
||||||
|
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None), None)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Two-pass deskew: iterative (+-5 deg) + word-alignment residual check
|
||||||
|
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
|
||||||
|
|
||||||
|
# Also run individual methods for reporting (non-authoritative)
|
||||||
|
try:
|
||||||
|
_, angle_hough = deskew_image(img_bgr.copy())
|
||||||
|
except Exception:
|
||||||
|
angle_hough = 0.0
|
||||||
|
|
||||||
|
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
||||||
|
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||||
|
try:
|
||||||
|
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||||
|
except Exception:
|
||||||
|
angle_wa = 0.0
|
||||||
|
|
||||||
|
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
|
||||||
|
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
|
||||||
|
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
|
||||||
|
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
# Create binarized version
|
||||||
|
binarized_png = None
|
||||||
|
try:
|
||||||
|
binarized = create_ocr_image(deskewed_bgr)
|
||||||
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||||||
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Binarization failed: {e}")
|
||||||
|
|
||||||
|
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
||||||
|
|
||||||
|
deskew_result = {
|
||||||
|
"angle_hough": round(angle_hough, 3),
|
||||||
|
"angle_word_alignment": round(angle_wa, 3),
|
||||||
|
"angle_iterative": round(angle_iterative, 3),
|
||||||
|
"angle_residual": round(angle_residual, 3),
|
||||||
|
"angle_textline": round(angle_textline, 3),
|
||||||
|
"angle_applied": round(angle_applied, 3),
|
||||||
|
"method_used": method_used,
|
||||||
|
"confidence": round(confidence, 2),
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"two_pass_debug": two_pass_debug,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["deskewed_bgr"] = deskewed_bgr
|
||||||
|
cached["binarized_png"] = binarized_png
|
||||||
|
cached["deskew_result"] = deskew_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
db_update = {
|
||||||
|
"deskewed_png": deskewed_png,
|
||||||
|
"deskew_result": deskew_result,
|
||||||
|
"current_step": 3,
|
||||||
|
}
|
||||||
|
if binarized_png:
|
||||||
|
db_update["binarized_png"] = binarized_png
|
||||||
|
await update_session_db(session_id, **db_update)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
||||||
|
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
|
||||||
|
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
|
||||||
|
f"textline={angle_textline:.2f} "
|
||||||
|
f"-> {method_used} total={angle_applied:.2f}")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "deskew", {
|
||||||
|
"angle_applied": round(angle_applied, 3),
|
||||||
|
"angle_iterative": round(angle_iterative, 3),
|
||||||
|
"angle_residual": round(angle_residual, 3),
|
||||||
|
"angle_textline": round(angle_textline, 3),
|
||||||
|
"confidence": round(confidence, 2),
|
||||||
|
"method": method_used,
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**deskew_result,
|
||||||
|
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||||||
|
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/deskew/manual")
|
||||||
|
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||||||
|
"""Apply a manual rotation angle to the oriented image."""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None), None)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
||||||
|
|
||||||
|
angle = max(-5.0, min(5.0, req.angle))
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
center = (w // 2, h // 2)
|
||||||
|
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||||
|
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
borderMode=cv2.BORDER_REPLICATE)
|
||||||
|
|
||||||
|
success, png_buf = cv2.imencode(".png", rotated)
|
||||||
|
deskewed_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
# Binarize
|
||||||
|
binarized_png = None
|
||||||
|
try:
|
||||||
|
binarized = create_ocr_image(rotated)
|
||||||
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||||||
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
deskew_result = {
|
||||||
|
**(cached.get("deskew_result") or {}),
|
||||||
|
"angle_applied": round(angle, 3),
|
||||||
|
"method_used": "manual",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["deskewed_bgr"] = rotated
|
||||||
|
cached["binarized_png"] = binarized_png
|
||||||
|
cached["deskew_result"] = deskew_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
db_update = {
|
||||||
|
"deskewed_png": deskewed_png,
|
||||||
|
"deskew_result": deskew_result,
|
||||||
|
}
|
||||||
|
if binarized_png:
|
||||||
|
db_update["binarized_png"] = binarized_png
|
||||||
|
await update_session_db(session_id, **db_update)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"angle_applied": round(angle, 3),
|
||||||
|
"method_used": "manual",
|
||||||
|
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
||||||
|
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the deskew step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_angle": req.corrected_angle,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"deskew_result": session.get("deskew_result"),
|
||||||
|
}
|
||||||
|
ground_truth["deskew"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
||||||
|
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Dewarp Endpoints
|
||||||
|
|
||||||
|
Auto dewarp (with VLM/CV ensemble), manual dewarp, combined
|
||||||
|
rotation+shear adjustment, and ground truth.
|
||||||
|
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
_apply_shear,
|
||||||
|
create_ocr_image,
|
||||||
|
dewarp_image,
|
||||||
|
dewarp_image_manual,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
ManualDewarpRequest,
|
||||||
|
CombinedAdjustRequest,
|
||||||
|
DewarpGroundTruthRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||||
|
|
||||||
|
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||||
|
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||||
|
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
|
||||||
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||||
|
"Are they perfectly vertical, or do they tilt slightly? "
|
||||||
|
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||||
|
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||||
|
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||||
|
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
|
||||||
|
# Parse JSON from response (may have surrounding text)
|
||||||
|
match = re.search(r'\{[^}]+\}', text)
|
||||||
|
if match:
|
||||||
|
data = json.loads(match.group(0))
|
||||||
|
shear = float(data.get("shear_degrees", 0.0))
|
||||||
|
conf = float(data.get("confidence", 0.0))
|
||||||
|
# Clamp to reasonable range
|
||||||
|
shear = max(-3.0, min(3.0, shear))
|
||||||
|
conf = max(0.0, min(1.0, conf))
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"VLM dewarp failed: {e}")
|
||||||
|
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/dewarp")
|
||||||
|
async def auto_dewarp(
|
||||||
|
session_id: str,
|
||||||
|
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
|
||||||
|
):
|
||||||
|
"""Detect and correct vertical shear on the deskewed image.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
|
||||||
|
- **cv**: CV ensemble only (same as ensemble)
|
||||||
|
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
|
||||||
|
"""
|
||||||
|
if method not in ("ensemble", "cv", "vlm"):
|
||||||
|
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
|
||||||
|
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
deskewed_bgr = cached.get("deskewed_bgr")
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
if method == "vlm":
|
||||||
|
# Encode deskewed image to PNG for VLM
|
||||||
|
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
img_bytes = png_buf.tobytes() if success else b""
|
||||||
|
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||||
|
shear_deg = vlm_det["shear_degrees"]
|
||||||
|
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||||
|
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||||
|
else:
|
||||||
|
dewarped_bgr = deskewed_bgr
|
||||||
|
dewarp_info = {
|
||||||
|
"method": vlm_det["method"],
|
||||||
|
"shear_degrees": shear_deg,
|
||||||
|
"confidence": vlm_det["confidence"],
|
||||||
|
"detections": [vlm_det],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||||
|
dewarped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
dewarp_result = {
|
||||||
|
"method_used": dewarp_info["method"],
|
||||||
|
"shear_degrees": dewarp_info["shear_degrees"],
|
||||||
|
"confidence": dewarp_info["confidence"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"detections": dewarp_info.get("detections", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["dewarped_bgr"] = dewarped_bgr
|
||||||
|
cached["dewarp_result"] = dewarp_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
dewarped_png=dewarped_png,
|
||||||
|
dewarp_result=dewarp_result,
|
||||||
|
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||||
|
current_step=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
||||||
|
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
||||||
|
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "dewarp", {
|
||||||
|
"shear_degrees": dewarp_info["shear_degrees"],
|
||||||
|
"confidence": dewarp_info["confidence"],
|
||||||
|
"method": dewarp_info["method"],
|
||||||
|
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**dewarp_result,
|
||||||
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/dewarp/manual")
|
||||||
|
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||||||
|
"""Apply shear correction with a manual angle."""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
deskewed_bgr = cached.get("deskewed_bgr")
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||||||
|
|
||||||
|
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
||||||
|
|
||||||
|
if abs(shear_deg) < 0.001:
|
||||||
|
dewarped_bgr = deskewed_bgr
|
||||||
|
else:
|
||||||
|
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
||||||
|
|
||||||
|
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||||
|
dewarped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
dewarp_result = {
|
||||||
|
**(cached.get("dewarp_result") or {}),
|
||||||
|
"method_used": "manual",
|
||||||
|
"shear_degrees": round(shear_deg, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["dewarped_bgr"] = dewarped_bgr
|
||||||
|
cached["dewarp_result"] = dewarp_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
dewarped_png=dewarped_png,
|
||||||
|
dewarp_result=dewarp_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"shear_degrees": round(shear_deg, 3),
|
||||||
|
"method_used": "manual",
|
||||||
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/adjust-combined")
|
||||||
|
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
|
||||||
|
"""Apply rotation + shear combined to the original image.
|
||||||
|
|
||||||
|
Used by the fine-tuning sliders to preview arbitrary rotation/shear
|
||||||
|
combinations without re-running the full deskew/dewarp pipeline.
|
||||||
|
"""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = cached.get("original_bgr")
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Original image not available")
|
||||||
|
|
||||||
|
rotation = max(-15.0, min(15.0, req.rotation_degrees))
|
||||||
|
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
result_bgr = img_bgr
|
||||||
|
|
||||||
|
# Step 1: Apply rotation
|
||||||
|
if abs(rotation) >= 0.001:
|
||||||
|
center = (w // 2, h // 2)
|
||||||
|
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
|
||||||
|
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
borderMode=cv2.BORDER_REPLICATE)
|
||||||
|
|
||||||
|
# Step 2: Apply shear
|
||||||
|
if abs(shear_deg) >= 0.001:
|
||||||
|
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
success, png_buf = cv2.imencode(".png", result_bgr)
|
||||||
|
dewarped_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
# Binarize
|
||||||
|
binarized_png = None
|
||||||
|
try:
|
||||||
|
binarized = create_ocr_image(result_bgr)
|
||||||
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||||||
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build combined result dicts
|
||||||
|
deskew_result = {
|
||||||
|
**(cached.get("deskew_result") or {}),
|
||||||
|
"angle_applied": round(rotation, 3),
|
||||||
|
"method_used": "manual_combined",
|
||||||
|
}
|
||||||
|
dewarp_result = {
|
||||||
|
**(cached.get("dewarp_result") or {}),
|
||||||
|
"method_used": "manual_combined",
|
||||||
|
"shear_degrees": round(shear_deg, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["deskewed_bgr"] = result_bgr
|
||||||
|
cached["dewarped_bgr"] = result_bgr
|
||||||
|
cached["deskew_result"] = deskew_result
|
||||||
|
cached["dewarp_result"] = dewarp_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
db_update = {
|
||||||
|
"dewarped_png": dewarped_png,
|
||||||
|
"deskew_result": deskew_result,
|
||||||
|
"dewarp_result": dewarp_result,
|
||||||
|
}
|
||||||
|
if binarized_png:
|
||||||
|
db_update["binarized_png"] = binarized_png
|
||||||
|
db_update["deskewed_png"] = dewarped_png
|
||||||
|
await update_session_db(session_id, **db_update)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
|
||||||
|
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"rotation_degrees": round(rotation, 3),
|
||||||
|
"shear_degrees": round(shear_deg, 3),
|
||||||
|
"method_used": "manual_combined",
|
||||||
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
||||||
|
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the dewarp step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_shear": req.corrected_shear,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"dewarp_result": session.get("dewarp_result"),
|
||||||
|
}
|
||||||
|
ground_truth["dewarp"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
||||||
|
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Geometry API (barrel re-export)
|
||||||
|
|
||||||
|
This module was split into:
|
||||||
|
- ocr_pipeline_deskew.py (Deskew endpoints)
|
||||||
|
- ocr_pipeline_dewarp.py (Dewarp endpoints)
|
||||||
|
- ocr_pipeline_structure.py (Structure detection + exclude regions)
|
||||||
|
- ocr_pipeline_columns.py (Column detection + ground truth)
|
||||||
|
|
||||||
|
The `router` object is assembled here by including all sub-routers.
|
||||||
|
Importers that did `from ocr_pipeline_geometry import router` continue to work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .deskew import router as _deskew_router
|
||||||
|
from .dewarp import router as _dewarp_router
|
||||||
|
from .structure import router as _structure_router
|
||||||
|
from .columns import router as _columns_router
|
||||||
|
|
||||||
|
# Assemble the combined router.
|
||||||
|
# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix.
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_deskew_router)
|
||||||
|
router.include_router(_dewarp_router)
|
||||||
|
router.include_router(_structure_router)
|
||||||
|
router.include_router(_columns_router)
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline LLM Review — LLM-based correction endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_postprocess.py.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
OLLAMA_REVIEW_MODEL,
|
||||||
|
llm_review_entries,
|
||||||
|
llm_review_entries_streaming,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_append_pipeline_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 8: LLM Review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/llm-review")
|
||||||
|
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||||||
|
"""Run LLM-based correction on vocab entries from Step 5.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
stream: false (default) for JSON response, true for SSE streaming
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||||||
|
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
if not entries:
|
||||||
|
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||||||
|
|
||||||
|
# Optional model override from request body
|
||||||
|
body = {}
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-streaming path
|
||||||
|
try:
|
||||||
|
result = await llm_review_entries(entries, model=model)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||||
|
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||||||
|
|
||||||
|
# Store result inside word_result as a sub-key
|
||||||
|
word_result["llm_review"] = {
|
||||||
|
"changes": result["changes"],
|
||||||
|
"model_used": result["model_used"],
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
"entries_corrected": result["entries_corrected"],
|
||||||
|
}
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||||||
|
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "correction", {
|
||||||
|
"engine": "llm",
|
||||||
|
"model": result["model_used"],
|
||||||
|
"total_entries": len(entries),
|
||||||
|
"corrections_proposed": len(result["changes"]),
|
||||||
|
}, duration_ms=result["duration_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"changes": result["changes"],
|
||||||
|
"model_used": result["model_used"],
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
"total_entries": len(entries),
|
||||||
|
"corrections_found": len(result["changes"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_review_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
entries: List[Dict],
|
||||||
|
word_result: Dict,
|
||||||
|
model: str,
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||||||
|
try:
|
||||||
|
async for event in llm_review_entries_streaming(entries, model=model):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# On complete: persist to DB
|
||||||
|
if event.get("type") == "complete":
|
||||||
|
word_result["llm_review"] = {
|
||||||
|
"changes": event["changes"],
|
||||||
|
"model_used": event["model_used"],
|
||||||
|
"duration_ms": event["duration_ms"],
|
||||||
|
"entries_corrected": event["entries_corrected"],
|
||||||
|
}
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||||||
|
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||||
|
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||||||
|
yield f"data: {json.dumps(error_event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/llm-review/apply")
|
||||||
|
async def apply_llm_corrections(session_id: str, request: Request):
|
||||||
|
"""Apply selected LLM corrections to vocab entries."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
llm_review = word_result.get("llm_review")
|
||||||
|
if not llm_review:
|
||||||
|
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||||||
|
|
||||||
|
changes = llm_review.get("changes", [])
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
|
||||||
|
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||||||
|
corrections = {}
|
||||||
|
applied_count = 0
|
||||||
|
for idx, change in enumerate(changes):
|
||||||
|
if idx in accepted_indices:
|
||||||
|
key = (change["row_index"], change["field"])
|
||||||
|
corrections[key] = change["new"]
|
||||||
|
applied_count += 1
|
||||||
|
|
||||||
|
# Apply corrections to entries
|
||||||
|
for entry in entries:
|
||||||
|
row_idx = entry.get("row_index", -1)
|
||||||
|
for field_name in ("english", "german", "example"):
|
||||||
|
key = (row_idx, field_name)
|
||||||
|
if key in corrections:
|
||||||
|
entry[field_name] = corrections[key]
|
||||||
|
entry["llm_corrected"] = True
|
||||||
|
|
||||||
|
# Update word_result
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["llm_review"]["applied_count"] = applied_count
|
||||||
|
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"applied_count": applied_count,
|
||||||
|
"total_changes": len(changes),
|
||||||
|
}
|
||||||
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_ocr_merge.py.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_paddle_multi_words(words: list) -> list:
|
||||||
|
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
||||||
|
|
||||||
|
PaddleOCR often returns entire phrases as a single box, e.g.
|
||||||
|
"More than 200 singers took part in the" with one bounding box.
|
||||||
|
This splits them into individual words with proportional widths.
|
||||||
|
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
|
||||||
|
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for w in words:
|
||||||
|
raw_text = w.get("text", "").strip()
|
||||||
|
if not raw_text:
|
||||||
|
continue
|
||||||
|
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
||||||
|
tokens = re.split(
|
||||||
|
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
||||||
|
)
|
||||||
|
tokens = [t for t in tokens if t]
|
||||||
|
|
||||||
|
if len(tokens) <= 1:
|
||||||
|
result.append(w)
|
||||||
|
else:
|
||||||
|
# Split proportionally by character count
|
||||||
|
total_chars = sum(len(t) for t in tokens)
|
||||||
|
if total_chars == 0:
|
||||||
|
continue
|
||||||
|
n_gaps = len(tokens) - 1
|
||||||
|
gap_px = w["width"] * 0.02
|
||||||
|
usable_w = w["width"] - gap_px * n_gaps
|
||||||
|
cursor = w["left"]
|
||||||
|
for t in tokens:
|
||||||
|
token_w = max(1, usable_w * len(t) / total_chars)
|
||||||
|
result.append({
|
||||||
|
"text": t,
|
||||||
|
"left": round(cursor),
|
||||||
|
"top": w["top"],
|
||||||
|
"width": round(token_w),
|
||||||
|
"height": w["height"],
|
||||||
|
"conf": w.get("conf", 0),
|
||||||
|
})
|
||||||
|
cursor += token_w + gap_px
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
||||||
|
"""Group words into rows by Y-position clustering.
|
||||||
|
|
||||||
|
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
||||||
|
Returns list of rows, each row is a list of words sorted left-to-right.
|
||||||
|
"""
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
# Sort by vertical center
|
||||||
|
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
||||||
|
rows: list = []
|
||||||
|
current_row: list = [sorted_words[0]]
|
||||||
|
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
||||||
|
|
||||||
|
for w in sorted_words[1:]:
|
||||||
|
cy = w["top"] + w.get("height", 0) / 2
|
||||||
|
if abs(cy - current_cy) <= row_gap:
|
||||||
|
current_row.append(w)
|
||||||
|
else:
|
||||||
|
# Sort current row left-to-right before saving
|
||||||
|
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||||
|
current_row = [w]
|
||||||
|
current_cy = cy
|
||||||
|
if current_row:
|
||||||
|
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def _row_center_y(row: list) -> float:
|
||||||
|
"""Average vertical center of a row of words."""
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
||||||
|
"""Merge two word sequences from the same row using sequence alignment.
|
||||||
|
|
||||||
|
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
||||||
|
- If words match (same/similar text): take Paddle text with averaged coords
|
||||||
|
- If they don't match: the extra word is unique to one engine, include it
|
||||||
|
"""
|
||||||
|
merged = []
|
||||||
|
pi, ti = 0, 0
|
||||||
|
|
||||||
|
while pi < len(paddle_row) and ti < len(tess_row):
|
||||||
|
pw = paddle_row[pi]
|
||||||
|
tw = tess_row[ti]
|
||||||
|
|
||||||
|
pt = pw.get("text", "").lower().strip()
|
||||||
|
tt = tw.get("text", "").lower().strip()
|
||||||
|
|
||||||
|
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
||||||
|
|
||||||
|
# Spatial overlap check
|
||||||
|
spatial_match = False
|
||||||
|
if not is_same:
|
||||||
|
overlap_left = max(pw["left"], tw["left"])
|
||||||
|
overlap_right = min(
|
||||||
|
pw["left"] + pw.get("width", 0),
|
||||||
|
tw["left"] + tw.get("width", 0),
|
||||||
|
)
|
||||||
|
overlap_w = max(0, overlap_right - overlap_left)
|
||||||
|
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
||||||
|
if min_w > 0 and overlap_w / min_w >= 0.4:
|
||||||
|
is_same = True
|
||||||
|
spatial_match = True
|
||||||
|
|
||||||
|
if is_same:
|
||||||
|
pc = pw.get("conf", 80)
|
||||||
|
tc = tw.get("conf", 50)
|
||||||
|
total = pc + tc
|
||||||
|
if total == 0:
|
||||||
|
total = 1
|
||||||
|
if spatial_match and pc < tc:
|
||||||
|
best_text = tw["text"]
|
||||||
|
else:
|
||||||
|
best_text = pw["text"]
|
||||||
|
merged.append({
|
||||||
|
"text": best_text,
|
||||||
|
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
||||||
|
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
||||||
|
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
||||||
|
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
||||||
|
"conf": max(pc, tc),
|
||||||
|
})
|
||||||
|
pi += 1
|
||||||
|
ti += 1
|
||||||
|
else:
|
||||||
|
paddle_ahead = any(
|
||||||
|
tess_row[t].get("text", "").lower().strip() == pt
|
||||||
|
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
||||||
|
)
|
||||||
|
tess_ahead = any(
|
||||||
|
paddle_row[p].get("text", "").lower().strip() == tt
|
||||||
|
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
||||||
|
)
|
||||||
|
|
||||||
|
if paddle_ahead and not tess_ahead:
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
elif tess_ahead and not paddle_ahead:
|
||||||
|
merged.append(pw)
|
||||||
|
pi += 1
|
||||||
|
else:
|
||||||
|
if pw["left"] <= tw["left"]:
|
||||||
|
merged.append(pw)
|
||||||
|
pi += 1
|
||||||
|
else:
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
|
||||||
|
while pi < len(paddle_row):
|
||||||
|
merged.append(paddle_row[pi])
|
||||||
|
pi += 1
|
||||||
|
while ti < len(tess_row):
|
||||||
|
tw = tess_row[ti]
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
||||||
|
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
|
||||||
|
if not paddle_words and not tess_words:
|
||||||
|
return []
|
||||||
|
if not paddle_words:
|
||||||
|
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
||||||
|
if not tess_words:
|
||||||
|
return list(paddle_words)
|
||||||
|
|
||||||
|
paddle_rows = _group_words_into_rows(paddle_words)
|
||||||
|
tess_rows = _group_words_into_rows(tess_words)
|
||||||
|
|
||||||
|
used_tess_rows: set = set()
|
||||||
|
merged_all: list = []
|
||||||
|
|
||||||
|
for pr in paddle_rows:
|
||||||
|
pr_cy = _row_center_y(pr)
|
||||||
|
best_dist, best_tri = float("inf"), -1
|
||||||
|
for tri, tr in enumerate(tess_rows):
|
||||||
|
if tri in used_tess_rows:
|
||||||
|
continue
|
||||||
|
tr_cy = _row_center_y(tr)
|
||||||
|
dist = abs(pr_cy - tr_cy)
|
||||||
|
if dist < best_dist:
|
||||||
|
best_dist, best_tri = dist, tri
|
||||||
|
|
||||||
|
max_row_dist = max(
|
||||||
|
max((w.get("height", 20) for w in pr), default=20),
|
||||||
|
15,
|
||||||
|
)
|
||||||
|
|
||||||
|
if best_tri >= 0 and best_dist <= max_row_dist:
|
||||||
|
tr = tess_rows[best_tri]
|
||||||
|
used_tess_rows.add(best_tri)
|
||||||
|
merged_all.extend(_merge_row_sequences(pr, tr))
|
||||||
|
else:
|
||||||
|
merged_all.extend(pr)
|
||||||
|
|
||||||
|
for tri, tr in enumerate(tess_rows):
|
||||||
|
if tri not in used_tess_rows:
|
||||||
|
for tw in tr:
|
||||||
|
if tw.get("conf", 0) >= 40:
|
||||||
|
merged_all.append(tw)
|
||||||
|
|
||||||
|
return merged_all
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate_words(words: list) -> list:
|
||||||
|
"""Remove duplicate words with same text at overlapping positions."""
|
||||||
|
if not words:
|
||||||
|
return words
|
||||||
|
|
||||||
|
result: list = []
|
||||||
|
for w in words:
|
||||||
|
wt = w.get("text", "").lower().strip()
|
||||||
|
if not wt:
|
||||||
|
continue
|
||||||
|
is_dup = False
|
||||||
|
w_right = w["left"] + w.get("width", 0)
|
||||||
|
w_bottom = w["top"] + w.get("height", 0)
|
||||||
|
for existing in result:
|
||||||
|
et = existing.get("text", "").lower().strip()
|
||||||
|
if wt != et:
|
||||||
|
continue
|
||||||
|
ox_l = max(w["left"], existing["left"])
|
||||||
|
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
|
||||||
|
ox = max(0, ox_r - ox_l)
|
||||||
|
min_w = min(w.get("width", 1), existing.get("width", 1))
|
||||||
|
if min_w <= 0 or ox / min_w < 0.5:
|
||||||
|
continue
|
||||||
|
oy_t = max(w["top"], existing["top"])
|
||||||
|
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
|
||||||
|
oy = max(0, oy_b - oy_t)
|
||||||
|
min_h = min(w.get("height", 1), existing.get("height", 1))
|
||||||
|
if min_h > 0 and oy / min_h >= 0.5:
|
||||||
|
is_dup = True
|
||||||
|
break
|
||||||
|
if not is_dup:
|
||||||
|
result.append(w)
|
||||||
|
|
||||||
|
removed = len(words) - len(result)
|
||||||
|
if removed:
|
||||||
|
logger.info("dedup: removed %d duplicate words", removed)
|
||||||
|
return result
|
||||||
@@ -0,0 +1,266 @@
|
|||||||
|
"""
|
||||||
|
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
|
||||||
|
|
||||||
|
Merge helper functions live in ocr_merge_helpers.py.
|
||||||
|
This module re-exports them for backward compatibility.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_words_first import build_grid_from_words
|
||||||
|
from .common import _cache, _append_pipeline_log
|
||||||
|
from .session_store import get_session_image, update_session_db
|
||||||
|
|
||||||
|
# Re-export merge helpers for backward compatibility
|
||||||
|
from .merge_helpers import ( # noqa: F401
|
||||||
|
_split_paddle_multi_words,
|
||||||
|
_group_words_into_rows,
|
||||||
|
_row_center_y,
|
||||||
|
_merge_row_sequences,
|
||||||
|
_merge_paddle_tesseract,
|
||||||
|
_deduplicate_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
def _run_tesseract_words(img_bgr) -> list:
|
||||||
|
"""Run Tesseract OCR on an image and return word dicts."""
|
||||||
|
from PIL import Image
|
||||||
|
import pytesseract
|
||||||
|
|
||||||
|
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
data = pytesseract.image_to_data(
|
||||||
|
pil_img, lang="eng+deu",
|
||||||
|
config="--psm 6 --oem 3",
|
||||||
|
output_type=pytesseract.Output.DICT,
|
||||||
|
)
|
||||||
|
tess_words = []
|
||||||
|
for i in range(len(data["text"])):
|
||||||
|
text = str(data["text"][i]).strip()
|
||||||
|
conf_raw = str(data["conf"][i])
|
||||||
|
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||||
|
if not text or conf < 20:
|
||||||
|
continue
|
||||||
|
tess_words.append({
|
||||||
|
"text": text,
|
||||||
|
"left": data["left"][i],
|
||||||
|
"top": data["top"][i],
|
||||||
|
"width": data["width"][i],
|
||||||
|
"height": data["height"][i],
|
||||||
|
"conf": conf,
|
||||||
|
})
|
||||||
|
return tess_words
|
||||||
|
|
||||||
|
|
||||||
|
def _build_kombi_word_result(
|
||||||
|
cells: list,
|
||||||
|
columns_meta: list,
|
||||||
|
img_w: int,
|
||||||
|
img_h: int,
|
||||||
|
duration: float,
|
||||||
|
engine_name: str,
|
||||||
|
raw_engine_words: list,
|
||||||
|
raw_engine_words_split: list,
|
||||||
|
tess_words: list,
|
||||||
|
merged_words: list,
|
||||||
|
raw_engine_key: str = "raw_paddle_words",
|
||||||
|
raw_split_key: str = "raw_paddle_words_split",
|
||||||
|
) -> dict:
|
||||||
|
"""Build the word_result dict for kombi endpoints."""
|
||||||
|
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
col_types = {c.get("type") for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": engine_name,
|
||||||
|
"grid_method": engine_name,
|
||||||
|
raw_engine_key: raw_engine_words,
|
||||||
|
raw_split_key: raw_engine_words_split,
|
||||||
|
"raw_tesseract_words": tess_words,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
|
||||||
|
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_session_image(session_id: str):
|
||||||
|
"""Load preprocessed image for kombi endpoints."""
|
||||||
|
img_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "original")
|
||||||
|
if not img_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||||
|
|
||||||
|
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||||
|
|
||||||
|
return img_png, img_bgr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Kombi endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/paddle-kombi")
|
||||||
|
async def paddle_kombi(session_id: str):
|
||||||
|
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
|
||||||
|
img_png, img_bgr = await _load_session_image(session_id)
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
||||||
|
if not paddle_words:
|
||||||
|
paddle_words = []
|
||||||
|
|
||||||
|
tess_words = _run_tesseract_words(img_bgr)
|
||||||
|
|
||||||
|
paddle_words_split = _split_paddle_multi_words(paddle_words)
|
||||||
|
logger.info(
|
||||||
|
"paddle_kombi: split %d paddle boxes -> %d individual words",
|
||||||
|
len(paddle_words), len(paddle_words_split),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not paddle_words_split and not tess_words:
|
||||||
|
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||||
|
|
||||||
|
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
|
||||||
|
merged_words = _deduplicate_words(merged_words)
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "kombi"
|
||||||
|
|
||||||
|
word_result = _build_kombi_word_result(
|
||||||
|
cells, columns_meta, img_w, img_h, duration, "kombi",
|
||||||
|
paddle_words, paddle_words_split, tess_words, merged_words,
|
||||||
|
"raw_paddle_words", "raw_paddle_words_split",
|
||||||
|
)
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
||||||
|
)
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||||
|
"[paddle=%d, tess=%d, merged=%d]",
|
||||||
|
session_id, len(cells), word_result["grid_shape"]["rows"],
|
||||||
|
word_result["grid_shape"]["cols"], duration,
|
||||||
|
len(paddle_words), len(tess_words), len(merged_words),
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "paddle_kombi", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"paddle_words": len(paddle_words),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
"ocr_engine": "kombi",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rapid-kombi")
|
||||||
|
async def rapid_kombi(session_id: str):
|
||||||
|
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
|
||||||
|
img_png, img_bgr = await _load_session_image(session_id)
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_rapid
|
||||||
|
from cv_vocab_types import PageRegion
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
full_region = PageRegion(
|
||||||
|
type="full_page", x=0, y=0, width=img_w, height=img_h,
|
||||||
|
)
|
||||||
|
rapid_words = ocr_region_rapid(img_bgr, full_region)
|
||||||
|
if not rapid_words:
|
||||||
|
rapid_words = []
|
||||||
|
|
||||||
|
tess_words = _run_tesseract_words(img_bgr)
|
||||||
|
|
||||||
|
rapid_words_split = _split_paddle_multi_words(rapid_words)
|
||||||
|
logger.info(
|
||||||
|
"rapid_kombi: split %d rapid boxes -> %d individual words",
|
||||||
|
len(rapid_words), len(rapid_words_split),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rapid_words_split and not tess_words:
|
||||||
|
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||||
|
|
||||||
|
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
|
||||||
|
merged_words = _deduplicate_words(merged_words)
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "rapid_kombi"
|
||||||
|
|
||||||
|
word_result = _build_kombi_word_result(
|
||||||
|
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
|
||||||
|
rapid_words, rapid_words_split, tess_words, merged_words,
|
||||||
|
"raw_rapid_words", "raw_rapid_words_split",
|
||||||
|
)
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
||||||
|
)
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||||
|
"[rapid=%d, tess=%d, merged=%d]",
|
||||||
|
session_id, len(cells), word_result["grid_shape"]["rows"],
|
||||||
|
word_result["grid_shape"]["cols"], duration,
|
||||||
|
len(rapid_words), len(tess_words), len(merged_words),
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "rapid_kombi", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"rapid_words": len(rapid_words),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
"ocr_engine": "rapid_kombi",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import detect_and_fix_orientation
|
||||||
|
from .page_crop import detect_page_splits
|
||||||
|
from .session_store import update_session_db
|
||||||
|
|
||||||
|
from .orientation_crop_helpers import ensure_cached, append_pipeline_log
|
||||||
|
from .page_sub_sessions import create_page_sub_sessions_full
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 1: Orientation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/orientation")
|
||||||
|
async def detect_orientation(session_id: str):
|
||||||
|
"""Detect and fix 90/180/270 degree rotations from scanners.
|
||||||
|
|
||||||
|
Reads the original image, applies orientation correction,
|
||||||
|
stores the result as oriented_png.
|
||||||
|
"""
|
||||||
|
cached = await ensure_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = cached.get("original_bgr")
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Original image not available")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Detect and fix orientation
|
||||||
|
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
orientation_result = {
|
||||||
|
"orientation_degrees": orientation_deg,
|
||||||
|
"corrected": orientation_deg != 0,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Encode oriented image
|
||||||
|
success, png_buf = cv2.imencode(".png", oriented_bgr)
|
||||||
|
oriented_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cached["oriented_bgr"] = oriented_bgr
|
||||||
|
cached["orientation_result"] = orientation_result
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
oriented_png=oriented_png,
|
||||||
|
orientation_result=orientation_result,
|
||||||
|
current_step=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
|
||||||
|
session_id, orientation_deg,
|
||||||
|
"corrected" if orientation_deg else "no change",
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
await append_pipeline_log(session_id, "orientation", {
|
||||||
|
"orientation_degrees": orientation_deg,
|
||||||
|
"corrected": orientation_deg != 0,
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
h, w = oriented_bgr.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**orientation_result,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/page-split")
|
||||||
|
async def detect_page_split(session_id: str):
|
||||||
|
"""Detect if the image is a double-page book spread and split into sub-sessions.
|
||||||
|
|
||||||
|
Must be called **after orientation** (step 1) and **before deskew** (step 2).
|
||||||
|
Each sub-session receives the raw page region and goes through the full
|
||||||
|
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
|
||||||
|
independently, so each page gets its own deskew correction.
|
||||||
|
|
||||||
|
Returns ``{"multi_page": false}`` if only one page is detected.
|
||||||
|
"""
|
||||||
|
cached = await ensure_cached(session_id)
|
||||||
|
|
||||||
|
# Use oriented (preferred), fall back to original
|
||||||
|
img_bgr = next(
|
||||||
|
(v for k in ("oriented_bgr", "original_bgr")
|
||||||
|
if (v := cached.get(k)) is not None),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No image available for page-split detection")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
page_splits = detect_page_splits(img_bgr)
|
||||||
|
used_original = False
|
||||||
|
|
||||||
|
if not page_splits or len(page_splits) < 2:
|
||||||
|
# Orientation may have rotated a landscape double-page spread to
|
||||||
|
# portrait. Try the original (pre-orientation) image as fallback.
|
||||||
|
orig_bgr = cached.get("original_bgr")
|
||||||
|
if orig_bgr is not None and orig_bgr is not img_bgr:
|
||||||
|
page_splits_orig = detect_page_splits(orig_bgr)
|
||||||
|
if page_splits_orig and len(page_splits_orig) >= 2:
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: page-split session %s: spread detected on "
|
||||||
|
"ORIGINAL (orientation rotated it away)",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
img_bgr = orig_bgr
|
||||||
|
page_splits = page_splits_orig
|
||||||
|
used_original = True
|
||||||
|
|
||||||
|
if not page_splits or len(page_splits) < 2:
|
||||||
|
duration = time.time() - t0
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: page-split session %s: single page (%.2fs)",
|
||||||
|
session_id, duration,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"multi_page": False,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Multi-page spread detected — create sub-sessions for full pipeline.
|
||||||
|
# start_step=2 means "ready for deskew" (orientation already applied).
|
||||||
|
# start_step=1 means "needs orientation too" (split from original image).
|
||||||
|
start_step = 1 if used_original else 2
|
||||||
|
sub_sessions = await create_page_sub_sessions_full(
|
||||||
|
session_id, cached, img_bgr, page_splits, start_step=start_step,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
split_info: Dict[str, Any] = {
|
||||||
|
"multi_page": True,
|
||||||
|
"page_count": len(page_splits),
|
||||||
|
"page_splits": page_splits,
|
||||||
|
"used_original": used_original,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mark parent session as split and hidden from session list
|
||||||
|
await update_session_db(session_id, crop_result=split_info, status='split')
|
||||||
|
cached["crop_result"] = split_info
|
||||||
|
|
||||||
|
await append_pipeline_log(session_id, "page_split", {
|
||||||
|
"multi_page": True,
|
||||||
|
"page_count": len(page_splits),
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
|
||||||
|
session_id, len(page_splits), duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**split_info,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"sub_sessions": sub_sessions,
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
|
||||||
|
|
||||||
|
Barrel re-export: merges routers from orientation_api and crop_api,
|
||||||
|
and re-exports set_cache_ref for main.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .orientation_crop_helpers import set_cache_ref # noqa: F401
|
||||||
|
from .orientation_api import router as _orientation_router
|
||||||
|
from .crop_api import router as _crop_router
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_orientation_router)
|
||||||
|
router.include_router(_crop_router)
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
Orientation & Crop shared helpers - cache management and pipeline logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
|
||||||
|
_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
|
||||||
|
"""Set reference to the shared cache from ocr_pipeline_api."""
|
||||||
|
global _cache
|
||||||
|
_cache = cache
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Get reference to the shared cache."""
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_cached(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Ensure session is in cache, loading from DB if needed."""
|
||||||
|
if session_id in _cache:
|
||||||
|
return _cache[session_id]
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
cache_entry: Dict[str, Any] = {
|
||||||
|
"id": session_id,
|
||||||
|
**session,
|
||||||
|
"original_bgr": None,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
for img_type, bgr_key in [
|
||||||
|
("original", "original_bgr"),
|
||||||
|
("oriented", "oriented_bgr"),
|
||||||
|
("cropped", "cropped_bgr"),
|
||||||
|
("deskewed", "deskewed_bgr"),
|
||||||
|
("dewarped", "dewarped_bgr"),
|
||||||
|
]:
|
||||||
|
png_data = await get_session_image(session_id, img_type)
|
||||||
|
if png_data:
|
||||||
|
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||||
|
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
cache_entry[bgr_key] = bgr
|
||||||
|
|
||||||
|
_cache[session_id] = cache_entry
|
||||||
|
return cache_entry
|
||||||
|
|
||||||
|
|
||||||
|
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
|
||||||
|
"""Append a step entry to the pipeline log."""
|
||||||
|
from datetime import datetime
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
pipeline_log = session.get("pipeline_log") or {"steps": []}
|
||||||
|
pipeline_log["steps"].append({
|
||||||
|
"step": step,
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
"success": True,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"metrics": metrics,
|
||||||
|
})
|
||||||
|
await update_session_db(session_id, pipeline_log=pipeline_log)
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
"""
|
||||||
|
Overlay rendering for columns, rows, and words (grid-based overlays).
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_overlays.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from .common import _get_base_image_png
|
||||||
|
from .session_store import get_session_db
|
||||||
|
from .rows import _draw_box_exclusion_overlay
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_columns_overlay(session_id: str) -> Response:
|
||||||
|
"""Generate cropped (or dewarped) image with column borders drawn on it."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
if not column_result or not column_result.get("columns"):
|
||||||
|
raise HTTPException(status_code=404, detail="No column data available")
|
||||||
|
|
||||||
|
# Load best available base image (cropped > dewarped > original)
|
||||||
|
base_png = await _get_base_image_png(session_id)
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
# Color map for region types (BGR)
|
||||||
|
colors = {
|
||||||
|
"column_en": (255, 180, 0), # Blue
|
||||||
|
"column_de": (0, 200, 0), # Green
|
||||||
|
"column_example": (0, 140, 255), # Orange
|
||||||
|
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||||||
|
"page_ref": (200, 0, 200), # Purple
|
||||||
|
"column_marker": (0, 0, 220), # Red
|
||||||
|
"column_ignore": (180, 180, 180), # Light Gray
|
||||||
|
"header": (128, 128, 128), # Gray
|
||||||
|
"footer": (128, 128, 128), # Gray
|
||||||
|
"margin_top": (100, 100, 100), # Dark Gray
|
||||||
|
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||||
|
}
|
||||||
|
|
||||||
|
overlay = img.copy()
|
||||||
|
for col in column_result["columns"]:
|
||||||
|
x, y = col["x"], col["y"]
|
||||||
|
w, h = col["width"], col["height"]
|
||||||
|
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||||||
|
|
||||||
|
# Semi-transparent fill
|
||||||
|
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||||
|
|
||||||
|
# Solid border
|
||||||
|
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||||||
|
|
||||||
|
# Label with confidence
|
||||||
|
label = col.get("type", "unknown").replace("column_", "").upper()
|
||||||
|
conf = col.get("classification_confidence")
|
||||||
|
if conf is not None and conf < 1.0:
|
||||||
|
label = f"{label} {int(conf * 100)}%"
|
||||||
|
cv2.putText(img, label, (x + 10, y + 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||||
|
|
||||||
|
# Blend overlay at 20% opacity
|
||||||
|
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||||||
|
|
||||||
|
# Draw detected box boundaries as dashed rectangles
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box = zone["box"]
|
||||||
|
bx, by = box["x"], box["y"]
|
||||||
|
bw, bh = box["width"], box["height"]
|
||||||
|
box_color = (0, 200, 255) # Yellow (BGR)
|
||||||
|
# Draw dashed rectangle by drawing short line segments
|
||||||
|
dash_len = 15
|
||||||
|
for edge_x in range(bx, bx + bw, dash_len * 2):
|
||||||
|
end_x = min(edge_x + dash_len, bx + bw)
|
||||||
|
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
|
||||||
|
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
|
||||||
|
for edge_y in range(by, by + bh, dash_len * 2):
|
||||||
|
end_y = min(edge_y + dash_len, by + bh)
|
||||||
|
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
|
||||||
|
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
|
||||||
|
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
|
||||||
|
|
||||||
|
# Red semi-transparent overlay for box zones
|
||||||
|
_draw_box_exclusion_overlay(img, zones)
|
||||||
|
|
||||||
|
success, result_png = cv2.imencode(".png", img)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||||
|
|
||||||
|
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_rows_overlay(session_id: str) -> Response:
|
||||||
|
"""Generate cropped (or dewarped) image with row bands drawn on it."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
row_result = session.get("row_result")
|
||||||
|
if not row_result or not row_result.get("rows"):
|
||||||
|
raise HTTPException(status_code=404, detail="No row data available")
|
||||||
|
|
||||||
|
# Load best available base image (cropped > dewarped > original)
|
||||||
|
base_png = await _get_base_image_png(session_id)
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
# Color map for row types (BGR)
|
||||||
|
row_colors = {
|
||||||
|
"content": (255, 180, 0), # Blue
|
||||||
|
"header": (128, 128, 128), # Gray
|
||||||
|
"footer": (128, 128, 128), # Gray
|
||||||
|
"margin_top": (100, 100, 100), # Dark Gray
|
||||||
|
"margin_bottom": (100, 100, 100), # Dark Gray
|
||||||
|
}
|
||||||
|
|
||||||
|
overlay = img.copy()
|
||||||
|
for row in row_result["rows"]:
|
||||||
|
x, y = row["x"], row["y"]
|
||||||
|
w, h = row["width"], row["height"]
|
||||||
|
row_type = row.get("row_type", "content")
|
||||||
|
color = row_colors.get(row_type, (200, 200, 200))
|
||||||
|
|
||||||
|
# Semi-transparent fill
|
||||||
|
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||||
|
|
||||||
|
# Solid border
|
||||||
|
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
idx = row.get("index", 0)
|
||||||
|
label = f"R{idx} {row_type.upper()}"
|
||||||
|
wc = row.get("word_count", 0)
|
||||||
|
if wc:
|
||||||
|
label = f"{label} ({wc}w)"
|
||||||
|
cv2.putText(img, label, (x + 5, y + 18),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||||
|
|
||||||
|
# Blend overlay at 15% opacity
|
||||||
|
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||||
|
|
||||||
|
# Draw zone separator lines if zones exist
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
if zones:
|
||||||
|
img_w_px = img.shape[1]
|
||||||
|
zone_color = (0, 200, 255) # Yellow (BGR)
|
||||||
|
dash_len = 20
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box":
|
||||||
|
zy = zone["y"]
|
||||||
|
zh = zone["height"]
|
||||||
|
for line_y in [zy, zy + zh]:
|
||||||
|
for sx in range(0, img_w_px, dash_len * 2):
|
||||||
|
ex = min(sx + dash_len, img_w_px)
|
||||||
|
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
|
||||||
|
|
||||||
|
# Red semi-transparent overlay for box zones
|
||||||
|
_draw_box_exclusion_overlay(img, zones)
|
||||||
|
|
||||||
|
success, result_png = cv2.imencode(".png", img)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||||
|
|
||||||
|
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_words_overlay(session_id: str) -> Response:
|
||||||
|
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=404, detail="No word data available")
|
||||||
|
|
||||||
|
# Support both new cell-based and legacy entry-based formats
|
||||||
|
cells = word_result.get("cells")
|
||||||
|
if not cells and not word_result.get("entries"):
|
||||||
|
raise HTTPException(status_code=404, detail="No word data available")
|
||||||
|
|
||||||
|
# Load best available base image (cropped > dewarped > original)
|
||||||
|
base_png = await _get_base_image_png(session_id)
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
img_h, img_w = img.shape[:2]
|
||||||
|
|
||||||
|
overlay = img.copy()
|
||||||
|
|
||||||
|
if cells:
|
||||||
|
# New cell-based overlay: color by column index
|
||||||
|
col_palette = [
|
||||||
|
(255, 180, 0), # Blue (BGR)
|
||||||
|
(0, 200, 0), # Green
|
||||||
|
(0, 140, 255), # Orange
|
||||||
|
(200, 100, 200), # Purple
|
||||||
|
(200, 200, 0), # Cyan
|
||||||
|
(100, 200, 200), # Yellow-ish
|
||||||
|
]
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
bbox = cell.get("bbox_px", {})
|
||||||
|
cx = bbox.get("x", 0)
|
||||||
|
cy = bbox.get("y", 0)
|
||||||
|
cw = bbox.get("w", 0)
|
||||||
|
ch = bbox.get("h", 0)
|
||||||
|
if cw <= 0 or ch <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
col_idx = cell.get("col_index", 0)
|
||||||
|
color = col_palette[col_idx % len(col_palette)]
|
||||||
|
|
||||||
|
# Cell rectangle border
|
||||||
|
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||||||
|
# Semi-transparent fill
|
||||||
|
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||||||
|
|
||||||
|
# Cell-ID label (top-left corner)
|
||||||
|
cell_id = cell.get("cell_id", "")
|
||||||
|
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||||||
|
|
||||||
|
# Text label (bottom of cell)
|
||||||
|
text = cell.get("text", "")
|
||||||
|
if text:
|
||||||
|
conf = cell.get("confidence", 0)
|
||||||
|
if conf >= 70:
|
||||||
|
text_color = (0, 180, 0)
|
||||||
|
elif conf >= 50:
|
||||||
|
text_color = (0, 180, 220)
|
||||||
|
else:
|
||||||
|
text_color = (0, 0, 220)
|
||||||
|
|
||||||
|
label = text.replace('\n', ' ')[:30]
|
||||||
|
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||||
|
else:
|
||||||
|
# Legacy fallback: entry-based overlay (for old sessions)
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
row_result = session.get("row_result")
|
||||||
|
col_colors = {
|
||||||
|
"column_en": (255, 180, 0),
|
||||||
|
"column_de": (0, 200, 0),
|
||||||
|
"column_example": (0, 140, 255),
|
||||||
|
}
|
||||||
|
|
||||||
|
columns = []
|
||||||
|
if column_result and column_result.get("columns"):
|
||||||
|
columns = [c for c in column_result["columns"]
|
||||||
|
if c.get("type", "").startswith("column_")]
|
||||||
|
|
||||||
|
content_rows_data = []
|
||||||
|
if row_result and row_result.get("rows"):
|
||||||
|
content_rows_data = [r for r in row_result["rows"]
|
||||||
|
if r.get("row_type") == "content"]
|
||||||
|
|
||||||
|
for col in columns:
|
||||||
|
col_type = col.get("type", "")
|
||||||
|
color = col_colors.get(col_type, (200, 200, 200))
|
||||||
|
cx, cw = col["x"], col["width"]
|
||||||
|
for row in content_rows_data:
|
||||||
|
ry, rh = row["y"], row["height"]
|
||||||
|
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||||||
|
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||||||
|
|
||||||
|
entries = word_result["entries"]
|
||||||
|
entry_by_row: Dict[int, Dict] = {}
|
||||||
|
for entry in entries:
|
||||||
|
entry_by_row[entry.get("row_index", -1)] = entry
|
||||||
|
|
||||||
|
for row_idx, row in enumerate(content_rows_data):
|
||||||
|
entry = entry_by_row.get(row_idx)
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
conf = entry.get("confidence", 0)
|
||||||
|
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||||||
|
ry, rh = row["y"], row["height"]
|
||||||
|
for col in columns:
|
||||||
|
col_type = col.get("type", "")
|
||||||
|
cx, cw = col["x"], col["width"]
|
||||||
|
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||||||
|
text = entry.get(field, "") if field else ""
|
||||||
|
if text:
|
||||||
|
label = text.replace('\n', ' ')[:30]
|
||||||
|
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||||||
|
|
||||||
|
# Blend overlay at 10% opacity
|
||||||
|
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||||||
|
|
||||||
|
# Red semi-transparent overlay for box zones
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
_draw_box_exclusion_overlay(img, zones)
|
||||||
|
|
||||||
|
success, result_png = cv2.imencode(".png", img)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||||
|
|
||||||
|
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
Overlay rendering for structure detection (boxes, zones, colors, graphics).
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_overlays.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from .common import _get_base_image_png
|
||||||
|
from .session_store import get_session_db
|
||||||
|
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
|
||||||
|
from cv_box_detect import detect_boxes, split_page_into_zones
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_structure_overlay(session_id: str) -> Response:
|
||||||
|
"""Generate overlay image showing detected boxes, zones, and color regions."""
|
||||||
|
base_png = await _get_base_image_png(session_id)
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
# Get structure result (run detection if not cached)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
structure = (session or {}).get("structure_result")
|
||||||
|
|
||||||
|
if not structure:
|
||||||
|
# Run detection on-the-fly
|
||||||
|
margin = int(min(w, h) * 0.03)
|
||||||
|
content_x, content_y = margin, margin
|
||||||
|
content_w_px = w - 2 * margin
|
||||||
|
content_h_px = h - 2 * margin
|
||||||
|
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
|
||||||
|
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||||||
|
structure = {
|
||||||
|
"boxes": [
|
||||||
|
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||||||
|
"confidence": b.confidence, "border_thickness": b.border_thickness}
|
||||||
|
for b in boxes
|
||||||
|
],
|
||||||
|
"zones": [
|
||||||
|
{"index": z.index, "zone_type": z.zone_type,
|
||||||
|
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
|
||||||
|
for z in zones
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
overlay = img.copy()
|
||||||
|
|
||||||
|
# --- Draw zone boundaries ---
|
||||||
|
zone_colors = {
|
||||||
|
"content": (200, 200, 200), # light gray
|
||||||
|
"box": (255, 180, 0), # blue-ish (BGR)
|
||||||
|
}
|
||||||
|
for zone in structure.get("zones", []):
|
||||||
|
zx = zone["x"]
|
||||||
|
zy = zone["y"]
|
||||||
|
zw = zone["w"]
|
||||||
|
zh = zone["h"]
|
||||||
|
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
|
||||||
|
|
||||||
|
# Draw zone boundary as dashed line
|
||||||
|
dash_len = 12
|
||||||
|
for edge_x in range(zx, zx + zw, dash_len * 2):
|
||||||
|
end_x = min(edge_x + dash_len, zx + zw)
|
||||||
|
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
|
||||||
|
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
|
||||||
|
|
||||||
|
# Zone label
|
||||||
|
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
|
||||||
|
cv2.putText(img, zone_label, (zx + 5, zy + 15),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
|
||||||
|
|
||||||
|
# --- Draw detected boxes ---
|
||||||
|
# Color map for box backgrounds (BGR)
|
||||||
|
bg_hex_to_bgr = {
|
||||||
|
"#dc2626": (38, 38, 220), # red
|
||||||
|
"#2563eb": (235, 99, 37), # blue
|
||||||
|
"#16a34a": (74, 163, 22), # green
|
||||||
|
"#ea580c": (12, 88, 234), # orange
|
||||||
|
"#9333ea": (234, 51, 147), # purple
|
||||||
|
"#ca8a04": (4, 138, 202), # yellow
|
||||||
|
"#6b7280": (128, 114, 107), # gray
|
||||||
|
}
|
||||||
|
|
||||||
|
for box_data in structure.get("boxes", []):
|
||||||
|
bx = box_data["x"]
|
||||||
|
by = box_data["y"]
|
||||||
|
bw = box_data["w"]
|
||||||
|
bh = box_data["h"]
|
||||||
|
conf = box_data.get("confidence", 0)
|
||||||
|
thickness = box_data.get("border_thickness", 0)
|
||||||
|
bg_hex = box_data.get("bg_color_hex", "#6b7280")
|
||||||
|
bg_name = box_data.get("bg_color_name", "")
|
||||||
|
|
||||||
|
# Box fill color
|
||||||
|
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
|
||||||
|
|
||||||
|
# Semi-transparent fill
|
||||||
|
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
|
||||||
|
|
||||||
|
# Solid border
|
||||||
|
border_color = fill_bgr
|
||||||
|
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
label = f"BOX"
|
||||||
|
if bg_name and bg_name not in ("unknown", "white"):
|
||||||
|
label += f" ({bg_name})"
|
||||||
|
if thickness > 0:
|
||||||
|
label += f" border={thickness}px"
|
||||||
|
label += f" {int(conf * 100)}%"
|
||||||
|
cv2.putText(img, label, (bx + 8, by + 22),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
|
||||||
|
cv2.putText(img, label, (bx + 8, by + 22),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
|
||||||
|
|
||||||
|
# Blend overlay at 15% opacity
|
||||||
|
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||||
|
|
||||||
|
# --- Draw color regions (HSV masks) ---
|
||||||
|
hsv = cv2.cvtColor(
|
||||||
|
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
|
||||||
|
cv2.COLOR_BGR2HSV,
|
||||||
|
)
|
||||||
|
color_bgr_map = {
|
||||||
|
"red": (0, 0, 255),
|
||||||
|
"orange": (0, 140, 255),
|
||||||
|
"yellow": (0, 200, 255),
|
||||||
|
"green": (0, 200, 0),
|
||||||
|
"blue": (255, 150, 0),
|
||||||
|
"purple": (200, 0, 200),
|
||||||
|
}
|
||||||
|
for color_name, ranges in _COLOR_RANGES.items():
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
for lower, upper in ranges:
|
||||||
|
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||||||
|
# Only draw if there are significant colored pixels
|
||||||
|
if np.sum(mask > 0) < 100:
|
||||||
|
continue
|
||||||
|
# Draw colored contours
|
||||||
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
|
||||||
|
for cnt in contours:
|
||||||
|
area = cv2.contourArea(cnt)
|
||||||
|
if area < 20:
|
||||||
|
continue
|
||||||
|
cv2.drawContours(img, [cnt], -1, draw_color, 2)
|
||||||
|
|
||||||
|
# --- Draw graphic elements ---
|
||||||
|
graphics_data = structure.get("graphics", [])
|
||||||
|
shape_icons = {
|
||||||
|
"image": "IMAGE",
|
||||||
|
"illustration": "ILLUST",
|
||||||
|
}
|
||||||
|
for gfx in graphics_data:
|
||||||
|
gx, gy = gfx["x"], gfx["y"]
|
||||||
|
gw, gh = gfx["w"], gfx["h"]
|
||||||
|
shape = gfx.get("shape", "icon")
|
||||||
|
color_hex = gfx.get("color_hex", "#6b7280")
|
||||||
|
conf = gfx.get("confidence", 0)
|
||||||
|
|
||||||
|
# Pick draw color based on element color (BGR)
|
||||||
|
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
|
||||||
|
|
||||||
|
# Draw bounding box (dashed style via short segments)
|
||||||
|
dash = 6
|
||||||
|
for seg_x in range(gx, gx + gw, dash * 2):
|
||||||
|
end_x = min(seg_x + dash, gx + gw)
|
||||||
|
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
|
||||||
|
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
|
||||||
|
for seg_y in range(gy, gy + gh, dash * 2):
|
||||||
|
end_y = min(seg_y + dash, gy + gh)
|
||||||
|
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
|
||||||
|
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
icon = shape_icons.get(shape, shape.upper()[:5])
|
||||||
|
label = f"{icon} {int(conf * 100)}%"
|
||||||
|
# White background for readability
|
||||||
|
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
||||||
|
lx = gx + 2
|
||||||
|
ly = max(gy - 4, th + 4)
|
||||||
|
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
|
||||||
|
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
|
||||||
|
|
||||||
|
# Encode result
|
||||||
|
_, png_buf = cv2.imencode(".png", img)
|
||||||
|
return Response(content=png_buf.tobytes(), media_type="image/png")
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
Overlay image rendering for OCR pipeline — barrel re-export.
|
||||||
|
|
||||||
|
All implementation split into:
|
||||||
|
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
|
||||||
|
ocr_pipeline_overlay_grid — columns, rows, words overlays
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from .overlay_structure import _get_structure_overlay # noqa: F401
|
||||||
|
from .overlay_grid import ( # noqa: F401
|
||||||
|
_get_columns_overlay,
|
||||||
|
_get_rows_overlay,
|
||||||
|
_get_words_overlay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def render_overlay(overlay_type: str, session_id: str) -> Response:
|
||||||
|
"""Dispatch to the appropriate overlay renderer."""
|
||||||
|
if overlay_type == "structure":
|
||||||
|
return await _get_structure_overlay(session_id)
|
||||||
|
elif overlay_type == "columns":
|
||||||
|
return await _get_columns_overlay(session_id)
|
||||||
|
elif overlay_type == "rows":
|
||||||
|
return await _get_rows_overlay(session_id)
|
||||||
|
elif overlay_type == "words":
|
||||||
|
return await _get_words_overlay(session_id)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
Page Crop — Barrel Re-export
|
||||||
|
|
||||||
|
Content-based crop for scanned pages and book scans.
|
||||||
|
|
||||||
|
Split into:
|
||||||
|
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
|
||||||
|
- page_crop_core.py — Main crop algorithm and format detection
|
||||||
|
|
||||||
|
All public names are re-exported here for backward compatibility.
|
||||||
|
License: Apache 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core: main crop functions and format detection
|
||||||
|
from .page_crop_core import ( # noqa: F401
|
||||||
|
PAPER_FORMATS,
|
||||||
|
detect_page_splits,
|
||||||
|
detect_and_crop_page,
|
||||||
|
_detect_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Edge detection helpers
|
||||||
|
from .page_crop_edges import ( # noqa: F401
|
||||||
|
_INK_THRESHOLD,
|
||||||
|
_MIN_RUN_FRAC,
|
||||||
|
_detect_spine_shadow,
|
||||||
|
_detect_gutter_continuity,
|
||||||
|
_detect_left_edge_shadow,
|
||||||
|
_detect_right_edge_shadow,
|
||||||
|
_detect_top_bottom_edges,
|
||||||
|
_detect_edge_projection,
|
||||||
|
_filter_narrow_runs,
|
||||||
|
)
|
||||||
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
Page Crop - Core Crop and Format Detection
|
||||||
|
|
||||||
|
Content-based crop for scanned pages and book scans. Detects the content
|
||||||
|
boundary by analysing ink density projections and (for book scans) the
|
||||||
|
spine shadow gradient.
|
||||||
|
|
||||||
|
Extracted from page_crop.py to keep files under 500 LOC.
|
||||||
|
License: Apache 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .page_crop_edges import (
|
||||||
|
_detect_left_edge_shadow,
|
||||||
|
_detect_right_edge_shadow,
|
||||||
|
_detect_top_bottom_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Known paper format aspect ratios (height / width, portrait orientation)
|
||||||
|
PAPER_FORMATS = {
|
||||||
|
"A4": 297.0 / 210.0, # 1.4143
|
||||||
|
"A5": 210.0 / 148.0, # 1.4189
|
||||||
|
"Letter": 11.0 / 8.5, # 1.2941
|
||||||
|
"Legal": 14.0 / 8.5, # 1.6471
|
||||||
|
"A3": 420.0 / 297.0, # 1.4141
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_page_splits(
|
||||||
|
img_bgr: np.ndarray,
|
||||||
|
) -> list:
|
||||||
|
"""Detect if the image is a multi-page spread and return split rectangles.
|
||||||
|
|
||||||
|
Uses **brightness** (not ink density) to find the spine area:
|
||||||
|
the scanner bed produces a characteristic gray strip where pages meet,
|
||||||
|
which is darker than the white paper on either side.
|
||||||
|
|
||||||
|
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
||||||
|
or an empty list if only one page is detected.
|
||||||
|
"""
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Only check landscape-ish images (width > height * 1.15)
|
||||||
|
if w < h * 1.15:
|
||||||
|
return []
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
||||||
|
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
||||||
|
|
||||||
|
# Heavy smoothing to ignore individual text lines
|
||||||
|
kern = max(11, w // 50)
|
||||||
|
if kern % 2 == 0:
|
||||||
|
kern += 1
|
||||||
|
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
||||||
|
|
||||||
|
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
||||||
|
page_brightness = float(np.max(brightness_smooth))
|
||||||
|
if page_brightness < 100:
|
||||||
|
return [] # Very dark image, skip
|
||||||
|
|
||||||
|
# Spine threshold: significantly darker than the page
|
||||||
|
spine_thresh = page_brightness * 0.88
|
||||||
|
|
||||||
|
# Search in center region (30-70% of width)
|
||||||
|
center_lo = int(w * 0.30)
|
||||||
|
center_hi = int(w * 0.70)
|
||||||
|
|
||||||
|
# Find the darkest valley in the center region
|
||||||
|
center_brightness = brightness_smooth[center_lo:center_hi]
|
||||||
|
darkest_val = float(np.min(center_brightness))
|
||||||
|
|
||||||
|
if darkest_val >= spine_thresh:
|
||||||
|
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
||||||
|
darkest_val, spine_thresh)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find ALL contiguous dark runs in the center region
|
||||||
|
is_dark = center_brightness < spine_thresh
|
||||||
|
dark_runs: list = []
|
||||||
|
run_start = -1
|
||||||
|
for i in range(len(is_dark)):
|
||||||
|
if is_dark[i]:
|
||||||
|
if run_start < 0:
|
||||||
|
run_start = i
|
||||||
|
else:
|
||||||
|
if run_start >= 0:
|
||||||
|
dark_runs.append((run_start, i))
|
||||||
|
run_start = -1
|
||||||
|
if run_start >= 0:
|
||||||
|
dark_runs.append((run_start, len(is_dark)))
|
||||||
|
|
||||||
|
# Filter out runs that are too narrow (< 1% of image width)
|
||||||
|
min_spine_px = int(w * 0.01)
|
||||||
|
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
||||||
|
|
||||||
|
if not dark_runs:
|
||||||
|
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Score each dark run: prefer centered, dark, narrow valleys
|
||||||
|
center_region_len = center_hi - center_lo
|
||||||
|
image_center_in_region = (w * 0.5 - center_lo)
|
||||||
|
best_score = -1.0
|
||||||
|
best_start, best_end = dark_runs[0]
|
||||||
|
|
||||||
|
for rs, re in dark_runs:
|
||||||
|
run_width = re - rs
|
||||||
|
run_center = (rs + re) / 2.0
|
||||||
|
|
||||||
|
sigma = center_region_len * 0.15
|
||||||
|
dist = abs(run_center - image_center_in_region)
|
||||||
|
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
||||||
|
|
||||||
|
run_brightness = float(np.mean(center_brightness[rs:re]))
|
||||||
|
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
||||||
|
|
||||||
|
width_frac = run_width / w
|
||||||
|
if width_frac <= 0.05:
|
||||||
|
narrowness_bonus = 1.0
|
||||||
|
elif width_frac <= 0.15:
|
||||||
|
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
|
||||||
|
else:
|
||||||
|
narrowness_bonus = 0.0
|
||||||
|
|
||||||
|
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
|
||||||
|
center_lo + rs, center_lo + re, run_width,
|
||||||
|
center_factor, darkness_factor, narrowness_bonus, score,
|
||||||
|
)
|
||||||
|
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_start, best_end = rs, re
|
||||||
|
|
||||||
|
spine_w = best_end - best_start
|
||||||
|
spine_x = center_lo + best_start
|
||||||
|
spine_center = spine_x + spine_w // 2
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
||||||
|
spine_x, spine_x + spine_w, spine_w, best_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: must have bright (paper) content on BOTH sides
|
||||||
|
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
||||||
|
right_end = center_lo + best_end
|
||||||
|
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
||||||
|
|
||||||
|
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
||||||
|
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
||||||
|
left_brightness, right_brightness, spine_thresh)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
||||||
|
"left_paper=%.0f, right_paper=%.0f",
|
||||||
|
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
||||||
|
left_brightness, right_brightness,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split at the spine center
|
||||||
|
split_points = [spine_center]
|
||||||
|
|
||||||
|
# Build page rectangles
|
||||||
|
pages: list = []
|
||||||
|
prev_x = 0
|
||||||
|
for i, sx in enumerate(split_points):
|
||||||
|
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
||||||
|
"height": h, "page_index": i})
|
||||||
|
prev_x = sx
|
||||||
|
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
||||||
|
"height": h, "page_index": len(split_points)})
|
||||||
|
|
||||||
|
# Filter out tiny pages (< 15% of total width)
|
||||||
|
pages = [p for p in pages if p["width"] >= w * 0.15]
|
||||||
|
if len(pages) < 2:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Re-index
|
||||||
|
for i, p in enumerate(pages):
|
||||||
|
p["page_index"] = i
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
||||||
|
len(pages), spine_w, split_points,
|
||||||
|
)
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
|
def detect_and_crop_page(
|
||||||
|
img_bgr: np.ndarray,
|
||||||
|
margin_frac: float = 0.01,
|
||||||
|
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||||
|
"""Detect content boundary and crop scanner/book borders.
|
||||||
|
|
||||||
|
Algorithm (4-edge detection):
|
||||||
|
1. Adaptive threshold -> binary (text=255, bg=0)
|
||||||
|
2. Left edge: spine-shadow detection via grayscale column means,
|
||||||
|
fallback to binary vertical projection
|
||||||
|
3. Right edge: binary vertical projection (last ink column)
|
||||||
|
4. Top/bottom edges: binary horizontal projection
|
||||||
|
5. Sanity checks, then crop with configurable margin
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
||||||
|
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cropped_image, result_dict)
|
||||||
|
"""
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
total_area = h * w
|
||||||
|
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"crop_applied": False,
|
||||||
|
"crop_rect": None,
|
||||||
|
"crop_rect_pct": None,
|
||||||
|
"original_size": {"width": w, "height": h},
|
||||||
|
"cropped_size": {"width": w, "height": h},
|
||||||
|
"detected_format": None,
|
||||||
|
"format_confidence": 0.0,
|
||||||
|
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
||||||
|
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# --- Binarise with adaptive threshold ---
|
||||||
|
binary = cv2.adaptiveThreshold(
|
||||||
|
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||||
|
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Edge detection ---
|
||||||
|
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
||||||
|
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
||||||
|
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
||||||
|
|
||||||
|
# Compute border fractions
|
||||||
|
border_top = top_edge / h
|
||||||
|
border_bottom = (h - bottom_edge) / h
|
||||||
|
border_left = left_edge / w
|
||||||
|
border_right = (w - right_edge) / w
|
||||||
|
|
||||||
|
result["border_fractions"] = {
|
||||||
|
"top": round(border_top, 4),
|
||||||
|
"bottom": round(border_bottom, 4),
|
||||||
|
"left": round(border_left, 4),
|
||||||
|
"right": round(border_right, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sanity: only crop if at least one edge has > 2% border
|
||||||
|
min_border = 0.02
|
||||||
|
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
||||||
|
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
||||||
|
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||||
|
return img_bgr, result
|
||||||
|
|
||||||
|
# Add margin
|
||||||
|
margin_x = int(w * margin_frac)
|
||||||
|
margin_y = int(h * margin_frac)
|
||||||
|
|
||||||
|
crop_x = max(0, left_edge - margin_x)
|
||||||
|
crop_y = max(0, top_edge - margin_y)
|
||||||
|
crop_x2 = min(w, right_edge + margin_x)
|
||||||
|
crop_y2 = min(h, bottom_edge + margin_y)
|
||||||
|
|
||||||
|
crop_w = crop_x2 - crop_x
|
||||||
|
crop_h = crop_y2 - crop_y
|
||||||
|
|
||||||
|
# Sanity: cropped area must be >= 40% of original
|
||||||
|
if crop_w * crop_h < 0.40 * total_area:
|
||||||
|
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
||||||
|
100.0 * crop_w * crop_h / total_area)
|
||||||
|
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||||
|
return img_bgr, result
|
||||||
|
|
||||||
|
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
||||||
|
|
||||||
|
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
||||||
|
|
||||||
|
result["crop_applied"] = True
|
||||||
|
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
||||||
|
result["crop_rect_pct"] = {
|
||||||
|
"x": round(100.0 * crop_x / w, 2),
|
||||||
|
"y": round(100.0 * crop_y / h, 2),
|
||||||
|
"width": round(100.0 * crop_w / w, 2),
|
||||||
|
"height": round(100.0 * crop_h / h, 2),
|
||||||
|
}
|
||||||
|
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
||||||
|
result["detected_format"] = detected_format
|
||||||
|
result["format_confidence"] = format_confidence
|
||||||
|
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
||||||
|
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
||||||
|
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
||||||
|
border_top * 100, border_bottom * 100,
|
||||||
|
border_left * 100, border_right * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cropped, result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Format detection (kept as optional metadata)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
||||||
|
"""Detect paper format from dimensions by comparing aspect ratios."""
|
||||||
|
if width <= 0 or height <= 0:
|
||||||
|
return "unknown", 0.0
|
||||||
|
|
||||||
|
aspect = max(width, height) / min(width, height)
|
||||||
|
|
||||||
|
best_format = "unknown"
|
||||||
|
best_diff = float("inf")
|
||||||
|
|
||||||
|
for fmt, expected_ratio in PAPER_FORMATS.items():
|
||||||
|
diff = abs(aspect - expected_ratio)
|
||||||
|
if diff < best_diff:
|
||||||
|
best_diff = diff
|
||||||
|
best_format = fmt
|
||||||
|
|
||||||
|
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
||||||
|
|
||||||
|
if confidence < 0.3:
|
||||||
|
return "unknown", 0.0
|
||||||
|
|
||||||
|
return best_format, round(confidence, 3)
|
||||||
@@ -0,0 +1,388 @@
|
|||||||
|
"""
|
||||||
|
Page Crop - Edge Detection Helpers
|
||||||
|
|
||||||
|
Spine shadow detection, gutter continuity analysis, projection-based
|
||||||
|
edge detection, and narrow-run filtering for content cropping.
|
||||||
|
|
||||||
|
Extracted from page_crop.py to keep files under 500 LOC.
|
||||||
|
License: Apache 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
||||||
|
_INK_THRESHOLD = 0.003 # 0.3%
|
||||||
|
|
||||||
|
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
||||||
|
_MIN_RUN_FRAC = 0.005 # 0.5%
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_spine_shadow(
|
||||||
|
gray: np.ndarray,
|
||||||
|
search_region: np.ndarray,
|
||||||
|
offset_x: int,
|
||||||
|
w: int,
|
||||||
|
side: str,
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Find the book spine center (darkest point) in a scanner shadow.
|
||||||
|
|
||||||
|
The scanner produces a gray strip where the book spine presses against
|
||||||
|
the glass. The darkest column in that strip is the spine center —
|
||||||
|
that's where we crop.
|
||||||
|
|
||||||
|
Distinguishes real spine shadows from text content by checking:
|
||||||
|
1. Strong brightness range (> 40 levels)
|
||||||
|
2. Darkest point is genuinely dark (< 180 mean brightness)
|
||||||
|
3. The dark area is a NARROW valley, not a text-content plateau
|
||||||
|
4. Brightness rises significantly toward the page content side
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gray: Full grayscale image (for context).
|
||||||
|
search_region: Column slice of the grayscale image to search in.
|
||||||
|
offset_x: X offset of search_region relative to full image.
|
||||||
|
w: Full image width.
|
||||||
|
side: 'left' or 'right' (for logging).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
X coordinate (in full image) of the spine center, or None.
|
||||||
|
"""
|
||||||
|
region_w = search_region.shape[1]
|
||||||
|
if region_w < 10:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Column-mean brightness in the search region
|
||||||
|
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
||||||
|
|
||||||
|
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
||||||
|
kernel_size = max(5, w // 100)
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
kernel = np.ones(kernel_size) / kernel_size
|
||||||
|
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
||||||
|
|
||||||
|
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
|
||||||
|
margin = kernel_size // 2
|
||||||
|
if region_w <= 2 * margin + 10:
|
||||||
|
return None
|
||||||
|
smoothed = smoothed_raw[margin:region_w - margin]
|
||||||
|
trim_offset = margin # offset of smoothed[0] relative to search_region
|
||||||
|
|
||||||
|
val_min = float(np.min(smoothed))
|
||||||
|
val_max = float(np.max(smoothed))
|
||||||
|
shadow_range = val_max - val_min
|
||||||
|
|
||||||
|
# --- Check 1: Strong brightness gradient ---
|
||||||
|
if shadow_range <= 40:
|
||||||
|
logger.debug(
|
||||||
|
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- Check 2: Darkest point must be genuinely dark ---
|
||||||
|
if val_min > 180:
|
||||||
|
logger.debug(
|
||||||
|
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
||||||
|
spine_local = spine_idx + trim_offset # index in search_region
|
||||||
|
trimmed_len = len(smoothed)
|
||||||
|
|
||||||
|
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
||||||
|
valley_thresh = val_min + shadow_range * 0.20
|
||||||
|
valley_mask = smoothed < valley_thresh
|
||||||
|
valley_width = int(np.sum(valley_mask))
|
||||||
|
max_valley_frac = 0.50
|
||||||
|
if valley_width > trimmed_len * max_valley_frac:
|
||||||
|
logger.debug(
|
||||||
|
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
||||||
|
side.capitalize(), valley_width, trimmed_len,
|
||||||
|
100.0 * valley_width / trimmed_len,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- Check 4: Brightness must rise toward page content ---
|
||||||
|
rise_check_w = max(5, trimmed_len // 5)
|
||||||
|
if side == "left":
|
||||||
|
right_start = min(spine_idx + 5, trimmed_len - 1)
|
||||||
|
right_end = min(right_start + rise_check_w, trimmed_len)
|
||||||
|
if right_end > right_start:
|
||||||
|
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
||||||
|
rise = rise_brightness - val_min
|
||||||
|
if rise < shadow_range * 0.3:
|
||||||
|
logger.debug(
|
||||||
|
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||||
|
side.capitalize(), rise, shadow_range * 0.3,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else: # right
|
||||||
|
left_end = max(spine_idx - 5, 0)
|
||||||
|
left_start = max(left_end - rise_check_w, 0)
|
||||||
|
if left_end > left_start:
|
||||||
|
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
||||||
|
rise = rise_brightness - val_min
|
||||||
|
if rise < shadow_range * 0.3:
|
||||||
|
logger.debug(
|
||||||
|
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||||
|
side.capitalize(), rise, shadow_range * 0.3,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
spine_x = offset_x + spine_local
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
||||||
|
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
||||||
|
)
|
||||||
|
return spine_x
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_gutter_continuity(
|
||||||
|
gray: np.ndarray,
|
||||||
|
search_region: np.ndarray,
|
||||||
|
offset_x: int,
|
||||||
|
w: int,
|
||||||
|
side: str,
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Detect gutter shadow via vertical continuity analysis.
|
||||||
|
|
||||||
|
Camera book scans produce a subtle brightness gradient at the gutter
|
||||||
|
that is too faint for scanner-shadow detection (range < 40). However,
|
||||||
|
the gutter shadow has a unique property: it runs **continuously from
|
||||||
|
top to bottom** without interruption.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Divide image into N horizontal strips (~60px each)
|
||||||
|
2. For each column, compute what fraction of strips are darker than
|
||||||
|
the page median (from the center 50% of the full image)
|
||||||
|
3. A "gutter column" has >= 75% of strips darker than page_median - d
|
||||||
|
4. Smooth the dark-fraction profile and find the transition point
|
||||||
|
5. Validate: gutter band must be 0.5%-10% of image width
|
||||||
|
"""
|
||||||
|
region_h, region_w = search_region.shape[:2]
|
||||||
|
if region_w < 20 or region_h < 100:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 1. Divide into horizontal strips ---
|
||||||
|
strip_target_h = 60
|
||||||
|
n_strips = max(10, region_h // strip_target_h)
|
||||||
|
strip_h = region_h // n_strips
|
||||||
|
|
||||||
|
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
||||||
|
for s in range(n_strips):
|
||||||
|
y0 = s * strip_h
|
||||||
|
y1 = min((s + 1) * strip_h, region_h)
|
||||||
|
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
||||||
|
|
||||||
|
# --- 2. Page median from center 50% of full image ---
|
||||||
|
center_lo = w // 4
|
||||||
|
center_hi = 3 * w // 4
|
||||||
|
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
||||||
|
|
||||||
|
dark_thresh = page_median - 5.0
|
||||||
|
|
||||||
|
if page_median < 180:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 3. Per-column dark fraction ---
|
||||||
|
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
||||||
|
dark_frac = dark_count / n_strips
|
||||||
|
|
||||||
|
# --- 4. Smooth and find transition ---
|
||||||
|
smooth_w = max(5, w // 100)
|
||||||
|
if smooth_w % 2 == 0:
|
||||||
|
smooth_w += 1
|
||||||
|
kernel = np.ones(smooth_w) / smooth_w
|
||||||
|
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
||||||
|
|
||||||
|
margin = smooth_w // 2
|
||||||
|
if region_w <= 2 * margin + 10:
|
||||||
|
return None
|
||||||
|
|
||||||
|
transition_thresh = 0.50
|
||||||
|
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
||||||
|
|
||||||
|
if peak_frac < 0.70:
|
||||||
|
logger.debug(
|
||||||
|
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
||||||
|
gutter_inner = None
|
||||||
|
|
||||||
|
if side == "right":
|
||||||
|
for x in range(peak_x, margin, -1):
|
||||||
|
if frac_smooth[x] < transition_thresh:
|
||||||
|
gutter_inner = x + 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for x in range(peak_x, region_w - margin):
|
||||||
|
if frac_smooth[x] < transition_thresh:
|
||||||
|
gutter_inner = x - 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if gutter_inner is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 5. Validate gutter width ---
|
||||||
|
if side == "right":
|
||||||
|
gutter_width = region_w - gutter_inner
|
||||||
|
else:
|
||||||
|
gutter_width = gutter_inner
|
||||||
|
|
||||||
|
min_gutter = max(3, int(w * 0.005))
|
||||||
|
max_gutter = int(w * 0.10)
|
||||||
|
|
||||||
|
if gutter_width < min_gutter:
|
||||||
|
logger.debug(
|
||||||
|
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
||||||
|
gutter_width, min_gutter,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if gutter_width > max_gutter:
|
||||||
|
logger.debug(
|
||||||
|
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
||||||
|
gutter_width, max_gutter,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if side == "right":
|
||||||
|
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
||||||
|
else:
|
||||||
|
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
||||||
|
|
||||||
|
brightness_drop = page_median - gutter_brightness
|
||||||
|
if brightness_drop < 3:
|
||||||
|
logger.debug(
|
||||||
|
"%s gutter: insufficient brightness drop (%.1f levels)",
|
||||||
|
side.capitalize(), brightness_drop,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
gutter_x = offset_x + gutter_inner
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
||||||
|
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
||||||
|
side.capitalize(), gutter_x, gutter_width,
|
||||||
|
100.0 * gutter_width / w, gutter_brightness, page_median,
|
||||||
|
brightness_drop, float(frac_smooth[gutter_inner]),
|
||||||
|
)
|
||||||
|
return gutter_x
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_left_edge_shadow(
|
||||||
|
gray: np.ndarray,
|
||||||
|
binary: np.ndarray,
|
||||||
|
w: int,
|
||||||
|
h: int,
|
||||||
|
) -> int:
|
||||||
|
"""Detect left content edge, accounting for book-spine shadow.
|
||||||
|
|
||||||
|
Tries three methods in order:
|
||||||
|
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||||
|
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||||
|
3. Binary projection fallback (first ink column)
|
||||||
|
"""
|
||||||
|
search_w = max(1, w // 4)
|
||||||
|
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
||||||
|
if spine_x is not None:
|
||||||
|
return spine_x
|
||||||
|
|
||||||
|
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
||||||
|
if gutter_x is not None:
|
||||||
|
return gutter_x
|
||||||
|
|
||||||
|
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_right_edge_shadow(
|
||||||
|
gray: np.ndarray,
|
||||||
|
binary: np.ndarray,
|
||||||
|
w: int,
|
||||||
|
h: int,
|
||||||
|
) -> int:
|
||||||
|
"""Detect right content edge, accounting for book-spine shadow.
|
||||||
|
|
||||||
|
Tries three methods in order:
|
||||||
|
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||||
|
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||||
|
3. Binary projection fallback (last ink column)
|
||||||
|
"""
|
||||||
|
search_w = max(1, w // 4)
|
||||||
|
right_start = w - search_w
|
||||||
|
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
||||||
|
if spine_x is not None:
|
||||||
|
return spine_x
|
||||||
|
|
||||||
|
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
||||||
|
if gutter_x is not None:
|
||||||
|
return gutter_x
|
||||||
|
|
||||||
|
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
||||||
|
"""Detect top and bottom content edges via binary horizontal projection."""
|
||||||
|
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
||||||
|
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
||||||
|
return top, bottom
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_edge_projection(
|
||||||
|
binary: np.ndarray,
|
||||||
|
axis: int,
|
||||||
|
from_start: bool,
|
||||||
|
dim: int,
|
||||||
|
) -> int:
|
||||||
|
"""Find the first/last row or column with ink density above threshold.
|
||||||
|
|
||||||
|
axis=0 -> project vertically (column densities) -> returns x position
|
||||||
|
axis=1 -> project horizontally (row densities) -> returns y position
|
||||||
|
|
||||||
|
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
||||||
|
"""
|
||||||
|
projection = np.mean(binary, axis=axis) / 255.0
|
||||||
|
|
||||||
|
ink_mask = projection >= _INK_THRESHOLD
|
||||||
|
|
||||||
|
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
||||||
|
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
||||||
|
|
||||||
|
ink_positions = np.where(ink_mask)[0]
|
||||||
|
if len(ink_positions) == 0:
|
||||||
|
return 0 if from_start else dim
|
||||||
|
|
||||||
|
if from_start:
|
||||||
|
return int(ink_positions[0])
|
||||||
|
else:
|
||||||
|
return int(ink_positions[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
||||||
|
"""Remove True-runs shorter than min_run pixels."""
|
||||||
|
if min_run <= 1:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
result = mask.copy()
|
||||||
|
n = len(result)
|
||||||
|
i = 0
|
||||||
|
while i < n:
|
||||||
|
if result[i]:
|
||||||
|
start = i
|
||||||
|
while i < n and result[i]:
|
||||||
|
i += 1
|
||||||
|
if i - start < min_run:
|
||||||
|
result[start:i] = False
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
return result
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Sub-session creation for multi-page spreads.
|
||||||
|
|
||||||
|
Used by both the page-split and crop steps when a double-page scan is detected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid as uuid_mod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .page_crop import detect_and_crop_page
|
||||||
|
from .session_store import (
|
||||||
|
create_session_db,
|
||||||
|
get_sub_sessions,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .orientation_crop_helpers import get_cache_ref
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_page_sub_sessions(
|
||||||
|
parent_session_id: str,
|
||||||
|
parent_cached: dict,
|
||||||
|
full_img_bgr: np.ndarray,
|
||||||
|
page_splits: List[Dict[str, Any]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Create sub-sessions for each detected page in a multi-page spread.
|
||||||
|
|
||||||
|
Each page region is individually cropped, then stored as a sub-session
|
||||||
|
with its own cropped image ready for the rest of the pipeline.
|
||||||
|
"""
|
||||||
|
# Check for existing sub-sessions (idempotent)
|
||||||
|
existing = await get_sub_sessions(parent_session_id)
|
||||||
|
if existing:
|
||||||
|
return [
|
||||||
|
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||||
|
for i, s in enumerate(existing)
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_name = parent_cached.get("name", "Scan")
|
||||||
|
parent_filename = parent_cached.get("filename", "scan.png")
|
||||||
|
|
||||||
|
sub_sessions: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for page in page_splits:
|
||||||
|
pi = page["page_index"]
|
||||||
|
px, py = page["x"], page["y"]
|
||||||
|
pw, ph = page["width"], page["height"]
|
||||||
|
|
||||||
|
# Extract page region
|
||||||
|
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||||
|
|
||||||
|
# Crop each page individually (remove its own borders)
|
||||||
|
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
ok, png_buf = cv2.imencode(".png", cropped_page)
|
||||||
|
page_png = png_buf.tobytes() if ok else b""
|
||||||
|
|
||||||
|
sub_id = str(uuid_mod.uuid4())
|
||||||
|
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||||
|
|
||||||
|
await create_session_db(
|
||||||
|
session_id=sub_id,
|
||||||
|
name=sub_name,
|
||||||
|
filename=parent_filename,
|
||||||
|
original_png=page_png,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-populate: set cropped = original (already cropped)
|
||||||
|
await update_session_db(
|
||||||
|
sub_id,
|
||||||
|
cropped_png=page_png,
|
||||||
|
crop_result=page_crop_info,
|
||||||
|
current_step=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
ch, cw = cropped_page.shape[:2]
|
||||||
|
sub_sessions.append({
|
||||||
|
"id": sub_id,
|
||||||
|
"name": sub_name,
|
||||||
|
"page_index": pi,
|
||||||
|
"source_rect": page,
|
||||||
|
"cropped_size": {"width": cw, "height": ch},
|
||||||
|
"detected_format": page_crop_info.get("detected_format"),
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
|
||||||
|
sub_id, pi + 1, px, pw, cw, ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def create_page_sub_sessions_full(
|
||||||
|
parent_session_id: str,
|
||||||
|
parent_cached: dict,
|
||||||
|
full_img_bgr: np.ndarray,
|
||||||
|
page_splits: List[Dict[str, Any]],
|
||||||
|
start_step: int = 2,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
|
||||||
|
|
||||||
|
Unlike ``create_page_sub_sessions`` (used by the crop step), these
|
||||||
|
sub-sessions store the *uncropped* page region and start at
|
||||||
|
``start_step`` (default 2 = ready for deskew; 1 if orientation still
|
||||||
|
needed). Each page goes through its own pipeline independently,
|
||||||
|
which is essential for book spreads where each page has a different tilt.
|
||||||
|
"""
|
||||||
|
_cache = get_cache_ref()
|
||||||
|
|
||||||
|
# Idempotent: reuse existing sub-sessions
|
||||||
|
existing = await get_sub_sessions(parent_session_id)
|
||||||
|
if existing:
|
||||||
|
return [
|
||||||
|
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||||
|
for i, s in enumerate(existing)
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_name = parent_cached.get("name", "Scan")
|
||||||
|
parent_filename = parent_cached.get("filename", "scan.png")
|
||||||
|
|
||||||
|
sub_sessions: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for page in page_splits:
|
||||||
|
pi = page["page_index"]
|
||||||
|
px, py = page["x"], page["y"]
|
||||||
|
pw, ph = page["width"], page["height"]
|
||||||
|
|
||||||
|
# Extract RAW page region — NO individual cropping here; each
|
||||||
|
# sub-session will run its own crop step after deskew + dewarp.
|
||||||
|
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
ok, png_buf = cv2.imencode(".png", page_bgr)
|
||||||
|
page_png = png_buf.tobytes() if ok else b""
|
||||||
|
|
||||||
|
sub_id = str(uuid_mod.uuid4())
|
||||||
|
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||||
|
|
||||||
|
await create_session_db(
|
||||||
|
session_id=sub_id,
|
||||||
|
name=sub_name,
|
||||||
|
filename=parent_filename,
|
||||||
|
original_png=page_png,
|
||||||
|
)
|
||||||
|
|
||||||
|
# start_step=2 -> ready for deskew (orientation already done on spread)
|
||||||
|
# start_step=1 -> needs its own orientation (split from original image)
|
||||||
|
await update_session_db(sub_id, current_step=start_step)
|
||||||
|
|
||||||
|
# Cache the BGR so the pipeline can start immediately
|
||||||
|
_cache[sub_id] = {
|
||||||
|
"id": sub_id,
|
||||||
|
"filename": parent_filename,
|
||||||
|
"name": sub_name,
|
||||||
|
"original_bgr": page_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": start_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
rh, rw = page_bgr.shape[:2]
|
||||||
|
sub_sessions.append({
|
||||||
|
"id": sub_id,
|
||||||
|
"name": sub_name,
|
||||||
|
"page_index": pi,
|
||||||
|
"source_rect": page,
|
||||||
|
"image_size": {"width": rw, "height": rh},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
|
||||||
|
sub_id, pi + 1, px, pw, rw, rh,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub_sessions
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Postprocessing API — composite router assembling LLM review,
|
||||||
|
reconstruction, export, validation, image detection/generation, and
|
||||||
|
handwriting removal endpoints.
|
||||||
|
|
||||||
|
Split into sub-modules:
|
||||||
|
ocr_pipeline_llm_review — LLM review + apply corrections
|
||||||
|
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
|
||||||
|
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .llm_review import router as _llm_review_router
|
||||||
|
from .reconstruction import router as _reconstruction_router
|
||||||
|
from .validation import router as _validation_router
|
||||||
|
|
||||||
|
# Composite router — drop-in replacement for the old monolithic router.
|
||||||
|
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_llm_review_router)
|
||||||
|
router.include_router(_reconstruction_router)
|
||||||
|
router.include_router(_validation_router)
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_postprocess.py.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_sub_sessions,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import _cache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 9: Reconstruction + Fabric JSON export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction")
|
||||||
|
async def save_reconstruction(session_id: str, request: Request):
|
||||||
|
"""Save edited cell texts from reconstruction step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
cell_updates = body.get("cells", [])
|
||||||
|
|
||||||
|
if not cell_updates:
|
||||||
|
await update_session_db(session_id, current_step=10)
|
||||||
|
return {"session_id": session_id, "updated": 0}
|
||||||
|
|
||||||
|
# Build update map: cell_id -> new text
|
||||||
|
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||||||
|
|
||||||
|
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||||||
|
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||||||
|
main_updates: Dict[str, str] = {}
|
||||||
|
for cell_id, text in update_map.items():
|
||||||
|
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||||||
|
if m:
|
||||||
|
bi = int(m.group(1))
|
||||||
|
original_id = m.group(2)
|
||||||
|
sub_updates.setdefault(bi, {})[original_id] = text
|
||||||
|
else:
|
||||||
|
main_updates[cell_id] = text
|
||||||
|
|
||||||
|
# Update main session cells
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
updated_count = 0
|
||||||
|
for cell in cells:
|
||||||
|
if cell["cell_id"] in main_updates:
|
||||||
|
cell["text"] = main_updates[cell["cell_id"]]
|
||||||
|
cell["status"] = "edited"
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
word_result["cells"] = cells
|
||||||
|
|
||||||
|
# Also update vocab_entries if present
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
if entries:
|
||||||
|
for entry in entries:
|
||||||
|
row_idx = entry.get("row_index", -1)
|
||||||
|
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||||||
|
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||||||
|
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||||||
|
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||||||
|
if new_text is not None:
|
||||||
|
entry[field_name] = new_text
|
||||||
|
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
if "entries" in word_result:
|
||||||
|
word_result["entries"] = entries
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
# Route sub-session updates
|
||||||
|
sub_updated = 0
|
||||||
|
if sub_updates:
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||||||
|
for bi, updates in sub_updates.items():
|
||||||
|
sub_id = sub_by_index.get(bi)
|
||||||
|
if not sub_id:
|
||||||
|
continue
|
||||||
|
sub_session = await get_session_db(sub_id)
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result")
|
||||||
|
if not sub_word:
|
||||||
|
continue
|
||||||
|
sub_cells = sub_word.get("cells", [])
|
||||||
|
for cell in sub_cells:
|
||||||
|
if cell["cell_id"] in updates:
|
||||||
|
cell["text"] = updates[cell["cell_id"]]
|
||||||
|
cell["status"] = "edited"
|
||||||
|
sub_updated += 1
|
||||||
|
sub_word["cells"] = sub_cells
|
||||||
|
await update_session_db(sub_id, word_result=sub_word)
|
||||||
|
if sub_id in _cache:
|
||||||
|
_cache[sub_id]["word_result"] = sub_word
|
||||||
|
|
||||||
|
total_updated = updated_count + sub_updated
|
||||||
|
logger.info(f"Reconstruction saved for session {session_id}: "
|
||||||
|
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"updated": total_updated,
|
||||||
|
"main_updated": updated_count,
|
||||||
|
"sub_updated": sub_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||||||
|
async def get_fabric_json(session_id: str):
|
||||||
|
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = list(word_result.get("cells", []))
|
||||||
|
img_w = word_result.get("image_width", 800)
|
||||||
|
img_h = word_result.get("image_height", 600)
|
||||||
|
|
||||||
|
# Merge sub-session cells at box positions
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
|
||||||
|
for sub in subs:
|
||||||
|
sub_session = await get_session_db(sub["id"])
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result")
|
||||||
|
if not sub_word or not sub_word.get("cells"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
bi = sub.get("box_index", 0)
|
||||||
|
if bi < len(box_zones):
|
||||||
|
box = box_zones[bi]["box"]
|
||||||
|
box_y, box_x = box["y"], box["x"]
|
||||||
|
else:
|
||||||
|
box_y, box_x = 0, 0
|
||||||
|
|
||||||
|
for cell in sub_word["cells"]:
|
||||||
|
cell_copy = dict(cell)
|
||||||
|
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||||||
|
cell_copy["source"] = f"box_{bi}"
|
||||||
|
bbox = cell_copy.get("bbox_px", {})
|
||||||
|
if bbox:
|
||||||
|
bbox = dict(bbox)
|
||||||
|
bbox["x"] = bbox.get("x", 0) + box_x
|
||||||
|
bbox["y"] = bbox.get("y", 0) + box_y
|
||||||
|
cell_copy["bbox_px"] = bbox
|
||||||
|
cells.append(cell_copy)
|
||||||
|
|
||||||
|
from services.layout_reconstruction_service import cells_to_fabric_json
|
||||||
|
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||||||
|
|
||||||
|
return fabric_json
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vocab entries merged + PDF/DOCX export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||||||
|
async def get_merged_vocab_entries(session_id: str):
|
||||||
|
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result") or {}
|
||||||
|
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||||||
|
|
||||||
|
for e in entries:
|
||||||
|
e.setdefault("source", "main")
|
||||||
|
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
|
||||||
|
for sub in subs:
|
||||||
|
sub_session = await get_session_db(sub["id"])
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result") or {}
|
||||||
|
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||||||
|
|
||||||
|
bi = sub.get("box_index", 0)
|
||||||
|
box_y = 0
|
||||||
|
if bi < len(box_zones):
|
||||||
|
box_y = box_zones[bi]["box"]["y"]
|
||||||
|
|
||||||
|
for e in sub_entries:
|
||||||
|
e_copy = dict(e)
|
||||||
|
e_copy["source"] = f"box_{bi}"
|
||||||
|
e_copy["source_y"] = box_y
|
||||||
|
entries.append(e_copy)
|
||||||
|
|
||||||
|
def _sort_key(e):
|
||||||
|
if e.get("source", "main") == "main":
|
||||||
|
return e.get("row_index", 0) * 100
|
||||||
|
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||||||
|
|
||||||
|
entries.sort(key=_sort_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"entries": entries,
|
||||||
|
"total": len(entries),
|
||||||
|
"sources": list(set(e.get("source", "main") for e in entries)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||||||
|
async def export_reconstruction_pdf(session_id: str):
|
||||||
|
"""Export the reconstructed cell grid as a PDF table."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
columns_used = word_result.get("columns_used", [])
|
||||||
|
grid_shape = word_result.get("grid_shape", {})
|
||||||
|
n_rows = grid_shape.get("rows", 0)
|
||||||
|
n_cols = grid_shape.get("cols", 0)
|
||||||
|
|
||||||
|
# Build table data: rows x columns
|
||||||
|
table_data: list[list[str]] = []
|
||||||
|
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||||
|
if not header:
|
||||||
|
header = [f"Col {i}" for i in range(n_cols)]
|
||||||
|
table_data.append(header)
|
||||||
|
|
||||||
|
for r in range(n_rows):
|
||||||
|
row_texts = []
|
||||||
|
for ci in range(n_cols):
|
||||||
|
cell_id = f"R{r:02d}_C{ci}"
|
||||||
|
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||||
|
row_texts.append(cell.get("text", "") if cell else "")
|
||||||
|
table_data.append(row_texts)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reportlab.lib.pagesizes import A4
|
||||||
|
from reportlab.lib import colors
|
||||||
|
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||||||
|
if not table_data or not table_data[0]:
|
||||||
|
raise HTTPException(status_code=400, detail="No data to export")
|
||||||
|
|
||||||
|
t = Table(table_data)
|
||||||
|
t.setStyle(TableStyle([
|
||||||
|
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||||||
|
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||||
|
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||||
|
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||||
|
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||||
|
('WORDWRAP', (0, 0), (-1, -1), True),
|
||||||
|
]))
|
||||||
|
doc.build([t])
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
buf,
|
||||||
|
media_type="application/pdf",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||||||
|
async def export_reconstruction_docx(session_id: str):
|
||||||
|
"""Export the reconstructed cell grid as a DOCX table."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
columns_used = word_result.get("columns_used", [])
|
||||||
|
grid_shape = word_result.get("grid_shape", {})
|
||||||
|
n_rows = grid_shape.get("rows", 0)
|
||||||
|
n_cols = grid_shape.get("cols", 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from docx import Document
|
||||||
|
from docx.shared import Pt
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
|
||||||
|
|
||||||
|
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||||
|
if not header:
|
||||||
|
header = [f"Col {i}" for i in range(n_cols)]
|
||||||
|
|
||||||
|
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||||||
|
table.style = 'Table Grid'
|
||||||
|
|
||||||
|
for ci, h in enumerate(header):
|
||||||
|
table.rows[0].cells[ci].text = h
|
||||||
|
|
||||||
|
for r in range(n_rows):
|
||||||
|
for ci in range(n_cols):
|
||||||
|
cell_id = f"R{r:02d}_C{ci}"
|
||||||
|
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||||
|
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||||||
|
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
doc.save(buf)
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
buf,
|
||||||
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Regression Tests — barrel re-export.
|
||||||
|
|
||||||
|
All implementation split into:
|
||||||
|
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
|
||||||
|
ocr_pipeline_regression_endpoints — FastAPI routes
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Helpers (used by grid_editor_api_grid.py)
|
||||||
|
from .regression_helpers import ( # noqa: F401
|
||||||
|
_init_regression_table,
|
||||||
|
_persist_regression_run,
|
||||||
|
_extract_cells_for_comparison,
|
||||||
|
_build_reference_snapshot,
|
||||||
|
compare_grids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Endpoints (router used by ocr_pipeline_api.py)
|
||||||
|
from .regression_endpoints import router # noqa: F401
|
||||||
@@ -0,0 +1,421 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_regression.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from grid_editor_api import _build_grid_core
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
list_ground_truth_sessions_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .regression_helpers import (
|
||||||
|
_build_reference_snapshot,
|
||||||
|
_init_regression_table,
|
||||||
|
_persist_regression_run,
|
||||||
|
compare_grids,
|
||||||
|
get_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/mark-ground-truth")
|
||||||
|
async def mark_ground_truth(
|
||||||
|
session_id: str,
|
||||||
|
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
|
||||||
|
):
|
||||||
|
"""Save the current build-grid result as ground-truth reference."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
grid_result = session.get("grid_editor_result")
|
||||||
|
if not grid_result or not grid_result.get("zones"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No grid_editor_result found. Run build-grid first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-detect pipeline from word_result if not provided
|
||||||
|
if not pipeline:
|
||||||
|
wr = session.get("word_result") or {}
|
||||||
|
engine = wr.get("ocr_engine", "")
|
||||||
|
if engine in ("kombi", "rapid_kombi"):
|
||||||
|
pipeline = "kombi"
|
||||||
|
elif engine == "paddle_direct":
|
||||||
|
pipeline = "paddle-direct"
|
||||||
|
else:
|
||||||
|
pipeline = "pipeline"
|
||||||
|
|
||||||
|
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
|
||||||
|
|
||||||
|
# Merge into existing ground_truth JSONB
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
gt["build_grid_reference"] = reference
|
||||||
|
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
||||||
|
|
||||||
|
# Compare with auto-snapshot if available (shows what the user corrected)
|
||||||
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||||
|
correction_diff = None
|
||||||
|
if auto_snapshot:
|
||||||
|
correction_diff = compare_grids(auto_snapshot, reference)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
||||||
|
session_id,
|
||||||
|
len(reference["cells"]),
|
||||||
|
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"session_id": session_id,
|
||||||
|
"cells_saved": len(reference["cells"]),
|
||||||
|
"summary": reference["summary"],
|
||||||
|
"correction_diff": correction_diff,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
||||||
|
async def unmark_ground_truth(session_id: str):
|
||||||
|
"""Remove the ground-truth reference from a session."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
if "build_grid_reference" not in gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
||||||
|
|
||||||
|
del gt["build_grid_reference"]
|
||||||
|
await update_session_db(session_id, ground_truth=gt)
|
||||||
|
|
||||||
|
logger.info("Ground truth removed for session %s", session_id)
|
||||||
|
return {"status": "ok", "session_id": session_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/correction-diff")
|
||||||
|
async def get_correction_diff(session_id: str):
|
||||||
|
"""Compare automatic OCR grid with manually corrected ground truth.
|
||||||
|
|
||||||
|
Returns a diff showing exactly which cells the user corrected,
|
||||||
|
broken down by col_type (english, german, ipa, etc.).
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
|
||||||
|
if not auto_snapshot:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
||||||
|
)
|
||||||
|
if not reference:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No ground truth reference found. Mark as ground truth first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
diff = compare_grids(auto_snapshot, reference)
|
||||||
|
|
||||||
|
# Enrich with per-col_type breakdown
|
||||||
|
col_type_stats: Dict[str, Dict[str, int]] = {}
|
||||||
|
for cell_diff in diff.get("cell_diffs", []):
|
||||||
|
if cell_diff["type"] != "text_change":
|
||||||
|
continue
|
||||||
|
# Find col_type from reference cells
|
||||||
|
cell_id = cell_diff["cell_id"]
|
||||||
|
ref_cell = next(
|
||||||
|
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
||||||
|
if ct not in col_type_stats:
|
||||||
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||||
|
col_type_stats[ct]["corrected"] += 1
|
||||||
|
|
||||||
|
# Count total cells per col_type from reference
|
||||||
|
for cell in reference.get("cells", []):
|
||||||
|
ct = cell.get("col_type", "unknown")
|
||||||
|
if ct not in col_type_stats:
|
||||||
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||||
|
col_type_stats[ct]["total"] += 1
|
||||||
|
|
||||||
|
# Calculate accuracy per col_type
|
||||||
|
for ct, stats in col_type_stats.items():
|
||||||
|
total = stats["total"]
|
||||||
|
corrected = stats["corrected"]
|
||||||
|
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
||||||
|
|
||||||
|
diff["col_type_breakdown"] = col_type_stats
|
||||||
|
|
||||||
|
return diff
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ground-truth-sessions")
|
||||||
|
async def list_ground_truth_sessions():
|
||||||
|
"""List all sessions that have a ground-truth reference."""
|
||||||
|
sessions = await list_ground_truth_sessions_db()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for s in sessions:
|
||||||
|
gt = s.get("ground_truth") or {}
|
||||||
|
ref = gt.get("build_grid_reference", {})
|
||||||
|
result.append({
|
||||||
|
"session_id": s["id"],
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"filename": s.get("filename", ""),
|
||||||
|
"document_category": s.get("document_category"),
|
||||||
|
"pipeline": ref.get("pipeline"),
|
||||||
|
"saved_at": ref.get("saved_at"),
|
||||||
|
"summary": ref.get("summary", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"sessions": result, "count": len(result)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/regression/run")
|
||||||
|
async def run_single_regression(session_id: str):
|
||||||
|
"""Re-run build_grid for a single session and compare to ground truth."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
if not reference:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No ground truth reference found for this session",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-compute grid without persisting
|
||||||
|
try:
|
||||||
|
new_result = await _build_grid_core(session_id, session)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
new_snapshot = _build_reference_snapshot(new_result)
|
||||||
|
diff = compare_grids(reference, new_snapshot)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
||||||
|
session_id, diff["status"],
|
||||||
|
diff["summary"]["structural_changes"],
|
||||||
|
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"status": diff["status"],
|
||||||
|
"diff": diff,
|
||||||
|
"reference_summary": reference.get("summary", {}),
|
||||||
|
"current_summary": new_snapshot.get("summary", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/regression/run")
|
||||||
|
async def run_all_regressions(
|
||||||
|
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
|
||||||
|
):
|
||||||
|
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
sessions = await list_ground_truth_sessions_db()
|
||||||
|
|
||||||
|
if not sessions:
|
||||||
|
return {
|
||||||
|
"status": "pass",
|
||||||
|
"message": "No ground truth sessions found",
|
||||||
|
"results": [],
|
||||||
|
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
for s in sessions:
|
||||||
|
session_id = s["id"]
|
||||||
|
gt = s.get("ground_truth") or {}
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
if not reference:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Re-load full session (list query may not include all JSONB fields)
|
||||||
|
full_session = await get_session_db(session_id)
|
||||||
|
if not full_session:
|
||||||
|
results.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": "Session not found during re-load",
|
||||||
|
})
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_result = await _build_grid_core(session_id, full_session)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
results.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_snapshot = _build_reference_snapshot(new_result)
|
||||||
|
diff = compare_grids(reference, new_snapshot)
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": diff["status"],
|
||||||
|
"diff_summary": diff["summary"],
|
||||||
|
"reference_summary": reference.get("summary", {}),
|
||||||
|
"current_summary": new_snapshot.get("summary", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include full diffs only for failures (keep response compact)
|
||||||
|
if diff["status"] == "fail":
|
||||||
|
entry["structural_diffs"] = diff["structural_diffs"]
|
||||||
|
entry["cell_diffs"] = diff["cell_diffs"]
|
||||||
|
failed += 1
|
||||||
|
else:
|
||||||
|
passed += 1
|
||||||
|
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
||||||
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"total": len(results),
|
||||||
|
"passed": passed,
|
||||||
|
"failed": failed,
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
|
||||||
|
overall, passed, failed, errors, len(results), duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
run_id = await _persist_regression_run(
|
||||||
|
status=overall,
|
||||||
|
summary=summary,
|
||||||
|
results=results,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": overall,
|
||||||
|
"run_id": run_id,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"results": results,
|
||||||
|
"summary": summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/regression/history")
|
||||||
|
async def get_regression_history(
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
):
|
||||||
|
"""Get recent regression run history from the database."""
|
||||||
|
try:
|
||||||
|
await _init_regression_table()
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT id, run_at, status, total, passed, failed, errors,
|
||||||
|
duration_ms, triggered_by
|
||||||
|
FROM regression_runs
|
||||||
|
ORDER BY run_at DESC
|
||||||
|
LIMIT $1
|
||||||
|
""",
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"runs": [
|
||||||
|
{
|
||||||
|
"id": str(row["id"]),
|
||||||
|
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||||
|
"status": row["status"],
|
||||||
|
"total": row["total"],
|
||||||
|
"passed": row["passed"],
|
||||||
|
"failed": row["failed"],
|
||||||
|
"errors": row["errors"],
|
||||||
|
"duration_ms": row["duration_ms"],
|
||||||
|
"triggered_by": row["triggered_by"],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
],
|
||||||
|
"count": len(rows),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to fetch regression history: %s", e)
|
||||||
|
return {"runs": [], "count": 0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/regression/history/{run_id}")
|
||||||
|
async def get_regression_run_detail(run_id: str):
|
||||||
|
"""Get detailed results of a specific regression run."""
|
||||||
|
try:
|
||||||
|
await _init_regression_table()
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT * FROM regression_runs WHERE id = $1",
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
return {
|
||||||
|
"id": str(row["id"]),
|
||||||
|
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
||||||
|
"status": row["status"],
|
||||||
|
"total": row["total"],
|
||||||
|
"passed": row["passed"],
|
||||||
|
"failed": row["failed"],
|
||||||
|
"errors": row["errors"],
|
||||||
|
"duration_ms": row["duration_ms"],
|
||||||
|
"triggered_by": row["triggered_by"],
|
||||||
|
"results": json.loads(row["results"]) if row["results"] else [],
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_regression.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .session_store import get_pool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DB persistence for regression runs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _init_regression_table():
|
||||||
|
"""Ensure regression_runs table exists (idempotent)."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
migration_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"migrations/008_regression_runs.sql",
|
||||||
|
)
|
||||||
|
if os.path.exists(migration_path):
|
||||||
|
with open(migration_path, "r") as f:
|
||||||
|
sql = f.read()
|
||||||
|
await conn.execute(sql)
|
||||||
|
|
||||||
|
|
||||||
|
async def _persist_regression_run(
|
||||||
|
status: str,
|
||||||
|
summary: dict,
|
||||||
|
results: list,
|
||||||
|
duration_ms: int,
|
||||||
|
triggered_by: str = "manual",
|
||||||
|
) -> str:
|
||||||
|
"""Save a regression run to the database. Returns the run ID."""
|
||||||
|
try:
|
||||||
|
await _init_regression_table()
|
||||||
|
pool = await get_pool()
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO regression_runs
|
||||||
|
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
|
||||||
|
""",
|
||||||
|
run_id,
|
||||||
|
status,
|
||||||
|
summary.get("total", 0),
|
||||||
|
summary.get("passed", 0),
|
||||||
|
summary.get("failed", 0),
|
||||||
|
summary.get("errors", 0),
|
||||||
|
duration_ms,
|
||||||
|
json.dumps(results),
|
||||||
|
triggered_by,
|
||||||
|
)
|
||||||
|
logger.info("Regression run %s persisted: %s", run_id, status)
|
||||||
|
return run_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to persist regression run: %s", e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
||||||
|
|
||||||
|
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
||||||
|
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
||||||
|
"""
|
||||||
|
cells = []
|
||||||
|
for zone in grid_result.get("zones", []):
|
||||||
|
for cell in zone.get("cells", []):
|
||||||
|
cells.append({
|
||||||
|
"cell_id": cell.get("cell_id", ""),
|
||||||
|
"row_index": cell.get("row_index"),
|
||||||
|
"col_index": cell.get("col_index"),
|
||||||
|
"col_type": cell.get("col_type", ""),
|
||||||
|
"text": cell.get("text", ""),
|
||||||
|
})
|
||||||
|
return cells
|
||||||
|
|
||||||
|
|
||||||
|
def _build_reference_snapshot(
|
||||||
|
grid_result: dict,
|
||||||
|
pipeline: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
||||||
|
cells = _extract_cells_for_comparison(grid_result)
|
||||||
|
|
||||||
|
total_zones = len(grid_result.get("zones", []))
|
||||||
|
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
||||||
|
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
"saved_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"version": 1,
|
||||||
|
"pipeline": pipeline,
|
||||||
|
"summary": {
|
||||||
|
"total_zones": total_zones,
|
||||||
|
"total_columns": total_columns,
|
||||||
|
"total_rows": total_rows,
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"cells": cells,
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def compare_grids(reference: dict, current: dict) -> dict:
|
||||||
|
"""Compare a reference grid snapshot with a newly computed one.
|
||||||
|
|
||||||
|
Returns a diff report with:
|
||||||
|
- status: "pass" or "fail"
|
||||||
|
- structural_diffs: changes in zone/row/column counts
|
||||||
|
- cell_diffs: list of individual cell changes
|
||||||
|
"""
|
||||||
|
ref_summary = reference.get("summary", {})
|
||||||
|
cur_summary = current.get("summary", {})
|
||||||
|
|
||||||
|
structural_diffs = []
|
||||||
|
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
||||||
|
ref_val = ref_summary.get(key, 0)
|
||||||
|
cur_val = cur_summary.get(key, 0)
|
||||||
|
if ref_val != cur_val:
|
||||||
|
structural_diffs.append({
|
||||||
|
"field": key,
|
||||||
|
"reference": ref_val,
|
||||||
|
"current": cur_val,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build cell lookup by cell_id
|
||||||
|
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
||||||
|
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
||||||
|
|
||||||
|
cell_diffs: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Check for missing cells (in reference but not in current)
|
||||||
|
for cell_id in ref_cells:
|
||||||
|
if cell_id not in cur_cells:
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "cell_missing",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference_text": ref_cells[cell_id].get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for added cells (in current but not in reference)
|
||||||
|
for cell_id in cur_cells:
|
||||||
|
if cell_id not in ref_cells:
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "cell_added",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"current_text": cur_cells[cell_id].get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for changes in shared cells
|
||||||
|
for cell_id in ref_cells:
|
||||||
|
if cell_id not in cur_cells:
|
||||||
|
continue
|
||||||
|
ref_cell = ref_cells[cell_id]
|
||||||
|
cur_cell = cur_cells[cell_id]
|
||||||
|
|
||||||
|
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "text_change",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference": ref_cell.get("text", ""),
|
||||||
|
"current": cur_cell.get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "col_type_change",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference": ref_cell.get("col_type", ""),
|
||||||
|
"current": cur_cell.get("col_type", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"structural_diffs": structural_diffs,
|
||||||
|
"cell_diffs": cell_diffs,
|
||||||
|
"summary": {
|
||||||
|
"structural_changes": len(structural_diffs),
|
||||||
|
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
||||||
|
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
||||||
|
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
||||||
|
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Reprocess Endpoint.
|
||||||
|
|
||||||
|
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
||||||
|
from .common import _cache
|
||||||
|
from .session_store import get_session_db, update_session_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reprocess")
|
||||||
|
async def reprocess_session(session_id: str, request: Request):
|
||||||
|
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||||
|
|
||||||
|
Body: {"from_step": 5} (1-indexed step number)
|
||||||
|
|
||||||
|
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
|
||||||
|
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
|
||||||
|
|
||||||
|
Clears downstream results:
|
||||||
|
- from_step <= 1: orientation_result + all downstream
|
||||||
|
- from_step <= 2: deskew_result + all downstream
|
||||||
|
- from_step <= 3: dewarp_result + all downstream
|
||||||
|
- from_step <= 4: crop_result + all downstream
|
||||||
|
- from_step <= 5: column_result, row_result, word_result
|
||||||
|
- from_step <= 6: row_result, word_result
|
||||||
|
- from_step <= 7: word_result (cells, vocab_entries)
|
||||||
|
- from_step <= 8: word_result.llm_review only
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
from_step = body.get("from_step", 1)
|
||||||
|
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||||
|
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||||
|
|
||||||
|
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||||
|
|
||||||
|
# Clear downstream data based on from_step
|
||||||
|
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
|
||||||
|
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
|
||||||
|
if from_step <= 8:
|
||||||
|
update_kwargs["word_result"] = None
|
||||||
|
elif from_step == 9:
|
||||||
|
# Only clear LLM review from word_result
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if word_result:
|
||||||
|
word_result.pop("llm_review", None)
|
||||||
|
word_result.pop("llm_corrections", None)
|
||||||
|
update_kwargs["word_result"] = word_result
|
||||||
|
|
||||||
|
if from_step <= 7:
|
||||||
|
update_kwargs["row_result"] = None
|
||||||
|
if from_step <= 6:
|
||||||
|
update_kwargs["column_result"] = None
|
||||||
|
if from_step <= 4:
|
||||||
|
update_kwargs["crop_result"] = None
|
||||||
|
if from_step <= 3:
|
||||||
|
update_kwargs["dewarp_result"] = None
|
||||||
|
if from_step <= 2:
|
||||||
|
update_kwargs["deskew_result"] = None
|
||||||
|
if from_step <= 1:
|
||||||
|
update_kwargs["orientation_result"] = None
|
||||||
|
|
||||||
|
await update_session_db(session_id, **update_kwargs)
|
||||||
|
|
||||||
|
# Also clear cache
|
||||||
|
if session_id in _cache:
|
||||||
|
for key in list(update_kwargs.keys()):
|
||||||
|
if key != "current_step":
|
||||||
|
_cache[session_id][key] = update_kwargs[key]
|
||||||
|
_cache[session_id]["current_step"] = from_step
|
||||||
|
|
||||||
|
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"from_step": from_step,
|
||||||
|
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||||
|
}
|
||||||
@@ -0,0 +1,348 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline - Row Detection Endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py.
|
||||||
|
Handles row detection (auto + manual) and row ground truth.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
create_ocr_image,
|
||||||
|
detect_column_geometry,
|
||||||
|
detect_row_geometry,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
ManualRowsRequest,
|
||||||
|
RowGroundTruthRequest,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _draw_box_exclusion_overlay(
|
||||||
|
img: np.ndarray,
|
||||||
|
zones: List[Dict],
|
||||||
|
*,
|
||||||
|
label: str = "BOX — separat verarbeitet",
|
||||||
|
) -> None:
|
||||||
|
"""Draw red semi-transparent rectangles over box zones (in-place).
|
||||||
|
|
||||||
|
Reusable for columns, rows, and words overlays.
|
||||||
|
"""
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") != "box" or not zone.get("box"):
|
||||||
|
continue
|
||||||
|
box = zone["box"]
|
||||||
|
bx, by = box["x"], box["y"]
|
||||||
|
bw, bh = box["width"], box["height"]
|
||||||
|
|
||||||
|
# Red semi-transparent fill (~25 %)
|
||||||
|
box_overlay = img.copy()
|
||||||
|
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
|
||||||
|
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
|
||||||
|
|
||||||
|
# Border
|
||||||
|
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
cv2.putText(img, label, (bx + 10, by + bh - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Row Detection Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rows")
|
||||||
|
async def detect_rows(session_id: str):
|
||||||
|
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if dewarped_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Try to reuse cached word_dicts and inv from column detection
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
inv = cached.get("_inv")
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
|
||||||
|
if word_dicts is None or inv is None or content_bounds is None:
|
||||||
|
# Not cached — run column geometry to get intermediates
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||||||
|
if geo_result is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
else:
|
||||||
|
left_x, right_x, top_y, bottom_y = content_bounds
|
||||||
|
|
||||||
|
# Read zones from column_result to exclude box regions
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
column_result = (session or {}).get("column_result") or {}
|
||||||
|
is_sub_session = bool((session or {}).get("parent_session_id"))
|
||||||
|
|
||||||
|
# Sub-sessions (box crops): use word-grouping instead of gap-based
|
||||||
|
# row detection. Box images are small with complex internal layouts
|
||||||
|
# (headings, sub-columns) where the horizontal projection approach
|
||||||
|
# merges rows. Word-grouping directly clusters words by Y proximity,
|
||||||
|
# which is more robust for these cases.
|
||||||
|
if is_sub_session and word_dicts:
|
||||||
|
from cv_layout import _build_rows_from_word_grouping
|
||||||
|
rows = _build_rows_from_word_grouping(
|
||||||
|
word_dicts, left_x, right_x, top_y, bottom_y,
|
||||||
|
right_x - left_x, bottom_y - top_y,
|
||||||
|
)
|
||||||
|
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
|
||||||
|
else:
|
||||||
|
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
|
||||||
|
|
||||||
|
# Collect box y-ranges for filtering.
|
||||||
|
# Use border_thickness to shrink the exclusion zone: the border pixels
|
||||||
|
# belong visually to the box frame, but text rows above/below the box
|
||||||
|
# may overlap with the border area and must not be clipped.
|
||||||
|
box_ranges = [] # [(y_start, y_end)]
|
||||||
|
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box = zone["box"]
|
||||||
|
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||||||
|
box_ranges.append((box["y"], box["y"] + box["height"]))
|
||||||
|
# Inner range: shrink by border thickness so boundary rows aren't excluded
|
||||||
|
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||||
|
|
||||||
|
if box_ranges and inv is not None:
|
||||||
|
# Combined-image approach: strip box regions from inv image,
|
||||||
|
# run row detection on the combined image, then remap y-coords back.
|
||||||
|
content_strips = [] # [(y_start, y_end)] in absolute coords
|
||||||
|
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
|
||||||
|
# Using inner ranges means the border area is included in the content
|
||||||
|
# strips, so the last row above a box isn't clipped by the border.
|
||||||
|
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
|
||||||
|
strip_start = top_y
|
||||||
|
for by_start, by_end in sorted_boxes:
|
||||||
|
if by_start > strip_start:
|
||||||
|
content_strips.append((strip_start, by_start))
|
||||||
|
strip_start = max(strip_start, by_end)
|
||||||
|
if strip_start < bottom_y:
|
||||||
|
content_strips.append((strip_start, bottom_y))
|
||||||
|
|
||||||
|
# Filter to strips with meaningful height
|
||||||
|
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
|
||||||
|
|
||||||
|
if content_strips:
|
||||||
|
# Stack content strips vertically
|
||||||
|
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
|
||||||
|
combined_inv = np.vstack(inv_strips)
|
||||||
|
|
||||||
|
# Filter word_dicts to only include words from content strips
|
||||||
|
combined_words = []
|
||||||
|
cum_y = 0
|
||||||
|
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
|
||||||
|
for ys, ye in content_strips:
|
||||||
|
h = ye - ys
|
||||||
|
strip_offsets.append((cum_y, h, ys))
|
||||||
|
for w in word_dicts:
|
||||||
|
w_abs_y = w['top'] + top_y # word y is relative to content top
|
||||||
|
w_center = w_abs_y + w['height'] / 2
|
||||||
|
if ys <= w_center < ye:
|
||||||
|
# Remap to combined coordinates
|
||||||
|
w_copy = dict(w)
|
||||||
|
w_copy['top'] = cum_y + (w_abs_y - ys)
|
||||||
|
combined_words.append(w_copy)
|
||||||
|
cum_y += h
|
||||||
|
|
||||||
|
# Run row detection on combined image
|
||||||
|
combined_h = combined_inv.shape[0]
|
||||||
|
rows = detect_row_geometry(
|
||||||
|
combined_inv, combined_words, left_x, right_x, 0, combined_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remap y-coordinates back to absolute page coords
|
||||||
|
def _combined_y_to_abs(cy: int) -> int:
|
||||||
|
for c_start, s_h, abs_start in strip_offsets:
|
||||||
|
if cy < c_start + s_h:
|
||||||
|
return abs_start + (cy - c_start)
|
||||||
|
last_c, last_h, last_abs = strip_offsets[-1]
|
||||||
|
return last_abs + last_h
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
abs_y = _combined_y_to_abs(r.y)
|
||||||
|
abs_y_end = _combined_y_to_abs(r.y + r.height)
|
||||||
|
r.y = abs_y
|
||||||
|
r.height = abs_y_end - abs_y
|
||||||
|
else:
|
||||||
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
else:
|
||||||
|
# No boxes — standard row detection
|
||||||
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Assign zone_index based on which content zone each row falls in
|
||||||
|
# Build content zone list with indices
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
|
||||||
|
|
||||||
|
# Build serializable result (exclude words to keep payload small)
|
||||||
|
rows_data = []
|
||||||
|
for r in rows:
|
||||||
|
# Determine zone_index
|
||||||
|
zone_idx = 0
|
||||||
|
row_center_y = r.y + r.height / 2
|
||||||
|
for zi, zone in content_zones:
|
||||||
|
zy = zone["y"]
|
||||||
|
zh = zone["height"]
|
||||||
|
if zy <= row_center_y < zy + zh:
|
||||||
|
zone_idx = zi
|
||||||
|
break
|
||||||
|
|
||||||
|
rd = {
|
||||||
|
"index": r.index,
|
||||||
|
"x": r.x,
|
||||||
|
"y": r.y,
|
||||||
|
"width": r.width,
|
||||||
|
"height": r.height,
|
||||||
|
"word_count": r.word_count,
|
||||||
|
"row_type": r.row_type,
|
||||||
|
"gap_before": r.gap_before,
|
||||||
|
"zone_index": zone_idx,
|
||||||
|
}
|
||||||
|
rows_data.append(rd)
|
||||||
|
|
||||||
|
type_counts = {}
|
||||||
|
for r in rows:
|
||||||
|
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||||||
|
|
||||||
|
row_result = {
|
||||||
|
"rows": rows_data,
|
||||||
|
"summary": type_counts,
|
||||||
|
"total_rows": len(rows),
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to DB — also invalidate word_result since rows changed
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
row_result=row_result,
|
||||||
|
word_result=None,
|
||||||
|
current_step=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["row_result"] = row_result
|
||||||
|
cached.pop("word_result", None)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||||||
|
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||||||
|
|
||||||
|
content_rows = sum(1 for r in rows if r.row_type == "content")
|
||||||
|
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
|
||||||
|
await _append_pipeline_log(session_id, "rows", {
|
||||||
|
"total_rows": len(rows),
|
||||||
|
"content_rows": content_rows,
|
||||||
|
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
|
||||||
|
"avg_row_height_px": avg_height,
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**row_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rows/manual")
|
||||||
|
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||||||
|
"""Override detected rows with manual definitions."""
|
||||||
|
row_result = {
|
||||||
|
"rows": req.rows,
|
||||||
|
"total_rows": len(req.rows),
|
||||||
|
"duration_seconds": 0,
|
||||||
|
"method": "manual",
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(session_id, row_result=row_result, word_result=None)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["row_result"] = row_result
|
||||||
|
_cache[session_id].pop("word_result", None)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||||||
|
f"{len(req.rows)} rows set")
|
||||||
|
|
||||||
|
return {"session_id": session_id, **row_result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||||||
|
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the row detection step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_rows": req.corrected_rows,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"row_result": session.get("row_result"),
|
||||||
|
}
|
||||||
|
ground_truth["rows"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||||||
|
async def get_row_ground_truth(session_id: str):
|
||||||
|
"""Retrieve saved ground truth for row detection."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
rows_gt = ground_truth.get("rows")
|
||||||
|
if not rows_gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"rows_gt": rows_gt,
|
||||||
|
"rows_auto": session.get("row_result"),
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
Scan Quality Assessment — Measures image quality before OCR.
|
||||||
|
|
||||||
|
Computes blur score, contrast score, and an overall quality rating.
|
||||||
|
Used to gate enhancement steps and warn users about degraded scans.
|
||||||
|
|
||||||
|
All operations use OpenCV (Apache-2.0), no additional dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Thresholds (empirically tuned on textbook scans)
|
||||||
|
BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry
|
||||||
|
CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast
|
||||||
|
CONFIDENCE_GOOD = 40 # OCR min confidence for good scans
|
||||||
|
CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanQualityReport:
|
||||||
|
"""Result of scan quality assessment."""
|
||||||
|
blur_score: float # Laplacian variance (higher = sharper)
|
||||||
|
contrast_score: float # Grayscale std deviation (higher = more contrast)
|
||||||
|
brightness: float # Mean grayscale value (0-255)
|
||||||
|
is_blurry: bool
|
||||||
|
is_low_contrast: bool
|
||||||
|
is_degraded: bool # True if any quality issue detected
|
||||||
|
quality_pct: int # 0-100 overall quality estimate
|
||||||
|
recommended_min_conf: int # Recommended OCR confidence threshold
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport:
|
||||||
|
"""
|
||||||
|
Assess the quality of a scanned image.
|
||||||
|
|
||||||
|
Uses:
|
||||||
|
- Laplacian variance for blur detection
|
||||||
|
- Grayscale standard deviation for contrast
|
||||||
|
- Mean brightness for exposure assessment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_bgr: BGR image (numpy array from OpenCV)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScanQualityReport with scores and recommendations
|
||||||
|
"""
|
||||||
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Blur detection: Laplacian variance
|
||||||
|
# Higher = sharper edges = better quality
|
||||||
|
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
||||||
|
blur_score = float(laplacian.var())
|
||||||
|
|
||||||
|
# Contrast: standard deviation of grayscale
|
||||||
|
contrast_score = float(np.std(gray))
|
||||||
|
|
||||||
|
# Brightness: mean grayscale
|
||||||
|
brightness = float(np.mean(gray))
|
||||||
|
|
||||||
|
# Quality flags
|
||||||
|
is_blurry = blur_score < BLUR_THRESHOLD
|
||||||
|
is_low_contrast = contrast_score < CONTRAST_THRESHOLD
|
||||||
|
is_degraded = is_blurry or is_low_contrast
|
||||||
|
|
||||||
|
# Overall quality percentage (simple weighted combination)
|
||||||
|
blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50)
|
||||||
|
contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50)
|
||||||
|
quality_pct = int(min(100, blur_pct + contrast_pct))
|
||||||
|
|
||||||
|
# Recommended confidence threshold
|
||||||
|
recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD
|
||||||
|
|
||||||
|
report = ScanQualityReport(
|
||||||
|
blur_score=round(blur_score, 1),
|
||||||
|
contrast_score=round(contrast_score, 1),
|
||||||
|
brightness=round(brightness, 1),
|
||||||
|
is_blurry=is_blurry,
|
||||||
|
is_low_contrast=is_low_contrast,
|
||||||
|
is_degraded=is_degraded,
|
||||||
|
quality_pct=quality_pct,
|
||||||
|
recommended_min_conf=recommended_min_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Scan quality: blur={report.blur_score} "
|
||||||
|
f"contrast={report.contrast_score} "
|
||||||
|
f"quality={report.quality_pct}% "
|
||||||
|
f"degraded={report.is_degraded} "
|
||||||
|
f"min_conf={report.recommended_min_conf}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
@@ -0,0 +1,388 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
|
||||||
|
|
||||||
|
Replaces in-memory storage with database persistence.
|
||||||
|
See migrations/002_ocr_pipeline_sessions.sql for schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Database configuration (same as vocab_session_store)
|
||||||
|
DATABASE_URL = os.getenv(
|
||||||
|
"DATABASE_URL",
|
||||||
|
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection pool (initialized lazily)
|
||||||
|
_pool: Optional[asyncpg.Pool] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pool() -> asyncpg.Pool:
|
||||||
|
"""Get or create the database connection pool."""
|
||||||
|
global _pool
|
||||||
|
if _pool is None:
|
||||||
|
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||||
|
return _pool
|
||||||
|
|
||||||
|
|
||||||
|
async def init_ocr_pipeline_tables():
|
||||||
|
"""Initialize OCR pipeline tables if they don't exist."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
tables_exist = await conn.fetchval("""
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = 'ocr_pipeline_sessions'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
if not tables_exist:
|
||||||
|
logger.info("Creating OCR pipeline tables...")
|
||||||
|
migration_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"migrations/002_ocr_pipeline_sessions.sql"
|
||||||
|
)
|
||||||
|
if os.path.exists(migration_path):
|
||||||
|
with open(migration_path, "r") as f:
|
||||||
|
sql = f.read()
|
||||||
|
await conn.execute(sql)
|
||||||
|
logger.info("OCR pipeline tables created successfully")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Migration file not found: {migration_path}")
|
||||||
|
else:
|
||||||
|
logger.debug("OCR pipeline tables already exist")
|
||||||
|
|
||||||
|
# Ensure new columns exist (idempotent ALTER TABLE)
|
||||||
|
await conn.execute("""
|
||||||
|
ALTER TABLE ocr_pipeline_sessions
|
||||||
|
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
||||||
|
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
|
||||||
|
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
|
||||||
|
ADD COLUMN IF NOT EXISTS pipeline_log JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
|
||||||
|
ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
|
||||||
|
ADD COLUMN IF NOT EXISTS orientation_result JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS crop_result JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
|
||||||
|
ADD COLUMN IF NOT EXISTS box_index INT,
|
||||||
|
ADD COLUMN IF NOT EXISTS grid_editor_result JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS structure_result JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS document_group_id UUID,
|
||||||
|
ADD COLUMN IF NOT EXISTS page_number INT
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Index for document group lookups
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group
|
||||||
|
ON ocr_pipeline_sessions (document_group_id)
|
||||||
|
WHERE document_group_id IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SESSION CRUD
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
async def create_session_db(
|
||||||
|
session_id: str,
|
||||||
|
name: str,
|
||||||
|
filename: str,
|
||||||
|
original_png: bytes,
|
||||||
|
parent_session_id: Optional[str] = None,
|
||||||
|
box_index: Optional[int] = None,
|
||||||
|
document_group_id: Optional[str] = None,
|
||||||
|
page_number: Optional[int] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a new OCR pipeline session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_session_id: If set, this is a sub-session for a box region.
|
||||||
|
box_index: 0-based index of the box this sub-session represents.
|
||||||
|
document_group_id: Groups multi-page uploads into one document.
|
||||||
|
page_number: 1-based page index within the document group.
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
|
||||||
|
group_uuid = uuid.UUID(document_group_id) if document_group_id else None
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
INSERT INTO ocr_pipeline_sessions (
|
||||||
|
id, name, filename, original_png, status, current_step,
|
||||||
|
parent_session_id, box_index, document_group_id, page_number
|
||||||
|
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8)
|
||||||
|
RETURNING id, name, filename, status, current_step,
|
||||||
|
orientation_result, crop_result,
|
||||||
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
|
document_category, pipeline_log,
|
||||||
|
grid_editor_result, structure_result,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at
|
||||||
|
""", uuid.UUID(session_id), name, filename, original_png,
|
||||||
|
parent_uuid, box_index, group_uuid, page_number)
|
||||||
|
|
||||||
|
return _row_to_dict(row)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get session metadata (without images)."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
orientation_result, crop_result,
|
||||||
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
|
document_category, pipeline_log,
|
||||||
|
grid_editor_result, structure_result,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at
|
||||||
|
FROM ocr_pipeline_sessions WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
|
||||||
|
"""Load a single image (BYTEA) from the session."""
|
||||||
|
column_map = {
|
||||||
|
"original": "original_png",
|
||||||
|
"oriented": "oriented_png",
|
||||||
|
"cropped": "cropped_png",
|
||||||
|
"deskewed": "deskewed_png",
|
||||||
|
"binarized": "binarized_png",
|
||||||
|
"dewarped": "dewarped_png",
|
||||||
|
"clean": "clean_png",
|
||||||
|
}
|
||||||
|
column = column_map.get(image_type)
|
||||||
|
if not column:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
return await conn.fetchval(
|
||||||
|
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
|
||||||
|
uuid.UUID(session_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Update session fields dynamically."""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
values = []
|
||||||
|
param_idx = 1
|
||||||
|
|
||||||
|
allowed_fields = {
|
||||||
|
'name', 'filename', 'status', 'current_step',
|
||||||
|
'original_png', 'oriented_png', 'cropped_png',
|
||||||
|
'deskewed_png', 'binarized_png', 'dewarped_png',
|
||||||
|
'clean_png', 'handwriting_removal_meta',
|
||||||
|
'orientation_result', 'crop_result',
|
||||||
|
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||||
|
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||||
|
'doc_type', 'doc_type_result',
|
||||||
|
'document_category', 'pipeline_log',
|
||||||
|
'grid_editor_result', 'structure_result',
|
||||||
|
'parent_session_id', 'box_index',
|
||||||
|
'document_group_id', 'page_number',
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in allowed_fields:
|
||||||
|
fields.append(f"{key} = ${param_idx}")
|
||||||
|
if key in jsonb_fields and value is not None and not isinstance(value, str):
|
||||||
|
value = json.dumps(value)
|
||||||
|
values.append(value)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return await get_session_db(session_id)
|
||||||
|
|
||||||
|
# Always update updated_at
|
||||||
|
fields.append(f"updated_at = NOW()")
|
||||||
|
|
||||||
|
values.append(uuid.UUID(session_id))
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(f"""
|
||||||
|
UPDATE ocr_pipeline_sessions
|
||||||
|
SET {', '.join(fields)}
|
||||||
|
WHERE id = ${param_idx}
|
||||||
|
RETURNING id, name, filename, status, current_step,
|
||||||
|
orientation_result, crop_result,
|
||||||
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
|
document_category, pipeline_log,
|
||||||
|
grid_editor_result, structure_result,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at
|
||||||
|
""", *values)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def list_sessions_db(
|
||||||
|
limit: int = 50,
|
||||||
|
include_sub_sessions: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List sessions (metadata only, no images).
|
||||||
|
|
||||||
|
By default, sub-sessions (those with parent_session_id) are excluded.
|
||||||
|
Pass include_sub_sessions=True to include them.
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')"
|
||||||
|
rows = await conn.fetch(f"""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
document_category, doc_type,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at,
|
||||||
|
ground_truth
|
||||||
|
FROM ocr_pipeline_sessions
|
||||||
|
{where}
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $1
|
||||||
|
""", limit)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
d = _row_to_dict(row)
|
||||||
|
# Derive is_ground_truth flag from JSONB, then drop the heavy field
|
||||||
|
gt = d.pop("ground_truth", None) or {}
|
||||||
|
d["is_ground_truth"] = bool(gt.get("build_grid_reference"))
|
||||||
|
results.append(d)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all sub-sessions for a parent session, ordered by box_index."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
document_category, doc_type,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at
|
||||||
|
FROM ocr_pipeline_sessions
|
||||||
|
WHERE parent_session_id = $1
|
||||||
|
ORDER BY box_index ASC
|
||||||
|
""", uuid.UUID(parent_session_id))
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all sessions in a document group, ordered by page_number."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
document_category, doc_type,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
document_group_id, page_number,
|
||||||
|
created_at, updated_at
|
||||||
|
FROM ocr_pipeline_sessions
|
||||||
|
WHERE document_group_id = $1
|
||||||
|
ORDER BY page_number ASC
|
||||||
|
""", uuid.UUID(document_group_id))
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
|
||||||
|
"""List sessions that have a build_grid_reference in ground_truth."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
document_category, doc_type,
|
||||||
|
ground_truth,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
created_at, updated_at
|
||||||
|
FROM ocr_pipeline_sessions
|
||||||
|
WHERE ground_truth IS NOT NULL
|
||||||
|
AND ground_truth::text LIKE '%build_grid_reference%'
|
||||||
|
AND parent_session_id IS NULL
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_session_db(session_id: str) -> bool:
|
||||||
|
"""Delete a session."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("""
|
||||||
|
DELETE FROM ocr_pipeline_sessions WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
return result == "DELETE 1"
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_all_sessions_db() -> int:
|
||||||
|
"""Delete all sessions. Returns number of deleted rows."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
|
||||||
|
# result is e.g. "DELETE 5"
|
||||||
|
try:
|
||||||
|
return int(result.split()[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# HELPER
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||||
|
"""Convert asyncpg Record to JSON-serializable dict."""
|
||||||
|
if row is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = dict(row)
|
||||||
|
|
||||||
|
# UUID → string
|
||||||
|
for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
result[key] = str(result[key])
|
||||||
|
|
||||||
|
# datetime → ISO string
|
||||||
|
for key in ['created_at', 'updated_at']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
result[key] = result[key].isoformat()
|
||||||
|
|
||||||
|
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||||
|
for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
if isinstance(result[key], str):
|
||||||
|
result[key] = json.loads(result[key])
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Sessions API — barrel re-export.
|
||||||
|
|
||||||
|
All implementation split into:
|
||||||
|
ocr_pipeline_sessions_crud — session CRUD, box sessions
|
||||||
|
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .sessions_crud import router as _crud_router # noqa: F401
|
||||||
|
from .sessions_images import router as _images_router # noqa: F401
|
||||||
|
|
||||||
|
# Composite router (used by ocr_pipeline_api.py)
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_crud_router)
|
||||||
|
router.include_router(_images_router)
|
||||||
@@ -0,0 +1,449 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_sessions.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
|
||||||
|
from .common import (
|
||||||
|
VALID_DOCUMENT_CATEGORIES,
|
||||||
|
UpdateSessionRequest,
|
||||||
|
_cache,
|
||||||
|
)
|
||||||
|
from .session_store import (
|
||||||
|
create_session_db,
|
||||||
|
delete_all_sessions_db,
|
||||||
|
delete_session_db,
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
get_sub_sessions,
|
||||||
|
list_sessions_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session Management Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions")
|
||||||
|
async def list_sessions(include_sub_sessions: bool = False):
|
||||||
|
"""List OCR pipeline sessions.
|
||||||
|
|
||||||
|
By default, sub-sessions (box regions) are hidden.
|
||||||
|
Pass ?include_sub_sessions=true to show them.
|
||||||
|
"""
|
||||||
|
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||||||
|
return {"sessions": sessions}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions")
|
||||||
|
async def create_session(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: Optional[str] = Form(None),
|
||||||
|
):
|
||||||
|
"""Upload a PDF or image file and create a pipeline session.
|
||||||
|
|
||||||
|
For multi-page PDFs (> 1 page), each page becomes its own session
|
||||||
|
grouped under a ``document_group_id``. The response includes a
|
||||||
|
``pages`` array with one entry per page/session.
|
||||||
|
"""
|
||||||
|
file_data = await file.read()
|
||||||
|
filename = file.filename or "upload"
|
||||||
|
content_type = file.content_type or ""
|
||||||
|
|
||||||
|
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||||||
|
session_name = name or filename
|
||||||
|
|
||||||
|
# --- Multi-page PDF handling ---
|
||||||
|
if is_pdf:
|
||||||
|
try:
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
|
||||||
|
page_count = pdf_doc.page_count
|
||||||
|
pdf_doc.close()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
|
||||||
|
|
||||||
|
if page_count > 1:
|
||||||
|
return await _create_multi_page_sessions(
|
||||||
|
file_data, filename, session_name, page_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Single page (image or 1-page PDF) ---
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_pdf:
|
||||||
|
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||||||
|
else:
|
||||||
|
img_bgr = render_image_high_res(file_data)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||||
|
|
||||||
|
# Encode original as PNG bytes
|
||||||
|
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||||
|
|
||||||
|
original_png = png_buf.tobytes()
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await create_session_db(
|
||||||
|
session_id=session_id,
|
||||||
|
name=session_name,
|
||||||
|
filename=filename,
|
||||||
|
original_png=original_png,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache BGR array for immediate processing
|
||||||
|
_cache[session_id] = {
|
||||||
|
"id": session_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": session_name,
|
||||||
|
"original_bgr": img_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||||||
|
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": session_name,
|
||||||
|
"image_width": img_bgr.shape[1],
|
||||||
|
"image_height": img_bgr.shape[0],
|
||||||
|
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_multi_page_sessions(
|
||||||
|
pdf_data: bytes,
|
||||||
|
filename: str,
|
||||||
|
base_name: str,
|
||||||
|
page_count: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Create one session per PDF page, grouped by document_group_id."""
|
||||||
|
document_group_id = str(uuid.uuid4())
|
||||||
|
pages = []
|
||||||
|
|
||||||
|
for page_idx in range(page_count):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
page_name = f"{base_name} — Seite {page_idx + 1}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
ok, png_buf = cv2.imencode(".png", img_bgr)
|
||||||
|
if not ok:
|
||||||
|
continue
|
||||||
|
page_png = png_buf.tobytes()
|
||||||
|
|
||||||
|
await create_session_db(
|
||||||
|
session_id=session_id,
|
||||||
|
name=page_name,
|
||||||
|
filename=filename,
|
||||||
|
original_png=page_png,
|
||||||
|
document_group_id=document_group_id,
|
||||||
|
page_number=page_idx + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
_cache[session_id] = {
|
||||||
|
"id": session_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": page_name,
|
||||||
|
"original_bgr": img_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
pages.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": page_name,
|
||||||
|
"page_number": page_idx + 1,
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OCR Pipeline: created page session {session_id} "
|
||||||
|
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include session_id pointing to first page for backwards compatibility
|
||||||
|
# (frontends that expect a single session_id will navigate to page 1)
|
||||||
|
first_session_id = pages[0]["session_id"] if pages else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": first_session_id,
|
||||||
|
"document_group_id": document_group_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": base_name,
|
||||||
|
"page_count": page_count,
|
||||||
|
"pages": pages,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}")
|
||||||
|
async def get_session_info(session_id: str):
|
||||||
|
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Get image dimensions from original PNG
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if original_png:
|
||||||
|
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||||
|
else:
|
||||||
|
img_w, img_h = 0, 0
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"session_id": session["id"],
|
||||||
|
"filename": session.get("filename", ""),
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||||
|
"current_step": session.get("current_step", 1),
|
||||||
|
"document_category": session.get("document_category"),
|
||||||
|
"doc_type": session.get("doc_type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.get("orientation_result"):
|
||||||
|
result["orientation_result"] = session["orientation_result"]
|
||||||
|
if session.get("crop_result"):
|
||||||
|
result["crop_result"] = session["crop_result"]
|
||||||
|
if session.get("deskew_result"):
|
||||||
|
result["deskew_result"] = session["deskew_result"]
|
||||||
|
if session.get("dewarp_result"):
|
||||||
|
result["dewarp_result"] = session["dewarp_result"]
|
||||||
|
if session.get("column_result"):
|
||||||
|
result["column_result"] = session["column_result"]
|
||||||
|
if session.get("row_result"):
|
||||||
|
result["row_result"] = session["row_result"]
|
||||||
|
if session.get("word_result"):
|
||||||
|
result["word_result"] = session["word_result"]
|
||||||
|
if session.get("doc_type_result"):
|
||||||
|
result["doc_type_result"] = session["doc_type_result"]
|
||||||
|
if session.get("structure_result"):
|
||||||
|
result["structure_result"] = session["structure_result"]
|
||||||
|
if session.get("grid_editor_result"):
|
||||||
|
# Include summary only to keep response small
|
||||||
|
gr = session["grid_editor_result"]
|
||||||
|
result["grid_editor_result"] = {
|
||||||
|
"summary": gr.get("summary", {}),
|
||||||
|
"zones_count": len(gr.get("zones", [])),
|
||||||
|
"edited": gr.get("edited", False),
|
||||||
|
}
|
||||||
|
if session.get("ground_truth"):
|
||||||
|
result["ground_truth"] = session["ground_truth"]
|
||||||
|
|
||||||
|
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
|
||||||
|
if session.get("parent_session_id"):
|
||||||
|
result["parent_session_id"] = session["parent_session_id"]
|
||||||
|
result["box_index"] = session.get("box_index")
|
||||||
|
else:
|
||||||
|
# Check for box sub-sessions (column detection creates these)
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
result["sub_sessions"] = [
|
||||||
|
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||||||
|
for s in subs
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/sessions/{session_id}")
|
||||||
|
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||||||
|
"""Update session name and/or document category."""
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if req.name is not None:
|
||||||
|
kwargs["name"] = req.name
|
||||||
|
if req.document_category is not None:
|
||||||
|
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||||||
|
)
|
||||||
|
kwargs["document_category"] = req.document_category
|
||||||
|
if not kwargs:
|
||||||
|
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||||
|
updated = await update_session_db(session_id, **kwargs)
|
||||||
|
if not updated:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}")
|
||||||
|
async def delete_session(session_id: str):
|
||||||
|
"""Delete a session."""
|
||||||
|
_cache.pop(session_id, None)
|
||||||
|
deleted = await delete_session_db(session_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, "deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions")
|
||||||
|
async def delete_all_sessions():
|
||||||
|
"""Delete ALL sessions (cleanup)."""
|
||||||
|
_cache.clear()
|
||||||
|
count = await delete_all_sessions_db()
|
||||||
|
return {"deleted_count": count}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/create-box-sessions")
|
||||||
|
async def create_box_sessions(session_id: str):
|
||||||
|
"""Create sub-sessions for each detected box region.
|
||||||
|
|
||||||
|
Crops box regions from the cropped/dewarped image and creates
|
||||||
|
independent sub-sessions that can be processed through the pipeline.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
if not column_result:
|
||||||
|
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||||
|
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
if not box_zones:
|
||||||
|
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||||||
|
|
||||||
|
# Check for existing sub-sessions
|
||||||
|
existing = await get_sub_sessions(session_id)
|
||||||
|
if existing:
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||||||
|
"message": f"{len(existing)} sub-session(s) already exist",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load base image
|
||||||
|
base_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not base_png:
|
||||||
|
base_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=400, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
parent_name = session.get("name", "Session")
|
||||||
|
created = []
|
||||||
|
|
||||||
|
for i, zone in enumerate(box_zones):
|
||||||
|
box = zone["box"]
|
||||||
|
bx, by = box["x"], box["y"]
|
||||||
|
bw, bh = box["width"], box["height"]
|
||||||
|
|
||||||
|
# Crop box region with small padding
|
||||||
|
pad = 5
|
||||||
|
y1 = max(0, by - pad)
|
||||||
|
y2 = min(img.shape[0], by + bh + pad)
|
||||||
|
x1 = max(0, bx - pad)
|
||||||
|
x2 = min(img.shape[1], bx + bw + pad)
|
||||||
|
crop = img[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
success, png_buf = cv2.imencode(".png", crop)
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
sub_id = str(uuid.uuid4())
|
||||||
|
sub_name = f"{parent_name} — Box {i + 1}"
|
||||||
|
|
||||||
|
await create_session_db(
|
||||||
|
session_id=sub_id,
|
||||||
|
name=sub_name,
|
||||||
|
filename=session.get("filename", "box-crop.png"),
|
||||||
|
original_png=png_buf.tobytes(),
|
||||||
|
parent_session_id=session_id,
|
||||||
|
box_index=i,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the BGR for immediate processing
|
||||||
|
# Promote original to cropped so column/row/word detection finds it
|
||||||
|
box_bgr = crop.copy()
|
||||||
|
_cache[sub_id] = {
|
||||||
|
"id": sub_id,
|
||||||
|
"filename": session.get("filename", "box-crop.png"),
|
||||||
|
"name": sub_name,
|
||||||
|
"parent_session_id": session_id,
|
||||||
|
"original_bgr": box_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": box_bgr,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
created.append({
|
||||||
|
"id": sub_id,
|
||||||
|
"name": sub_name,
|
||||||
|
"box_index": i,
|
||||||
|
"box": box,
|
||||||
|
"image_width": crop.shape[1],
|
||||||
|
"image_height": crop.shape[0],
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||||||
|
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"sub_sessions": created,
|
||||||
|
"total": len(created),
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
|
||||||
|
categories, and document type detection.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_sessions.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import create_ocr_image, detect_document_type
|
||||||
|
from .common import (
|
||||||
|
VALID_DOCUMENT_CATEGORIES,
|
||||||
|
_append_pipeline_log,
|
||||||
|
_cache,
|
||||||
|
_get_base_image_png,
|
||||||
|
_get_cached,
|
||||||
|
_load_session_to_cache,
|
||||||
|
)
|
||||||
|
from .overlays import render_overlay
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thumbnail & Log Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/thumbnail")
|
||||||
|
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||||||
|
"""Return a small thumbnail of the original image."""
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if not original_png:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||||||
|
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
scale = size / max(h, w)
|
||||||
|
new_w, new_h = int(w * scale), int(h * scale)
|
||||||
|
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||||
|
_, png_bytes = cv2.imencode(".png", thumb)
|
||||||
|
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||||||
|
headers={"Cache-Control": "public, max-age=3600"})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/pipeline-log")
|
||||||
|
async def get_pipeline_log(session_id: str):
|
||||||
|
"""Get the pipeline execution log for a session."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/categories")
|
||||||
|
async def list_categories():
|
||||||
|
"""List valid document categories."""
|
||||||
|
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Image Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||||
|
async def get_image(session_id: str, image_type: str):
|
||||||
|
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||||||
|
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||||
|
if image_type not in valid_types:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||||
|
|
||||||
|
if image_type == "structure-overlay":
|
||||||
|
return await render_overlay("structure", session_id)
|
||||||
|
|
||||||
|
if image_type == "columns-overlay":
|
||||||
|
return await render_overlay("columns", session_id)
|
||||||
|
|
||||||
|
if image_type == "rows-overlay":
|
||||||
|
return await render_overlay("rows", session_id)
|
||||||
|
|
||||||
|
if image_type == "words-overlay":
|
||||||
|
return await render_overlay("words", session_id)
|
||||||
|
|
||||||
|
# Try cache first for fast serving
|
||||||
|
cached = _cache.get(session_id)
|
||||||
|
if cached:
|
||||||
|
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||||
|
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||||
|
|
||||||
|
# For binarized, check if we have it cached as PNG
|
||||||
|
if image_type == "binarized" and cached.get("binarized_png"):
|
||||||
|
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||||
|
|
||||||
|
# Load from DB — for cropped/dewarped, fall back through the chain
|
||||||
|
if image_type in ("cropped", "dewarped"):
|
||||||
|
data = await _get_base_image_png(session_id)
|
||||||
|
else:
|
||||||
|
data = await get_session_image(session_id, image_type)
|
||||||
|
if not data:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||||
|
|
||||||
|
return Response(content=data, media_type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document Type Detection (between Dewarp and Columns)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/detect-type")
|
||||||
|
async def detect_type(session_id: str):
|
||||||
|
"""Detect document type (vocab_table, full_text, generic_table).
|
||||||
|
|
||||||
|
Should be called after crop (clean image available).
|
||||||
|
Falls back to dewarped if crop was skipped.
|
||||||
|
Stores result in session for frontend to decide pipeline flow.
|
||||||
|
"""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(img_bgr)
|
||||||
|
result = detect_document_type(ocr_img, img_bgr)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"doc_type": result.doc_type,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
"pipeline": result.pipeline,
|
||||||
|
"skip_steps": result.skip_steps,
|
||||||
|
"features": result.features,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
doc_type=result.doc_type,
|
||||||
|
doc_type_result=result_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["doc_type_result"] = result_dict
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||||
|
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "detect_type", {
|
||||||
|
"doc_type": result.doc_type,
|
||||||
|
"pipeline": result.pipeline,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **result_dict}
|
||||||
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Structure Detection and Exclude Regions
|
||||||
|
|
||||||
|
Detect document structure (boxes, zones, color regions, graphics)
|
||||||
|
and manage user-drawn exclude regions.
|
||||||
|
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cv_box_detect import detect_boxes
|
||||||
|
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
|
||||||
|
from cv_graphic_detect import detect_graphic_elements
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_filter_border_ghost_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Structure Detection Endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/detect-structure")
|
||||||
|
async def detect_structure(session_id: str):
|
||||||
|
"""Detect document structure: boxes, zones, and color regions.
|
||||||
|
|
||||||
|
Runs box detection (line + shading) and color analysis on the cropped
|
||||||
|
image. Returns structured JSON with all detected elements for the
|
||||||
|
structure visualization step.
|
||||||
|
"""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = (
|
||||||
|
cached.get("cropped_bgr")
|
||||||
|
if cached.get("cropped_bgr") is not None
|
||||||
|
else cached.get("dewarped_bgr")
|
||||||
|
)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# --- Content bounds from word result (if available) or full image ---
|
||||||
|
word_result = cached.get("word_result")
|
||||||
|
words: List[Dict] = []
|
||||||
|
if word_result and word_result.get("cells"):
|
||||||
|
for cell in word_result["cells"]:
|
||||||
|
for wb in (cell.get("word_boxes") or []):
|
||||||
|
words.append(wb)
|
||||||
|
# Fallback: use raw OCR words if cell word_boxes are empty
|
||||||
|
if not words and word_result:
|
||||||
|
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
|
||||||
|
raw = word_result.get(key, [])
|
||||||
|
if raw:
|
||||||
|
words = raw
|
||||||
|
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
|
||||||
|
break
|
||||||
|
# If no words yet, use image dimensions with small margin
|
||||||
|
if words:
|
||||||
|
content_x = max(0, min(int(wb["left"]) for wb in words))
|
||||||
|
content_y = max(0, min(int(wb["top"]) for wb in words))
|
||||||
|
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
|
||||||
|
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
|
||||||
|
content_w_px = content_r - content_x
|
||||||
|
content_h_px = content_b - content_y
|
||||||
|
else:
|
||||||
|
margin = int(min(w, h) * 0.03)
|
||||||
|
content_x, content_y = margin, margin
|
||||||
|
content_w_px = w - 2 * margin
|
||||||
|
content_h_px = h - 2 * margin
|
||||||
|
|
||||||
|
# --- Box detection ---
|
||||||
|
boxes = detect_boxes(
|
||||||
|
img_bgr,
|
||||||
|
content_x=content_x,
|
||||||
|
content_w=content_w_px,
|
||||||
|
content_y=content_y,
|
||||||
|
content_h=content_h_px,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Zone splitting ---
|
||||||
|
from cv_box_detect import split_page_into_zones as _split_zones
|
||||||
|
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||||||
|
|
||||||
|
# --- Color region sampling ---
|
||||||
|
# Sample background shading in each detected box
|
||||||
|
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||||||
|
box_colors = []
|
||||||
|
for box in boxes:
|
||||||
|
# Sample the center region of each box
|
||||||
|
cy1 = box.y + box.height // 4
|
||||||
|
cy2 = box.y + 3 * box.height // 4
|
||||||
|
cx1 = box.x + box.width // 4
|
||||||
|
cx2 = box.x + 3 * box.width // 4
|
||||||
|
cy1 = max(0, min(cy1, h - 1))
|
||||||
|
cy2 = max(0, min(cy2, h - 1))
|
||||||
|
cx1 = max(0, min(cx1, w - 1))
|
||||||
|
cx2 = max(0, min(cx2, w - 1))
|
||||||
|
if cy2 > cy1 and cx2 > cx1:
|
||||||
|
roi_hsv = hsv[cy1:cy2, cx1:cx2]
|
||||||
|
med_h = float(np.median(roi_hsv[:, :, 0]))
|
||||||
|
med_s = float(np.median(roi_hsv[:, :, 1]))
|
||||||
|
med_v = float(np.median(roi_hsv[:, :, 2]))
|
||||||
|
if med_s > 15:
|
||||||
|
from cv_color_detect import _hue_to_color_name
|
||||||
|
bg_name = _hue_to_color_name(med_h)
|
||||||
|
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
|
||||||
|
else:
|
||||||
|
bg_name = "gray" if med_v < 220 else "white"
|
||||||
|
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
|
||||||
|
else:
|
||||||
|
bg_name = "unknown"
|
||||||
|
bg_hex = "#6b7280"
|
||||||
|
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
|
||||||
|
|
||||||
|
# --- Color text detection overview ---
|
||||||
|
# Quick scan for colored text regions across the page
|
||||||
|
color_summary: Dict[str, int] = {}
|
||||||
|
for color_name, ranges in _COLOR_RANGES.items():
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
for lower, upper in ranges:
|
||||||
|
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||||||
|
pixel_count = int(np.sum(mask > 0))
|
||||||
|
if pixel_count > 50: # minimum threshold
|
||||||
|
color_summary[color_name] = pixel_count
|
||||||
|
|
||||||
|
# --- Graphic element detection ---
|
||||||
|
box_dicts = [
|
||||||
|
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
|
||||||
|
for b in boxes
|
||||||
|
]
|
||||||
|
graphics = detect_graphic_elements(
|
||||||
|
img_bgr, words,
|
||||||
|
detected_boxes=box_dicts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Filter border-ghost words from OCR result ---
|
||||||
|
ghost_count = 0
|
||||||
|
if boxes and word_result:
|
||||||
|
ghost_count = _filter_border_ghost_words(word_result, boxes)
|
||||||
|
if ghost_count:
|
||||||
|
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
|
||||||
|
await update_session_db(session_id, word_result=word_result)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Preserve user-drawn exclude regions from previous run
|
||||||
|
prev_sr = cached.get("structure_result") or {}
|
||||||
|
prev_exclude = prev_sr.get("exclude_regions", [])
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"image_width": w,
|
||||||
|
"image_height": h,
|
||||||
|
"content_bounds": {
|
||||||
|
"x": content_x, "y": content_y,
|
||||||
|
"w": content_w_px, "h": content_h_px,
|
||||||
|
},
|
||||||
|
"boxes": [
|
||||||
|
{
|
||||||
|
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||||||
|
"confidence": b.confidence,
|
||||||
|
"border_thickness": b.border_thickness,
|
||||||
|
"bg_color_name": box_colors[i]["color_name"],
|
||||||
|
"bg_color_hex": box_colors[i]["color_hex"],
|
||||||
|
}
|
||||||
|
for i, b in enumerate(boxes)
|
||||||
|
],
|
||||||
|
"zones": [
|
||||||
|
{
|
||||||
|
"index": z.index,
|
||||||
|
"zone_type": z.zone_type,
|
||||||
|
"y": z.y, "h": z.height,
|
||||||
|
"x": z.x, "w": z.width,
|
||||||
|
}
|
||||||
|
for z in zones
|
||||||
|
],
|
||||||
|
"graphics": [
|
||||||
|
{
|
||||||
|
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
|
||||||
|
"area": g.area,
|
||||||
|
"shape": g.shape,
|
||||||
|
"color_name": g.color_name,
|
||||||
|
"color_hex": g.color_hex,
|
||||||
|
"confidence": round(g.confidence, 2),
|
||||||
|
}
|
||||||
|
for g in graphics
|
||||||
|
],
|
||||||
|
"exclude_regions": prev_exclude,
|
||||||
|
"color_pixel_counts": color_summary,
|
||||||
|
"has_words": len(words) > 0,
|
||||||
|
"word_count": len(words),
|
||||||
|
"border_ghosts_removed": ghost_count,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to session
|
||||||
|
await update_session_db(session_id, structure_result=result_dict)
|
||||||
|
cached["structure_result"] = result_dict
|
||||||
|
|
||||||
|
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
|
||||||
|
session_id, len(boxes), len(zones), len(graphics), duration)
|
||||||
|
|
||||||
|
return {"session_id": session_id, **result_dict}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exclude Regions -- user-drawn rectangles to exclude from OCR results
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class _ExcludeRegionIn(BaseModel):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
w: int
|
||||||
|
h: int
|
||||||
|
label: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class _ExcludeRegionsBatchIn(BaseModel):
|
||||||
|
regions: list[_ExcludeRegionIn]
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/sessions/{session_id}/exclude-regions")
|
||||||
|
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
|
||||||
|
"""Replace all exclude regions for a session.
|
||||||
|
|
||||||
|
Regions are stored inside ``structure_result.exclude_regions``.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
sr = session.get("structure_result") or {}
|
||||||
|
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
|
||||||
|
|
||||||
|
# Invalidate grid so it rebuilds with new exclude regions
|
||||||
|
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["structure_result"] = sr
|
||||||
|
_cache[session_id].pop("grid_editor_result", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"exclude_regions": sr["exclude_regions"],
|
||||||
|
"count": len(sr["exclude_regions"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
|
||||||
|
async def delete_exclude_region(session_id: str, region_index: int):
|
||||||
|
"""Remove a single exclude region by index."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
sr = session.get("structure_result") or {}
|
||||||
|
regions = sr.get("exclude_regions", [])
|
||||||
|
|
||||||
|
if region_index < 0 or region_index >= len(regions):
|
||||||
|
raise HTTPException(status_code=404, detail="Region index out of range")
|
||||||
|
|
||||||
|
removed = regions.pop(region_index)
|
||||||
|
sr["exclude_regions"] = regions
|
||||||
|
|
||||||
|
# Invalidate grid so it rebuilds with new exclude regions
|
||||||
|
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["structure_result"] = sr
|
||||||
|
_cache[session_id].pop("grid_editor_result", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"removed": removed,
|
||||||
|
"remaining": len(regions),
|
||||||
|
}
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Validation — image detection, generation, validation save,
|
||||||
|
and handwriting removal endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_postprocess.py.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
RemoveHandwritingRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
STYLE_SUFFIXES = {
|
||||||
|
"educational": "educational illustration, textbook style, clear, colorful",
|
||||||
|
"cartoon": "cartoon, child-friendly, simple shapes",
|
||||||
|
"sketch": "pencil sketch, hand-drawn, black and white",
|
||||||
|
"clipart": "clipart, flat vector style, simple",
|
||||||
|
"realistic": "photorealistic, high detail",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationRequest(BaseModel):
|
||||||
|
notes: Optional[str] = None
|
||||||
|
score: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateImageRequest(BaseModel):
|
||||||
|
region_index: int
|
||||||
|
prompt: str
|
||||||
|
style: str = "educational"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Image detection + generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||||||
|
async def detect_image_regions(session_id: str):
|
||||||
|
"""Detect illustration/image regions in the original scan using VLM."""
|
||||||
|
import base64
|
||||||
|
import httpx
|
||||||
|
import re
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if not original_png:
|
||||||
|
raise HTTPException(status_code=400, detail="No original image found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result") or {}
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
vocab_context = ""
|
||||||
|
if entries:
|
||||||
|
sample = entries[:10]
|
||||||
|
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||||||
|
if words:
|
||||||
|
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||||||
|
|
||||||
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||||||
|
"(NOT text, NOT table cells, NOT blank areas). "
|
||||||
|
"For each image region found, return its bounding box as percentage of page dimensions "
|
||||||
|
"and a short English description of what the image shows. "
|
||||||
|
"Reply with ONLY a JSON array like: "
|
||||||
|
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||||||
|
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||||||
|
"If there are NO images on the page, return an empty array: []"
|
||||||
|
f"{vocab_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
|
||||||
|
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
raw_regions = json.loads(match.group(0))
|
||||||
|
else:
|
||||||
|
raw_regions = []
|
||||||
|
|
||||||
|
regions = []
|
||||||
|
for r in raw_regions:
|
||||||
|
regions.append({
|
||||||
|
"bbox_pct": {
|
||||||
|
"x": max(0, min(100, float(r.get("x", 0)))),
|
||||||
|
"y": max(0, min(100, float(r.get("y", 0)))),
|
||||||
|
"w": max(1, min(100, float(r.get("w", 10)))),
|
||||||
|
"h": max(1, min(100, float(r.get("h", 10)))),
|
||||||
|
},
|
||||||
|
"description": r.get("description", ""),
|
||||||
|
"prompt": r.get("description", ""),
|
||||||
|
"image_b64": None,
|
||||||
|
"style": "educational",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Enrich prompts with nearby vocab context
|
||||||
|
if entries:
|
||||||
|
for region in regions:
|
||||||
|
ry = region["bbox_pct"]["y"]
|
||||||
|
rh = region["bbox_pct"]["h"]
|
||||||
|
nearby = [
|
||||||
|
e for e in entries
|
||||||
|
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||||||
|
]
|
||||||
|
if nearby:
|
||||||
|
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||||||
|
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||||||
|
if en_words or de_words:
|
||||||
|
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||||||
|
if de_words:
|
||||||
|
context += f" / {', '.join(de_words[:5])}"
|
||||||
|
context += ")"
|
||||||
|
region["prompt"] = region["description"] + context
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
validation["image_regions"] = regions
|
||||||
|
validation["detected_at"] = datetime.utcnow().isoformat()
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||||||
|
|
||||||
|
return {"regions": regions, "count": len(regions)}
|
||||||
|
|
||||||
|
except httpx.ConnectError:
|
||||||
|
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||||||
|
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image detection failed for {session_id}: {e}")
|
||||||
|
return {"regions": [], "count": 0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||||||
|
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||||||
|
"""Generate a replacement image for a detected region using mflux."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
regions = validation.get("image_regions") or []
|
||||||
|
|
||||||
|
if req.region_index < 0 or req.region_index >= len(regions):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||||||
|
|
||||||
|
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||||||
|
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||||||
|
full_prompt = f"{req.prompt}, {style_suffix}"
|
||||||
|
|
||||||
|
region = regions[req.region_index]
|
||||||
|
bbox = region["bbox_pct"]
|
||||||
|
aspect = bbox["w"] / max(bbox["h"], 1)
|
||||||
|
if aspect > 1.3:
|
||||||
|
width, height = 768, 512
|
||||||
|
elif aspect < 0.7:
|
||||||
|
width, height = 512, 768
|
||||||
|
else:
|
||||||
|
width, height = 512, 512
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||||
|
resp = await client.post(f"{mflux_url}/generate", json={
|
||||||
|
"prompt": full_prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": 4,
|
||||||
|
})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
image_b64 = data.get("image_b64")
|
||||||
|
|
||||||
|
if not image_b64:
|
||||||
|
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||||||
|
|
||||||
|
regions[req.region_index]["image_b64"] = image_b64
|
||||||
|
regions[req.region_index]["prompt"] = req.prompt
|
||||||
|
regions[req.region_index]["style"] = req.style
|
||||||
|
validation["image_regions"] = regions
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||||||
|
return {"image_b64": image_b64, "success": True}
|
||||||
|
|
||||||
|
except httpx.ConnectError:
|
||||||
|
logger.warning(f"mflux-service not available at {mflux_url}")
|
||||||
|
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image generation failed for {session_id}: {e}")
|
||||||
|
return {"image_b64": None, "success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Validation save/get
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||||||
|
async def save_validation(session_id: str, req: ValidationRequest):
|
||||||
|
"""Save final validation results for step 8."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
validation["validated_at"] = datetime.utcnow().isoformat()
|
||||||
|
validation["notes"] = req.notes
|
||||||
|
validation["score"] = req.score
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||||||
|
|
||||||
|
return {"session_id": session_id, "validation": validation}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||||||
|
async def get_validation(session_id: str):
|
||||||
|
"""Retrieve saved validation data for step 8."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"validation": validation,
|
||||||
|
"word_result": session.get("word_result"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Remove handwriting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/remove-handwriting")
|
||||||
|
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||||||
|
"""Remove handwriting from a session image using inpainting."""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
from services.handwriting_detection import detect_handwriting
|
||||||
|
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
t0 = _time.monotonic()
|
||||||
|
|
||||||
|
# 1. Determine source image
|
||||||
|
source = req.use_source
|
||||||
|
if source == "auto":
|
||||||
|
deskewed = await get_session_image(session_id, "deskewed")
|
||||||
|
source = "deskewed" if deskewed else "original"
|
||||||
|
|
||||||
|
image_bytes = await get_session_image(session_id, source)
|
||||||
|
if not image_bytes:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||||||
|
|
||||||
|
# 2. Detect handwriting mask
|
||||||
|
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||||||
|
|
||||||
|
# 3. Convert mask to PNG bytes and dilate
|
||||||
|
import io
|
||||||
|
from PIL import Image as _PILImage
|
||||||
|
mask_img = _PILImage.fromarray(detection.mask)
|
||||||
|
mask_buf = io.BytesIO()
|
||||||
|
mask_img.save(mask_buf, format="PNG")
|
||||||
|
mask_bytes = mask_buf.getvalue()
|
||||||
|
|
||||||
|
if req.dilation > 0:
|
||||||
|
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||||||
|
|
||||||
|
# 4. Inpaint
|
||||||
|
method_map = {
|
||||||
|
"telea": InpaintingMethod.OPENCV_TELEA,
|
||||||
|
"ns": InpaintingMethod.OPENCV_NS,
|
||||||
|
"auto": InpaintingMethod.AUTO,
|
||||||
|
}
|
||||||
|
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||||||
|
|
||||||
|
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||||||
|
if not result.success:
|
||||||
|
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||||||
|
|
||||||
|
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||||||
|
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||||||
|
"detection_confidence": round(detection.confidence, 4),
|
||||||
|
"target_ink": req.target_ink,
|
||||||
|
"dilation": req.dilation,
|
||||||
|
"source_image": source,
|
||||||
|
"processing_time_ms": elapsed_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Persist clean image
|
||||||
|
clean_png_bytes = image_to_png(result.image)
|
||||||
|
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**meta,
|
||||||
|
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
@@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading.
|
||||||
|
|
||||||
|
Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL.
|
||||||
|
The LLM can read degraded text using context understanding and visual inspection,
|
||||||
|
while OCR coordinates provide structural hints (where text is, column positions).
|
||||||
|
|
||||||
|
Uses Ollama API (same pattern as handwriting_htr_api.py).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b")
|
||||||
|
|
||||||
|
# Document category → prompt context
|
||||||
|
CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = {
|
||||||
|
"vokabelseite": {
|
||||||
|
"label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)",
|
||||||
|
"columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.",
|
||||||
|
},
|
||||||
|
"woerterbuch": {
|
||||||
|
"label": "Woerterbuchseite",
|
||||||
|
"columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.",
|
||||||
|
},
|
||||||
|
"arbeitsblatt": {
|
||||||
|
"label": "Arbeitsblatt",
|
||||||
|
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
|
||||||
|
},
|
||||||
|
"buchseite": {
|
||||||
|
"label": "Schulbuchseite",
|
||||||
|
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _group_words_into_lines(
|
||||||
|
words: List[Dict], y_tolerance: float = 15.0,
|
||||||
|
) -> List[List[Dict]]:
|
||||||
|
"""Group OCR words into lines by Y-proximity."""
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
sorted_w = sorted(words, key=lambda w: w.get("top", 0))
|
||||||
|
lines: List[List[Dict]] = [[sorted_w[0]]]
|
||||||
|
for w in sorted_w[1:]:
|
||||||
|
last_line = lines[-1]
|
||||||
|
avg_y = sum(ww["top"] for ww in last_line) / len(last_line)
|
||||||
|
if abs(w["top"] - avg_y) <= y_tolerance:
|
||||||
|
last_line.append(w)
|
||||||
|
else:
|
||||||
|
lines.append([w])
|
||||||
|
# Sort words within each line by X
|
||||||
|
for line in lines:
|
||||||
|
line.sort(key=lambda w: w.get("left", 0))
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def _build_ocr_context(words: List[Dict], img_h: int) -> str:
|
||||||
|
"""Build a text description of OCR words with positions for the prompt."""
|
||||||
|
lines = _group_words_into_lines(words)
|
||||||
|
context_parts = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
word_descs = []
|
||||||
|
for w in line:
|
||||||
|
text = w.get("text", "").strip()
|
||||||
|
x = w.get("left", 0)
|
||||||
|
conf = w.get("conf", 0)
|
||||||
|
marker = " (?)" if conf < 50 else ""
|
||||||
|
word_descs.append(f'x={x} "{text}"{marker}')
|
||||||
|
avg_y = int(sum(w["top"] for w in line) / len(line))
|
||||||
|
context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}")
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt(
|
||||||
|
ocr_context: str, category: str, img_w: int, img_h: int,
|
||||||
|
) -> str:
|
||||||
|
"""Build the Vision-LLM prompt with OCR context and document type."""
|
||||||
|
cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"])
|
||||||
|
|
||||||
|
return f"""Du siehst eine eingescannte {cat_info['label']}.
|
||||||
|
{cat_info['columns']}
|
||||||
|
|
||||||
|
Die OCR-Software hat folgende Woerter an diesen Positionen erkannt.
|
||||||
|
Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch:
|
||||||
|
|
||||||
|
{ocr_context}
|
||||||
|
|
||||||
|
Bildgroesse: {img_w} x {img_h} Pixel.
|
||||||
|
|
||||||
|
AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle.
|
||||||
|
- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst
|
||||||
|
- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist,
|
||||||
|
gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht)
|
||||||
|
- Behalte die Reihenfolge bei
|
||||||
|
|
||||||
|
Antworte NUR mit einem JSON-Array, keine Erklaerungen:
|
||||||
|
[
|
||||||
|
{{"row": 1, "english": "...", "german": "...", "example": "..."}},
|
||||||
|
{{"row": 2, "english": "...", "german": "...", "example": "..."}}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_llm_response(response_text: str) -> Optional[List[Dict]]:
|
||||||
|
"""Parse the LLM JSON response, handling markdown code blocks."""
|
||||||
|
text = response_text.strip()
|
||||||
|
|
||||||
|
# Strip markdown code block if present
|
||||||
|
if text.startswith("```"):
|
||||||
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||||
|
text = re.sub(r"\s*```\s*$", "", text)
|
||||||
|
|
||||||
|
# Try to find JSON array
|
||||||
|
match = re.search(r"\[[\s\S]*\]", text)
|
||||||
|
if not match:
|
||||||
|
logger.warning("vision_fuse_ocr: no JSON array found in LLM response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(match.group())
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"vision_fuse_ocr: JSON parse error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _vocab_rows_to_words(
|
||||||
|
rows: List[Dict], img_w: int, img_h: int,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Convert LLM vocab rows back to word dicts for grid building.
|
||||||
|
|
||||||
|
Distributes words across estimated column positions so the
|
||||||
|
existing grid builder can process them normally.
|
||||||
|
"""
|
||||||
|
words = []
|
||||||
|
# Estimate column positions (3-column vocab layout)
|
||||||
|
col_positions = [
|
||||||
|
(0.02, 0.28), # EN: 2%-28% of width
|
||||||
|
(0.30, 0.55), # DE: 30%-55%
|
||||||
|
(0.57, 0.98), # Example: 57%-98%
|
||||||
|
]
|
||||||
|
|
||||||
|
median_h = max(15, img_h // (len(rows) * 3)) if rows else 20
|
||||||
|
y_step = max(median_h + 5, img_h // max(len(rows), 1))
|
||||||
|
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
y = int(i * y_step + 20)
|
||||||
|
row_num = row.get("row", i + 1)
|
||||||
|
|
||||||
|
for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([
|
||||||
|
("english", col_positions[0]),
|
||||||
|
("german", col_positions[1]),
|
||||||
|
("example", col_positions[2]),
|
||||||
|
]):
|
||||||
|
text = (row.get(field) or "").strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
x = int(x_start_pct * img_w)
|
||||||
|
w = int((x_end_pct - x_start_pct) * img_w)
|
||||||
|
words.append({
|
||||||
|
"text": text,
|
||||||
|
"left": x,
|
||||||
|
"top": y,
|
||||||
|
"width": w,
|
||||||
|
"height": median_h,
|
||||||
|
"conf": 95, # LLM-corrected → high confidence
|
||||||
|
"_source": "vision_llm",
|
||||||
|
"_row": row_num,
|
||||||
|
"_col_type": f"column_{['en', 'de', 'example'][col_idx]}",
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words")
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
|
async def vision_fuse_ocr(
|
||||||
|
img_bgr: np.ndarray,
|
||||||
|
ocr_words: List[Dict],
|
||||||
|
document_category: str = "vokabelseite",
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Fuse traditional OCR results with Vision-LLM reading.
|
||||||
|
|
||||||
|
Sends the image + OCR word positions to Qwen2.5-VL which can:
|
||||||
|
- Read degraded text that traditional OCR cannot
|
||||||
|
- Use document context (knows what a vocab table looks like)
|
||||||
|
- Merge continuation rows (understands table structure)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_bgr: The cropped/dewarped scan image (BGR)
|
||||||
|
ocr_words: Traditional OCR word list with positions
|
||||||
|
document_category: Type of document being scanned
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corrected word list in same format as input, ready for grid building.
|
||||||
|
Falls back to original ocr_words on error.
|
||||||
|
"""
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Build OCR context string
|
||||||
|
ocr_context = _build_ocr_context(ocr_words, img_h)
|
||||||
|
|
||||||
|
# Build prompt
|
||||||
|
prompt = _build_prompt(ocr_context, document_category, img_w, img_h)
|
||||||
|
|
||||||
|
# Encode image as base64
|
||||||
|
_, img_encoded = cv2.imencode(".png", img_bgr)
|
||||||
|
img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8")
|
||||||
|
|
||||||
|
# Call Qwen2.5-VL via Ollama
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{OLLAMA_BASE_URL}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": VISION_FUSION_MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.1, "num_predict": 4096},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
response_text = data.get("response", "").strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"vision_fuse_ocr: Ollama call failed: {e}")
|
||||||
|
return ocr_words # Fallback to original
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
logger.warning("vision_fuse_ocr: empty LLM response")
|
||||||
|
return ocr_words
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
rows = _parse_llm_response(response_text)
|
||||||
|
if not rows:
|
||||||
|
logger.warning(
|
||||||
|
"vision_fuse_ocr: could not parse LLM response, "
|
||||||
|
"first 200 chars: %s", response_text[:200],
|
||||||
|
)
|
||||||
|
return ocr_words
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows "
|
||||||
|
f"(from {len(ocr_words)} OCR words)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to word format for grid building
|
||||||
|
return _vocab_rows_to_words(rows, img_w, img_h)
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
|
||||||
|
and ground truth endpoints.
|
||||||
|
|
||||||
|
Split into sub-modules:
|
||||||
|
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
|
||||||
|
ocr_pipeline_words_stream — SSE streaming generators
|
||||||
|
|
||||||
|
This barrel module contains the PaddleOCR direct endpoint and ground truth
|
||||||
|
endpoints, and assembles all word-related routers.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cv_words_first import build_grid_from_words
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_append_pipeline_log,
|
||||||
|
)
|
||||||
|
from .words_detect import router as _detect_router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WordGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PaddleOCR Direct Endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@_local_router.post("/sessions/{session_id}/paddle-direct")
|
||||||
|
async def paddle_direct(session_id: str):
|
||||||
|
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
|
||||||
|
img_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "original")
|
||||||
|
if not img_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||||
|
|
||||||
|
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to decode original image")
|
||||||
|
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
word_dicts = await ocr_region_paddle(img_bgr, region=None)
|
||||||
|
if not word_dicts:
|
||||||
|
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "paddle_direct"
|
||||||
|
|
||||||
|
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
col_types = {c.get("type") for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": "paddle_direct",
|
||||||
|
"grid_method": "paddle_direct",
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
cropped_png=img_png,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
|
||||||
|
session_id, len(cells), n_rows, n_cols, duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "paddle_direct", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"ocr_engine": "paddle_direct",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ground Truth Words Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@_local_router.post("/sessions/{session_id}/ground-truth/words")
|
||||||
|
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the word recognition step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_entries": req.corrected_entries,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"word_result": session.get("word_result"),
|
||||||
|
}
|
||||||
|
ground_truth["words"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
@_local_router.get("/sessions/{session_id}/ground-truth/words")
|
||||||
|
async def get_word_ground_truth(session_id: str):
|
||||||
|
"""Retrieve saved ground truth for word recognition."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
words_gt = ground_truth.get("words")
|
||||||
|
if not words_gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"words_gt": words_gt,
|
||||||
|
"words_auto": session.get("word_result"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Composite router
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(_detect_router)
|
||||||
|
router.include_router(_local_router)
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
|
||||||
|
endpoint which handles both v2 and words_first grid methods.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
|
_cells_to_vocab_entries,
|
||||||
|
_fix_phonetic_brackets,
|
||||||
|
fix_cell_phonetics,
|
||||||
|
build_cell_grid_v2,
|
||||||
|
create_ocr_image,
|
||||||
|
detect_column_geometry,
|
||||||
|
)
|
||||||
|
from cv_words_first import build_grid_from_words
|
||||||
|
from .session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from .common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
)
|
||||||
|
from .words_stream import (
|
||||||
|
_word_batch_stream_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Word Detection Endpoint (Step 7)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/words")
|
||||||
|
async def detect_words(
|
||||||
|
session_id: str,
|
||||||
|
request: Request,
|
||||||
|
engine: str = "auto",
|
||||||
|
pronunciation: str = "british",
|
||||||
|
stream: bool = False,
|
||||||
|
skip_heal_gaps: bool = False,
|
||||||
|
grid_method: str = "v2",
|
||||||
|
):
|
||||||
|
"""Build word grid from columns x rows, OCR each cell.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
||||||
|
pronunciation: 'british' (default) or 'american'
|
||||||
|
stream: false (default) for JSON response, true for SSE streaming
|
||||||
|
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
|
||||||
|
grid_method: 'v2' (default) or 'words_first'
|
||||||
|
"""
|
||||||
|
# PaddleOCR is full-page remote OCR -> force words_first grid method
|
||||||
|
if engine == "paddle" and grid_method != "words_first":
|
||||||
|
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
||||||
|
grid_method = "words_first"
|
||||||
|
|
||||||
|
if session_id not in _cache:
|
||||||
|
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if dewarped_bgr is None:
|
||||||
|
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
||||||
|
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
row_result = session.get("row_result")
|
||||||
|
if not column_result or not column_result.get("columns"):
|
||||||
|
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
||||||
|
column_result = {
|
||||||
|
"columns": [{
|
||||||
|
"type": "column_text",
|
||||||
|
"x": 0, "y": 0,
|
||||||
|
"width": img_w_tmp, "height": img_h_tmp,
|
||||||
|
"classification_confidence": 1.0,
|
||||||
|
"classification_method": "full_page_fallback",
|
||||||
|
}],
|
||||||
|
"zones": [],
|
||||||
|
"duration_seconds": 0,
|
||||||
|
}
|
||||||
|
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
||||||
|
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
||||||
|
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||||
|
|
||||||
|
# Convert column dicts back to PageRegion objects
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"],
|
||||||
|
x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert row dicts back to RowGeometry objects
|
||||||
|
row_geoms = [
|
||||||
|
RowGeometry(
|
||||||
|
index=r["index"],
|
||||||
|
x=r["x"], y=r["y"],
|
||||||
|
width=r["width"], height=r["height"],
|
||||||
|
word_count=r.get("word_count", 0),
|
||||||
|
words=[],
|
||||||
|
row_type=r.get("row_type", "content"),
|
||||||
|
gap_before=r.get("gap_before", 0),
|
||||||
|
)
|
||||||
|
for r in row_result["rows"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Populate word counts from cached words
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
if word_dicts is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||||
|
if geo_result is not None:
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
if word_dicts:
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
if content_bounds:
|
||||||
|
_lx, _rx, top_y, _by = content_bounds
|
||||||
|
else:
|
||||||
|
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||||||
|
|
||||||
|
for row in row_geoms:
|
||||||
|
row_y_rel = row.y - top_y
|
||||||
|
row_bottom_rel = row_y_rel + row.height
|
||||||
|
row.words = [
|
||||||
|
w for w in word_dicts
|
||||||
|
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||||
|
]
|
||||||
|
row.word_count = len(row.words)
|
||||||
|
|
||||||
|
# Exclude rows that fall within box zones
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_ranges_inner = []
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box = zone["box"]
|
||||||
|
bt = max(box.get("border_thickness", 0), 5)
|
||||||
|
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||||
|
|
||||||
|
if box_ranges_inner:
|
||||||
|
def _row_in_box(r):
|
||||||
|
center_y = r.y + r.height / 2
|
||||||
|
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
||||||
|
|
||||||
|
before_count = len(row_geoms)
|
||||||
|
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
||||||
|
excluded = before_count - len(row_geoms)
|
||||||
|
if excluded:
|
||||||
|
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
||||||
|
|
||||||
|
# --- Words-First path ---
|
||||||
|
if grid_method == "words_first":
|
||||||
|
return await _words_first_path(
|
||||||
|
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
_word_batch_stream_generator(
|
||||||
|
session_id, cached, col_regions, row_geoms,
|
||||||
|
dewarped_bgr, engine, pronunciation, request,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Non-streaming path (grid_method=v2) ---
|
||||||
|
return await _v2_path(
|
||||||
|
session_id, cached, col_regions, row_geoms,
|
||||||
|
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _words_first_path(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
zones: list,
|
||||||
|
) -> dict:
|
||||||
|
"""Words-first grid construction path."""
|
||||||
|
t0 = time.time()
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
if engine == "paddle":
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
||||||
|
cached["_paddle_word_dicts"] = wf_word_dicts
|
||||||
|
else:
|
||||||
|
wf_word_dicts = cached.get("_word_dicts")
|
||||||
|
if wf_word_dicts is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||||
|
if geo_result is not None:
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = wf_word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
if not wf_word_dicts:
|
||||||
|
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
|
||||||
|
|
||||||
|
# Convert word coordinates to absolute if needed
|
||||||
|
if engine != "paddle":
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
if content_bounds:
|
||||||
|
lx, _rx, ty, _by = content_bounds
|
||||||
|
abs_words = []
|
||||||
|
for w in wf_word_dicts:
|
||||||
|
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
|
||||||
|
wf_word_dicts = abs_words
|
||||||
|
|
||||||
|
box_rects = []
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box_rects.append(zone["box"])
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(
|
||||||
|
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
for cell in cells:
|
||||||
|
cell.setdefault("zone_index", 0)
|
||||||
|
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
used_engine = "paddle" if engine == "paddle" else "words_first"
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"grid_method": "words_first",
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_vocab or 'column_text' in col_types:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
||||||
|
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "words", {
|
||||||
|
"grid_method": "words_first",
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"layout": word_result["layout"],
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
|
||||||
|
async def _v2_path(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
col_regions: List[PageRegion],
|
||||||
|
row_geoms: List[RowGeometry],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
skip_heal_gaps: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""Cell-First OCR v2 non-streaming path."""
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
cells, columns_meta = build_cell_grid_v2(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell.setdefault("zone_index", 0)
|
||||||
|
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||||
|
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, "
|
||||||
|
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "words", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"low_confidence_count": word_result["summary"]["low_confidence"],
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"layout": word_result["layout"],
|
||||||
|
"entry_count": word_result.get("entry_count", 0),
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
@@ -0,0 +1,303 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Words Stream — SSE streaming generators for word detection.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_words.py.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
|
_cells_to_vocab_entries,
|
||||||
|
_fix_character_confusion,
|
||||||
|
_fix_phonetic_brackets,
|
||||||
|
fix_cell_phonetics,
|
||||||
|
build_cell_grid_v2,
|
||||||
|
build_cell_grid_v2_streaming,
|
||||||
|
create_ocr_image,
|
||||||
|
)
|
||||||
|
from .session_store import update_session_db
|
||||||
|
from .common import _cache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _word_batch_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
col_regions: List[PageRegion],
|
||||||
|
row_geoms: List[RowGeometry],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
request: Request,
|
||||||
|
skip_heal_gaps: bool = False,
|
||||||
|
):
|
||||||
|
"""SSE generator that runs batch OCR (parallel) then streams results.
|
||||||
|
|
||||||
|
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
|
||||||
|
then emits all cells as SSE events.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||||
|
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
total_cells = n_content_rows * n_cols
|
||||||
|
|
||||||
|
# 1. Send meta event immediately
|
||||||
|
meta_event = {
|
||||||
|
"type": "meta",
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||||
|
|
||||||
|
# 2. Send preparing event (keepalive for proxy)
|
||||||
|
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
||||||
|
|
||||||
|
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ocr_future = loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: build_cell_grid_v2(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send keepalive events every 5 seconds while OCR runs
|
||||||
|
keepalive_count = 0
|
||||||
|
while not ocr_future.done():
|
||||||
|
try:
|
||||||
|
cells, columns_meta = await asyncio.wait_for(
|
||||||
|
asyncio.shield(ocr_future), timeout=5.0,
|
||||||
|
)
|
||||||
|
break # OCR finished
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
keepalive_count += 1
|
||||||
|
elapsed = int(time.time() - t0)
|
||||||
|
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
||||||
|
ocr_future.cancel()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cells, columns_meta = ocr_future.result()
|
||||||
|
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. Apply IPA phonetic fixes
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
# 5. Send columns meta
|
||||||
|
if columns_meta:
|
||||||
|
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
||||||
|
|
||||||
|
# 6. Stream all cells
|
||||||
|
for idx, cell in enumerate(cells):
|
||||||
|
cell_event = {
|
||||||
|
"type": "cell",
|
||||||
|
"cell": cell,
|
||||||
|
"progress": {"current": idx + 1, "total": len(cells)},
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||||
|
|
||||||
|
# 7. Build final result and persist
|
||||||
|
duration = time.time() - t0
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab_entries = None
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
vocab_entries = entries
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
||||||
|
|
||||||
|
# 8. Send complete event
|
||||||
|
complete_event = {
|
||||||
|
"type": "complete",
|
||||||
|
"summary": word_result["summary"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
}
|
||||||
|
if vocab_entries is not None:
|
||||||
|
complete_event["vocab_entries"] = vocab_entries
|
||||||
|
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def _word_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
col_regions: List[PageRegion],
|
||||||
|
row_geoms: List[RowGeometry],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||||
|
|
||||||
|
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
|
||||||
|
columns_meta = None
|
||||||
|
total_cells = n_content_rows * n_cols
|
||||||
|
|
||||||
|
meta_event = {
|
||||||
|
"type": "meta",
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
||||||
|
|
||||||
|
all_cells: List[Dict[str, Any]] = []
|
||||||
|
cell_idx = 0
|
||||||
|
last_keepalive = time.time()
|
||||||
|
|
||||||
|
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if columns_meta is None:
|
||||||
|
columns_meta = cols_meta
|
||||||
|
meta_update = {"type": "columns", "columns_used": cols_meta}
|
||||||
|
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||||
|
|
||||||
|
all_cells.append(cell)
|
||||||
|
cell_idx += 1
|
||||||
|
|
||||||
|
cell_event = {
|
||||||
|
"type": "cell",
|
||||||
|
"cell": cell,
|
||||||
|
"progress": {"current": cell_idx, "total": total},
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||||
|
|
||||||
|
# All cells done
|
||||||
|
duration = time.time() - t0
|
||||||
|
if columns_meta is None:
|
||||||
|
columns_meta = []
|
||||||
|
|
||||||
|
# Remove all-empty rows
|
||||||
|
rows_with_text: set = set()
|
||||||
|
for c in all_cells:
|
||||||
|
if c.get("text", "").strip():
|
||||||
|
rows_with_text.add(c["row_index"])
|
||||||
|
before_filter = len(all_cells)
|
||||||
|
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
||||||
|
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
||||||
|
if empty_rows_removed > 0:
|
||||||
|
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
||||||
|
|
||||||
|
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||||
|
|
||||||
|
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": all_cells,
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(all_cells),
|
||||||
|
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab_entries = None
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||||
|
entries = _fix_character_confusion(entries)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
vocab_entries = entries
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, "
|
||||||
|
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||||
|
|
||||||
|
complete_event = {
|
||||||
|
"type": "complete",
|
||||||
|
"summary": word_result["summary"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
}
|
||||||
|
if vocab_entries is not None:
|
||||||
|
complete_event["vocab_entries"] = vocab_entries
|
||||||
|
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||||
@@ -1,81 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/labeling/api.py
|
||||||
OCR Labeling API — Barrel Re-export
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Split into:
|
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.api")
|
||||||
- ocr_labeling_models.py — Pydantic models and constants
|
|
||||||
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
|
|
||||||
- ocr_labeling_routes.py — Session/queue/labeling route handlers
|
|
||||||
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
|
|
||||||
|
|
||||||
All public names are re-exported here for backward compatibility.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Models
|
|
||||||
from ocr_labeling_models import ( # noqa: F401
|
|
||||||
LOCAL_STORAGE_PATH,
|
|
||||||
SessionCreate,
|
|
||||||
SessionResponse,
|
|
||||||
ItemResponse,
|
|
||||||
ConfirmRequest,
|
|
||||||
CorrectRequest,
|
|
||||||
SkipRequest,
|
|
||||||
ExportRequest,
|
|
||||||
StatsResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Helpers
|
|
||||||
from ocr_labeling_helpers import ( # noqa: F401
|
|
||||||
VISION_OCR_AVAILABLE,
|
|
||||||
PADDLEOCR_AVAILABLE,
|
|
||||||
TROCR_AVAILABLE,
|
|
||||||
DONUT_AVAILABLE,
|
|
||||||
MINIO_AVAILABLE,
|
|
||||||
TRAINING_EXPORT_AVAILABLE,
|
|
||||||
compute_image_hash,
|
|
||||||
run_ocr_on_image,
|
|
||||||
run_vision_ocr_wrapper,
|
|
||||||
run_paddleocr_wrapper,
|
|
||||||
run_trocr_wrapper,
|
|
||||||
run_donut_wrapper,
|
|
||||||
save_image_locally,
|
|
||||||
get_image_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conditional re-exports from helpers' optional imports
|
|
||||||
try:
|
|
||||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from training_export_service import ( # noqa: F401
|
|
||||||
TrainingExportService,
|
|
||||||
TrainingSample,
|
|
||||||
get_training_export_service,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from services.trocr_service import run_trocr_ocr # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from services.donut_ocr_service import run_donut_ocr # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Routes (router is the main export for app.include_router)
|
|
||||||
from ocr_labeling_routes import router # noqa: F401
|
|
||||||
from ocr_labeling_upload_routes import router as upload_router # noqa: F401
|
|
||||||
|
|||||||
@@ -1,205 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/labeling/helpers.py
|
||||||
OCR Labeling - Helper Functions and OCR Wrappers
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.helpers")
|
||||||
|
|
||||||
DATENSCHUTZ/PRIVACY:
|
|
||||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
|
||||||
- Keine Daten werden an externe Server gesendet
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from ocr_labeling_models import LOCAL_STORAGE_PATH
|
|
||||||
|
|
||||||
# Try to import Vision OCR service
|
|
||||||
try:
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
|
||||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
|
||||||
VISION_OCR_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
VISION_OCR_AVAILABLE = False
|
|
||||||
print("Warning: Vision OCR service not available")
|
|
||||||
|
|
||||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
|
||||||
try:
|
|
||||||
from hybrid_vocab_extractor import run_paddle_ocr
|
|
||||||
PADDLEOCR_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
PADDLEOCR_AVAILABLE = False
|
|
||||||
print("Warning: PaddleOCR not available")
|
|
||||||
|
|
||||||
# Try to import TrOCR service
|
|
||||||
try:
|
|
||||||
from services.trocr_service import run_trocr_ocr
|
|
||||||
TROCR_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
TROCR_AVAILABLE = False
|
|
||||||
print("Warning: TrOCR service not available")
|
|
||||||
|
|
||||||
# Try to import Donut service
|
|
||||||
try:
|
|
||||||
from services.donut_ocr_service import run_donut_ocr
|
|
||||||
DONUT_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
DONUT_AVAILABLE = False
|
|
||||||
print("Warning: Donut OCR service not available")
|
|
||||||
|
|
||||||
# Try to import MinIO storage
|
|
||||||
try:
|
|
||||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
|
||||||
MINIO_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
MINIO_AVAILABLE = False
|
|
||||||
print("Warning: MinIO storage not available, using local storage")
|
|
||||||
|
|
||||||
# Try to import Training Export Service
|
|
||||||
try:
|
|
||||||
from training_export_service import (
|
|
||||||
TrainingExportService,
|
|
||||||
TrainingSample,
|
|
||||||
get_training_export_service,
|
|
||||||
)
|
|
||||||
TRAINING_EXPORT_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
TRAINING_EXPORT_AVAILABLE = False
|
|
||||||
print("Warning: Training export service not available")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
def compute_image_hash(image_data: bytes) -> str:
|
|
||||||
"""Compute SHA256 hash of image data."""
|
|
||||||
return hashlib.sha256(image_data).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
|
||||||
"""
|
|
||||||
Run OCR on an image using the specified model.
|
|
||||||
|
|
||||||
Models:
|
|
||||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
|
||||||
- trocr: Microsoft TrOCR (fast for printed text)
|
|
||||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
|
||||||
- donut: Document Understanding Transformer (structured documents)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ocr_text, confidence)
|
|
||||||
"""
|
|
||||||
print(f"Running OCR with model: {model}")
|
|
||||||
|
|
||||||
# Route to appropriate OCR service based on model
|
|
||||||
if model == "paddleocr":
|
|
||||||
return await run_paddleocr_wrapper(image_data, filename)
|
|
||||||
elif model == "donut":
|
|
||||||
return await run_donut_wrapper(image_data, filename)
|
|
||||||
elif model == "trocr":
|
|
||||||
return await run_trocr_wrapper(image_data, filename)
|
|
||||||
else:
|
|
||||||
# Default: Vision LLM (llama3.2-vision or similar)
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
||||||
"""Vision LLM OCR wrapper."""
|
|
||||||
if not VISION_OCR_AVAILABLE:
|
|
||||||
print("Vision OCR service not available")
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
service = get_vision_ocr_service()
|
|
||||||
if not await service.is_available():
|
|
||||||
print("Vision OCR service not available (is_available check failed)")
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
result = await service.extract_text(
|
|
||||||
image_data,
|
|
||||||
filename=filename,
|
|
||||||
is_handwriting=True
|
|
||||||
)
|
|
||||||
return result.text, result.confidence
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Vision OCR failed: {e}")
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
|
|
||||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
||||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
|
||||||
if not PADDLEOCR_AVAILABLE:
|
|
||||||
print("PaddleOCR not available, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# run_paddle_ocr returns (regions, raw_text)
|
|
||||||
regions, raw_text = run_paddle_ocr(image_data)
|
|
||||||
|
|
||||||
if not raw_text:
|
|
||||||
print("PaddleOCR returned empty text")
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
# Calculate average confidence from regions
|
|
||||||
if regions:
|
|
||||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
|
||||||
else:
|
|
||||||
avg_confidence = 0.5
|
|
||||||
|
|
||||||
return raw_text, avg_confidence
|
|
||||||
except Exception as e:
|
|
||||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
||||||
"""TrOCR wrapper."""
|
|
||||||
if not TROCR_AVAILABLE:
|
|
||||||
print("TrOCR not available, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
try:
|
|
||||||
text, confidence = await run_trocr_ocr(image_data)
|
|
||||||
return text, confidence
|
|
||||||
except Exception as e:
|
|
||||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
|
||||||
"""Donut OCR wrapper."""
|
|
||||||
if not DONUT_AVAILABLE:
|
|
||||||
print("Donut not available, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
try:
|
|
||||||
text, confidence = await run_donut_ocr(image_data)
|
|
||||||
return text, confidence
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
|
||||||
return await run_vision_ocr_wrapper(image_data, filename)
|
|
||||||
|
|
||||||
|
|
||||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
|
||||||
"""Save image to local storage."""
|
|
||||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
|
||||||
os.makedirs(session_dir, exist_ok=True)
|
|
||||||
|
|
||||||
filename = f"{item_id}.{extension}"
|
|
||||||
filepath = os.path.join(session_dir, filename)
|
|
||||||
|
|
||||||
with open(filepath, 'wb') as f:
|
|
||||||
f.write(image_data)
|
|
||||||
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_url(image_path: str) -> str:
|
|
||||||
"""Get URL for an image."""
|
|
||||||
# For local images, return a relative path that the frontend can use
|
|
||||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
|
||||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
|
||||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
|
||||||
# For MinIO images, the path is already a URL or key
|
|
||||||
return image_path
|
|
||||||
|
|||||||
@@ -1,86 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/labeling/models.py
|
||||||
OCR Labeling - Pydantic Models and Constants
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.models")
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional, Dict
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
# Local storage path (fallback if MinIO not available)
|
|
||||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Pydantic Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
class SessionCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
|
||||||
description: Optional[str] = None
|
|
||||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
|
||||||
|
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
source_type: str
|
|
||||||
description: Optional[str]
|
|
||||||
ocr_model: Optional[str]
|
|
||||||
total_items: int
|
|
||||||
labeled_items: int
|
|
||||||
confirmed_items: int
|
|
||||||
corrected_items: int
|
|
||||||
skipped_items: int
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class ItemResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
image_path: str
|
|
||||||
image_url: Optional[str]
|
|
||||||
ocr_text: Optional[str]
|
|
||||||
ocr_confidence: Optional[float]
|
|
||||||
ground_truth: Optional[str]
|
|
||||||
status: str
|
|
||||||
metadata: Optional[Dict]
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class ConfirmRequest(BaseModel):
|
|
||||||
item_id: str
|
|
||||||
label_time_seconds: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CorrectRequest(BaseModel):
|
|
||||||
item_id: str
|
|
||||||
ground_truth: str
|
|
||||||
label_time_seconds: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SkipRequest(BaseModel):
|
|
||||||
item_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRequest(BaseModel):
|
|
||||||
export_format: str = "generic" # generic, trocr, llama_vision
|
|
||||||
session_id: Optional[str] = None
|
|
||||||
batch_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class StatsResponse(BaseModel):
|
|
||||||
total_sessions: Optional[int] = None
|
|
||||||
total_items: int
|
|
||||||
labeled_items: int
|
|
||||||
confirmed_items: int
|
|
||||||
corrected_items: int
|
|
||||||
pending_items: int
|
|
||||||
exportable_items: Optional[int] = None
|
|
||||||
accuracy_rate: float
|
|
||||||
avg_label_time_seconds: Optional[float] = None
|
|
||||||
|
|||||||
@@ -1,241 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/labeling/routes.py
|
||||||
OCR Labeling - Session and Labeling Route Handlers
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.routes")
|
||||||
|
|
||||||
Endpoints:
|
|
||||||
- POST /sessions - Create labeling session
|
|
||||||
- GET /sessions - List sessions
|
|
||||||
- GET /sessions/{id} - Get session
|
|
||||||
- GET /queue - Get labeling queue
|
|
||||||
- GET /items/{id} - Get item
|
|
||||||
- POST /confirm - Confirm OCR
|
|
||||||
- POST /correct - Correct ground truth
|
|
||||||
- POST /skip - Skip item
|
|
||||||
- GET /stats - Get statistics
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
|
||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from metrics_db import (
|
|
||||||
create_ocr_labeling_session,
|
|
||||||
get_ocr_labeling_sessions,
|
|
||||||
get_ocr_labeling_session,
|
|
||||||
get_ocr_labeling_queue,
|
|
||||||
get_ocr_labeling_item,
|
|
||||||
confirm_ocr_label,
|
|
||||||
correct_ocr_label,
|
|
||||||
skip_ocr_item,
|
|
||||||
get_ocr_labeling_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ocr_labeling_models import (
|
|
||||||
SessionCreate, SessionResponse, ItemResponse,
|
|
||||||
ConfirmRequest, CorrectRequest, SkipRequest,
|
|
||||||
)
|
|
||||||
from ocr_labeling_helpers import get_image_url
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Session Endpoints
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
@router.post("/sessions", response_model=SessionResponse)
|
|
||||||
async def create_session(session: SessionCreate):
|
|
||||||
"""Create a new OCR labeling session."""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
success = await create_ocr_labeling_session(
|
|
||||||
session_id=session_id,
|
|
||||||
name=session.name,
|
|
||||||
source_type=session.source_type,
|
|
||||||
description=session.description,
|
|
||||||
ocr_model=session.ocr_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
|
||||||
|
|
||||||
return SessionResponse(
|
|
||||||
id=session_id,
|
|
||||||
name=session.name,
|
|
||||||
source_type=session.source_type,
|
|
||||||
description=session.description,
|
|
||||||
ocr_model=session.ocr_model,
|
|
||||||
total_items=0,
|
|
||||||
labeled_items=0,
|
|
||||||
confirmed_items=0,
|
|
||||||
corrected_items=0,
|
|
||||||
skipped_items=0,
|
|
||||||
created_at=datetime.utcnow(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions", response_model=List[SessionResponse])
|
|
||||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
|
||||||
"""List all OCR labeling sessions."""
|
|
||||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
|
||||||
|
|
||||||
return [
|
|
||||||
SessionResponse(
|
|
||||||
id=s['id'],
|
|
||||||
name=s['name'],
|
|
||||||
source_type=s['source_type'],
|
|
||||||
description=s.get('description'),
|
|
||||||
ocr_model=s.get('ocr_model'),
|
|
||||||
total_items=s.get('total_items', 0),
|
|
||||||
labeled_items=s.get('labeled_items', 0),
|
|
||||||
confirmed_items=s.get('confirmed_items', 0),
|
|
||||||
corrected_items=s.get('corrected_items', 0),
|
|
||||||
skipped_items=s.get('skipped_items', 0),
|
|
||||||
created_at=s.get('created_at', datetime.utcnow()),
|
|
||||||
)
|
|
||||||
for s in sessions
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
|
||||||
async def get_session(session_id: str):
|
|
||||||
"""Get a specific OCR labeling session."""
|
|
||||||
session = await get_ocr_labeling_session(session_id)
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
|
||||||
|
|
||||||
return SessionResponse(
|
|
||||||
id=session['id'],
|
|
||||||
name=session['name'],
|
|
||||||
source_type=session['source_type'],
|
|
||||||
description=session.get('description'),
|
|
||||||
ocr_model=session.get('ocr_model'),
|
|
||||||
total_items=session.get('total_items', 0),
|
|
||||||
labeled_items=session.get('labeled_items', 0),
|
|
||||||
confirmed_items=session.get('confirmed_items', 0),
|
|
||||||
corrected_items=session.get('corrected_items', 0),
|
|
||||||
skipped_items=session.get('skipped_items', 0),
|
|
||||||
created_at=session.get('created_at', datetime.utcnow()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Queue and Item Endpoints
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
@router.get("/queue", response_model=List[ItemResponse])
|
|
||||||
async def get_labeling_queue(
|
|
||||||
session_id: Optional[str] = Query(None),
|
|
||||||
status: str = Query("pending"),
|
|
||||||
limit: int = Query(10, ge=1, le=50),
|
|
||||||
):
|
|
||||||
"""Get items from the labeling queue."""
|
|
||||||
items = await get_ocr_labeling_queue(
|
|
||||||
session_id=session_id,
|
|
||||||
status=status,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
ItemResponse(
|
|
||||||
id=item['id'],
|
|
||||||
session_id=item['session_id'],
|
|
||||||
session_name=item.get('session_name', ''),
|
|
||||||
image_path=item['image_path'],
|
|
||||||
image_url=get_image_url(item['image_path']),
|
|
||||||
ocr_text=item.get('ocr_text'),
|
|
||||||
ocr_confidence=item.get('ocr_confidence'),
|
|
||||||
ground_truth=item.get('ground_truth'),
|
|
||||||
status=item.get('status', 'pending'),
|
|
||||||
metadata=item.get('metadata'),
|
|
||||||
created_at=item.get('created_at', datetime.utcnow()),
|
|
||||||
)
|
|
||||||
for item in items
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
|
||||||
async def get_item(item_id: str):
|
|
||||||
"""Get a specific labeling item."""
|
|
||||||
item = await get_ocr_labeling_item(item_id)
|
|
||||||
|
|
||||||
if not item:
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
|
|
||||||
return ItemResponse(
|
|
||||||
id=item['id'],
|
|
||||||
session_id=item['session_id'],
|
|
||||||
session_name=item.get('session_name', ''),
|
|
||||||
image_path=item['image_path'],
|
|
||||||
image_url=get_image_url(item['image_path']),
|
|
||||||
ocr_text=item.get('ocr_text'),
|
|
||||||
ocr_confidence=item.get('ocr_confidence'),
|
|
||||||
ground_truth=item.get('ground_truth'),
|
|
||||||
status=item.get('status', 'pending'),
|
|
||||||
metadata=item.get('metadata'),
|
|
||||||
created_at=item.get('created_at', datetime.utcnow()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Labeling Action Endpoints
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
@router.post("/confirm")
|
|
||||||
async def confirm_item(request: ConfirmRequest):
|
|
||||||
"""Confirm that OCR text is correct."""
|
|
||||||
success = await confirm_ocr_label(
|
|
||||||
item_id=request.item_id,
|
|
||||||
labeled_by="admin",
|
|
||||||
label_time_seconds=request.label_time_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
|
||||||
|
|
||||||
return {"status": "confirmed", "item_id": request.item_id}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/correct")
|
|
||||||
async def correct_item(request: CorrectRequest):
|
|
||||||
"""Save corrected ground truth for an item."""
|
|
||||||
success = await correct_ocr_label(
|
|
||||||
item_id=request.item_id,
|
|
||||||
ground_truth=request.ground_truth,
|
|
||||||
labeled_by="admin",
|
|
||||||
label_time_seconds=request.label_time_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
|
||||||
|
|
||||||
return {"status": "corrected", "item_id": request.item_id}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/skip")
|
|
||||||
async def skip_item(request: SkipRequest):
|
|
||||||
"""Skip an item (unusable image, etc.)."""
|
|
||||||
success = await skip_ocr_item(
|
|
||||||
item_id=request.item_id,
|
|
||||||
labeled_by="admin",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
|
||||||
|
|
||||||
return {"status": "skipped", "item_id": request.item_id}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats")
|
|
||||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
|
||||||
"""Get labeling statistics."""
|
|
||||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
|
||||||
|
|
||||||
if "error" in stats:
|
|
||||||
raise HTTPException(status_code=500, detail=stats["error"])
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|||||||
@@ -1,313 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/labeling/upload_routes.py
|
||||||
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
|
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.upload_routes")
|
||||||
|
|
||||||
Endpoints:
|
|
||||||
- POST /sessions/{id}/upload - Upload images for labeling
|
|
||||||
- POST /run-ocr/{item_id} - Run OCR on existing item
|
|
||||||
- POST /export - Export training data
|
|
||||||
- GET /training-samples - List training samples
|
|
||||||
- GET /images/{path} - Serve images from local storage
|
|
||||||
- GET /exports - List exports
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
|
||||||
from typing import Optional, List
|
|
||||||
import uuid
|
|
||||||
import os
|
|
||||||
|
|
||||||
from metrics_db import (
|
|
||||||
get_ocr_labeling_session,
|
|
||||||
add_ocr_labeling_item,
|
|
||||||
get_ocr_labeling_item,
|
|
||||||
export_training_samples,
|
|
||||||
get_training_samples,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ocr_labeling_models import (
|
|
||||||
ExportRequest,
|
|
||||||
LOCAL_STORAGE_PATH,
|
|
||||||
)
|
|
||||||
from ocr_labeling_helpers import (
|
|
||||||
compute_image_hash, run_ocr_on_image,
|
|
||||||
save_image_locally,
|
|
||||||
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conditional imports
|
|
||||||
try:
|
|
||||||
from minio_storage import upload_ocr_image, get_ocr_image
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from training_export_service import TrainingSample, get_training_export_service
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/upload")
|
|
||||||
async def upload_images(
|
|
||||||
session_id: str,
|
|
||||||
files: List[UploadFile] = File(...),
|
|
||||||
run_ocr: bool = Form(True),
|
|
||||||
metadata: Optional[str] = Form(None),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upload images to a labeling session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: Session to add images to
|
|
||||||
files: Image files to upload (PNG, JPG, PDF)
|
|
||||||
run_ocr: Whether to run OCR immediately (default: True)
|
|
||||||
metadata: Optional JSON metadata (subject, year, etc.)
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
session = await get_ocr_labeling_session(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
|
||||||
|
|
||||||
meta_dict = None
|
|
||||||
if metadata:
|
|
||||||
try:
|
|
||||||
meta_dict = json.loads(metadata)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
meta_dict = {"raw": metadata}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
content = await file.read()
|
|
||||||
image_hash = compute_image_hash(content)
|
|
||||||
item_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
|
||||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
|
||||||
extension = 'png'
|
|
||||||
|
|
||||||
if MINIO_AVAILABLE:
|
|
||||||
try:
|
|
||||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"MinIO upload failed, using local storage: {e}")
|
|
||||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
|
||||||
else:
|
|
||||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
|
||||||
|
|
||||||
ocr_text = None
|
|
||||||
ocr_confidence = None
|
|
||||||
|
|
||||||
if run_ocr and extension != 'pdf':
|
|
||||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
|
||||||
content,
|
|
||||||
file.filename or f"{item_id}.{extension}",
|
|
||||||
model=ocr_model
|
|
||||||
)
|
|
||||||
|
|
||||||
success = await add_ocr_labeling_item(
|
|
||||||
item_id=item_id,
|
|
||||||
session_id=session_id,
|
|
||||||
image_path=image_path,
|
|
||||||
image_hash=image_hash,
|
|
||||||
ocr_text=ocr_text,
|
|
||||||
ocr_confidence=ocr_confidence,
|
|
||||||
ocr_model=ocr_model if ocr_text else None,
|
|
||||||
metadata=meta_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
results.append({
|
|
||||||
"id": item_id,
|
|
||||||
"filename": file.filename,
|
|
||||||
"image_path": image_path,
|
|
||||||
"image_hash": image_hash,
|
|
||||||
"ocr_text": ocr_text,
|
|
||||||
"ocr_confidence": ocr_confidence,
|
|
||||||
"status": "pending",
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"uploaded_count": len(results),
|
|
||||||
"items": results,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/export")
|
|
||||||
async def export_data(request: ExportRequest):
|
|
||||||
"""Export labeled data for training."""
|
|
||||||
db_samples = await export_training_samples(
|
|
||||||
export_format=request.export_format,
|
|
||||||
session_id=request.session_id,
|
|
||||||
batch_id=request.batch_id,
|
|
||||||
exported_by="admin",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not db_samples:
|
|
||||||
return {
|
|
||||||
"export_format": request.export_format,
|
|
||||||
"batch_id": request.batch_id,
|
|
||||||
"exported_count": 0,
|
|
||||||
"samples": [],
|
|
||||||
"message": "No labeled samples found to export",
|
|
||||||
}
|
|
||||||
|
|
||||||
export_result = None
|
|
||||||
if TRAINING_EXPORT_AVAILABLE:
|
|
||||||
try:
|
|
||||||
export_service = get_training_export_service()
|
|
||||||
|
|
||||||
training_samples = []
|
|
||||||
for s in db_samples:
|
|
||||||
training_samples.append(TrainingSample(
|
|
||||||
id=s.get('id', s.get('item_id', '')),
|
|
||||||
image_path=s.get('image_path', ''),
|
|
||||||
ground_truth=s.get('ground_truth', ''),
|
|
||||||
ocr_text=s.get('ocr_text'),
|
|
||||||
ocr_confidence=s.get('ocr_confidence'),
|
|
||||||
metadata=s.get('metadata'),
|
|
||||||
))
|
|
||||||
|
|
||||||
export_result = export_service.export(
|
|
||||||
samples=training_samples,
|
|
||||||
export_format=request.export_format,
|
|
||||||
batch_id=request.batch_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Training export failed: {e}")
|
|
||||||
|
|
||||||
response = {
|
|
||||||
"export_format": request.export_format,
|
|
||||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
|
||||||
"exported_count": len(db_samples),
|
|
||||||
"samples": db_samples,
|
|
||||||
}
|
|
||||||
|
|
||||||
if export_result:
|
|
||||||
response["export_path"] = export_result.export_path
|
|
||||||
response["manifest_path"] = export_result.manifest_path
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/training-samples")
|
|
||||||
async def list_training_samples(
|
|
||||||
export_format: Optional[str] = Query(None),
|
|
||||||
batch_id: Optional[str] = Query(None),
|
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
|
||||||
):
|
|
||||||
"""Get exported training samples."""
|
|
||||||
samples = await get_training_samples(
|
|
||||||
export_format=export_format,
|
|
||||||
batch_id=batch_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"count": len(samples),
|
|
||||||
"samples": samples,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/images/{path:path}")
|
|
||||||
async def get_image(path: str):
|
|
||||||
"""Serve an image from local storage."""
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
|
|
||||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
|
||||||
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
|
||||||
|
|
||||||
extension = filepath.split('.')[-1].lower()
|
|
||||||
content_type = {
|
|
||||||
'png': 'image/png',
|
|
||||||
'jpg': 'image/jpeg',
|
|
||||||
'jpeg': 'image/jpeg',
|
|
||||||
'pdf': 'application/pdf',
|
|
||||||
}.get(extension, 'application/octet-stream')
|
|
||||||
|
|
||||||
return FileResponse(filepath, media_type=content_type)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/run-ocr/{item_id}")
|
|
||||||
async def run_ocr_for_item(item_id: str):
|
|
||||||
"""Run OCR on an existing item."""
|
|
||||||
item = await get_ocr_labeling_item(item_id)
|
|
||||||
|
|
||||||
if not item:
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
|
|
||||||
image_path = item['image_path']
|
|
||||||
|
|
||||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
raise HTTPException(status_code=404, detail="Image file not found")
|
|
||||||
with open(image_path, 'rb') as f:
|
|
||||||
image_data = f.read()
|
|
||||||
elif MINIO_AVAILABLE:
|
|
||||||
try:
|
|
||||||
image_data = get_ocr_image(image_path)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
|
||||||
|
|
||||||
session = await get_ocr_labeling_session(item['session_id'])
|
|
||||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
|
||||||
|
|
||||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
|
||||||
image_data,
|
|
||||||
os.path.basename(image_path),
|
|
||||||
model=ocr_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if ocr_text is None:
|
|
||||||
raise HTTPException(status_code=500, detail="OCR failed")
|
|
||||||
|
|
||||||
from metrics_db import get_pool
|
|
||||||
pool = await get_pool()
|
|
||||||
if pool:
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
"""
|
|
||||||
UPDATE ocr_labeling_items
|
|
||||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
|
||||||
WHERE id = $1
|
|
||||||
""",
|
|
||||||
item_id, ocr_text, ocr_confidence, ocr_model
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"item_id": item_id,
|
|
||||||
"ocr_text": ocr_text,
|
|
||||||
"ocr_confidence": ocr_confidence,
|
|
||||||
"ocr_model": ocr_model,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/exports")
|
|
||||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
|
||||||
"""List all available training data exports."""
|
|
||||||
if not TRAINING_EXPORT_AVAILABLE:
|
|
||||||
return {
|
|
||||||
"exports": [],
|
|
||||||
"message": "Training export service not available",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
export_service = get_training_export_service()
|
|
||||||
exports = export_service.list_exports(export_format=export_format)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"count": len(exports),
|
|
||||||
"exports": exports,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
|
||||||
|
|||||||
@@ -1,272 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/merge_helpers.py
|
||||||
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_ocr_merge.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.merge_helpers")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _split_paddle_multi_words(words: list) -> list:
|
|
||||||
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
|
||||||
|
|
||||||
PaddleOCR often returns entire phrases as a single box, e.g.
|
|
||||||
"More than 200 singers took part in the" with one bounding box.
|
|
||||||
This splits them into individual words with proportional widths.
|
|
||||||
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
|
|
||||||
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for w in words:
|
|
||||||
raw_text = w.get("text", "").strip()
|
|
||||||
if not raw_text:
|
|
||||||
continue
|
|
||||||
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
|
||||||
tokens = re.split(
|
|
||||||
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
|
||||||
)
|
|
||||||
tokens = [t for t in tokens if t]
|
|
||||||
|
|
||||||
if len(tokens) <= 1:
|
|
||||||
result.append(w)
|
|
||||||
else:
|
|
||||||
# Split proportionally by character count
|
|
||||||
total_chars = sum(len(t) for t in tokens)
|
|
||||||
if total_chars == 0:
|
|
||||||
continue
|
|
||||||
n_gaps = len(tokens) - 1
|
|
||||||
gap_px = w["width"] * 0.02
|
|
||||||
usable_w = w["width"] - gap_px * n_gaps
|
|
||||||
cursor = w["left"]
|
|
||||||
for t in tokens:
|
|
||||||
token_w = max(1, usable_w * len(t) / total_chars)
|
|
||||||
result.append({
|
|
||||||
"text": t,
|
|
||||||
"left": round(cursor),
|
|
||||||
"top": w["top"],
|
|
||||||
"width": round(token_w),
|
|
||||||
"height": w["height"],
|
|
||||||
"conf": w.get("conf", 0),
|
|
||||||
})
|
|
||||||
cursor += token_w + gap_px
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
|
||||||
"""Group words into rows by Y-position clustering.
|
|
||||||
|
|
||||||
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
|
||||||
Returns list of rows, each row is a list of words sorted left-to-right.
|
|
||||||
"""
|
|
||||||
if not words:
|
|
||||||
return []
|
|
||||||
# Sort by vertical center
|
|
||||||
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
|
||||||
rows: list = []
|
|
||||||
current_row: list = [sorted_words[0]]
|
|
||||||
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
|
||||||
|
|
||||||
for w in sorted_words[1:]:
|
|
||||||
cy = w["top"] + w.get("height", 0) / 2
|
|
||||||
if abs(cy - current_cy) <= row_gap:
|
|
||||||
current_row.append(w)
|
|
||||||
else:
|
|
||||||
# Sort current row left-to-right before saving
|
|
||||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
|
||||||
current_row = [w]
|
|
||||||
current_cy = cy
|
|
||||||
if current_row:
|
|
||||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
|
||||||
return rows
|
|
||||||
|
|
||||||
|
|
||||||
def _row_center_y(row: list) -> float:
|
|
||||||
"""Average vertical center of a row of words."""
|
|
||||||
if not row:
|
|
||||||
return 0.0
|
|
||||||
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
|
||||||
"""Merge two word sequences from the same row using sequence alignment.
|
|
||||||
|
|
||||||
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
|
||||||
- If words match (same/similar text): take Paddle text with averaged coords
|
|
||||||
- If they don't match: the extra word is unique to one engine, include it
|
|
||||||
"""
|
|
||||||
merged = []
|
|
||||||
pi, ti = 0, 0
|
|
||||||
|
|
||||||
while pi < len(paddle_row) and ti < len(tess_row):
|
|
||||||
pw = paddle_row[pi]
|
|
||||||
tw = tess_row[ti]
|
|
||||||
|
|
||||||
pt = pw.get("text", "").lower().strip()
|
|
||||||
tt = tw.get("text", "").lower().strip()
|
|
||||||
|
|
||||||
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
|
||||||
|
|
||||||
# Spatial overlap check
|
|
||||||
spatial_match = False
|
|
||||||
if not is_same:
|
|
||||||
overlap_left = max(pw["left"], tw["left"])
|
|
||||||
overlap_right = min(
|
|
||||||
pw["left"] + pw.get("width", 0),
|
|
||||||
tw["left"] + tw.get("width", 0),
|
|
||||||
)
|
|
||||||
overlap_w = max(0, overlap_right - overlap_left)
|
|
||||||
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
|
||||||
if min_w > 0 and overlap_w / min_w >= 0.4:
|
|
||||||
is_same = True
|
|
||||||
spatial_match = True
|
|
||||||
|
|
||||||
if is_same:
|
|
||||||
pc = pw.get("conf", 80)
|
|
||||||
tc = tw.get("conf", 50)
|
|
||||||
total = pc + tc
|
|
||||||
if total == 0:
|
|
||||||
total = 1
|
|
||||||
if spatial_match and pc < tc:
|
|
||||||
best_text = tw["text"]
|
|
||||||
else:
|
|
||||||
best_text = pw["text"]
|
|
||||||
merged.append({
|
|
||||||
"text": best_text,
|
|
||||||
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
|
||||||
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
|
||||||
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
|
||||||
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
|
||||||
"conf": max(pc, tc),
|
|
||||||
})
|
|
||||||
pi += 1
|
|
||||||
ti += 1
|
|
||||||
else:
|
|
||||||
paddle_ahead = any(
|
|
||||||
tess_row[t].get("text", "").lower().strip() == pt
|
|
||||||
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
|
||||||
)
|
|
||||||
tess_ahead = any(
|
|
||||||
paddle_row[p].get("text", "").lower().strip() == tt
|
|
||||||
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
|
||||||
)
|
|
||||||
|
|
||||||
if paddle_ahead and not tess_ahead:
|
|
||||||
if tw.get("conf", 0) >= 30:
|
|
||||||
merged.append(tw)
|
|
||||||
ti += 1
|
|
||||||
elif tess_ahead and not paddle_ahead:
|
|
||||||
merged.append(pw)
|
|
||||||
pi += 1
|
|
||||||
else:
|
|
||||||
if pw["left"] <= tw["left"]:
|
|
||||||
merged.append(pw)
|
|
||||||
pi += 1
|
|
||||||
else:
|
|
||||||
if tw.get("conf", 0) >= 30:
|
|
||||||
merged.append(tw)
|
|
||||||
ti += 1
|
|
||||||
|
|
||||||
while pi < len(paddle_row):
|
|
||||||
merged.append(paddle_row[pi])
|
|
||||||
pi += 1
|
|
||||||
while ti < len(tess_row):
|
|
||||||
tw = tess_row[ti]
|
|
||||||
if tw.get("conf", 0) >= 30:
|
|
||||||
merged.append(tw)
|
|
||||||
ti += 1
|
|
||||||
|
|
||||||
return merged
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
|
||||||
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
|
|
||||||
if not paddle_words and not tess_words:
|
|
||||||
return []
|
|
||||||
if not paddle_words:
|
|
||||||
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
|
||||||
if not tess_words:
|
|
||||||
return list(paddle_words)
|
|
||||||
|
|
||||||
paddle_rows = _group_words_into_rows(paddle_words)
|
|
||||||
tess_rows = _group_words_into_rows(tess_words)
|
|
||||||
|
|
||||||
used_tess_rows: set = set()
|
|
||||||
merged_all: list = []
|
|
||||||
|
|
||||||
for pr in paddle_rows:
|
|
||||||
pr_cy = _row_center_y(pr)
|
|
||||||
best_dist, best_tri = float("inf"), -1
|
|
||||||
for tri, tr in enumerate(tess_rows):
|
|
||||||
if tri in used_tess_rows:
|
|
||||||
continue
|
|
||||||
tr_cy = _row_center_y(tr)
|
|
||||||
dist = abs(pr_cy - tr_cy)
|
|
||||||
if dist < best_dist:
|
|
||||||
best_dist, best_tri = dist, tri
|
|
||||||
|
|
||||||
max_row_dist = max(
|
|
||||||
max((w.get("height", 20) for w in pr), default=20),
|
|
||||||
15,
|
|
||||||
)
|
|
||||||
|
|
||||||
if best_tri >= 0 and best_dist <= max_row_dist:
|
|
||||||
tr = tess_rows[best_tri]
|
|
||||||
used_tess_rows.add(best_tri)
|
|
||||||
merged_all.extend(_merge_row_sequences(pr, tr))
|
|
||||||
else:
|
|
||||||
merged_all.extend(pr)
|
|
||||||
|
|
||||||
for tri, tr in enumerate(tess_rows):
|
|
||||||
if tri not in used_tess_rows:
|
|
||||||
for tw in tr:
|
|
||||||
if tw.get("conf", 0) >= 40:
|
|
||||||
merged_all.append(tw)
|
|
||||||
|
|
||||||
return merged_all
|
|
||||||
|
|
||||||
|
|
||||||
def _deduplicate_words(words: list) -> list:
|
|
||||||
"""Remove duplicate words with same text at overlapping positions."""
|
|
||||||
if not words:
|
|
||||||
return words
|
|
||||||
|
|
||||||
result: list = []
|
|
||||||
for w in words:
|
|
||||||
wt = w.get("text", "").lower().strip()
|
|
||||||
if not wt:
|
|
||||||
continue
|
|
||||||
is_dup = False
|
|
||||||
w_right = w["left"] + w.get("width", 0)
|
|
||||||
w_bottom = w["top"] + w.get("height", 0)
|
|
||||||
for existing in result:
|
|
||||||
et = existing.get("text", "").lower().strip()
|
|
||||||
if wt != et:
|
|
||||||
continue
|
|
||||||
ox_l = max(w["left"], existing["left"])
|
|
||||||
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
|
|
||||||
ox = max(0, ox_r - ox_l)
|
|
||||||
min_w = min(w.get("width", 1), existing.get("width", 1))
|
|
||||||
if min_w <= 0 or ox / min_w < 0.5:
|
|
||||||
continue
|
|
||||||
oy_t = max(w["top"], existing["top"])
|
|
||||||
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
|
|
||||||
oy = max(0, oy_b - oy_t)
|
|
||||||
min_h = min(w.get("height", 1), existing.get("height", 1))
|
|
||||||
if min_h > 0 and oy / min_h >= 0.5:
|
|
||||||
is_dup = True
|
|
||||||
break
|
|
||||||
if not is_dup:
|
|
||||||
result.append(w)
|
|
||||||
|
|
||||||
removed = len(words) - len(result)
|
|
||||||
if removed:
|
|
||||||
logger.info("dedup: removed %d duplicate words", removed)
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,63 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/api.py
|
||||||
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Thin wrapper that assembles all sub-module routers into a single
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.api")
|
||||||
composite router. Backward-compatible: main.py and tests can still
|
|
||||||
import ``router``, ``_cache``, and helper functions from here.
|
|
||||||
|
|
||||||
Sub-modules (each < 1 000 lines):
|
|
||||||
ocr_pipeline_common – shared state, cache, Pydantic models, helpers
|
|
||||||
ocr_pipeline_sessions – session CRUD, image serving, doc-type
|
|
||||||
ocr_pipeline_geometry – deskew, dewarp, structure, columns
|
|
||||||
ocr_pipeline_rows – row detection, box-overlay helper
|
|
||||||
ocr_pipeline_words – word detection (SSE), paddle-direct, word GT
|
|
||||||
ocr_pipeline_ocr_merge – paddle/tesseract merge helpers, kombi endpoints
|
|
||||||
ocr_pipeline_postprocess – LLM review, reconstruction, export, validation
|
|
||||||
ocr_pipeline_auto – auto-mode orchestrator, reprocess
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Shared state (imported by main.py and orientation_crop_api.py)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
from ocr_pipeline_common import ( # noqa: F401 – re-exported
|
|
||||||
_cache,
|
|
||||||
_BORDER_GHOST_CHARS,
|
|
||||||
_filter_border_ghost_words,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sub-module routers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
from ocr_pipeline_sessions import router as _sessions_router
|
|
||||||
from ocr_pipeline_geometry import router as _geometry_router
|
|
||||||
from ocr_pipeline_rows import router as _rows_router
|
|
||||||
from ocr_pipeline_words import router as _words_router
|
|
||||||
from ocr_pipeline_ocr_merge import (
|
|
||||||
router as _ocr_merge_router,
|
|
||||||
# Re-export for test backward compatibility
|
|
||||||
_split_paddle_multi_words, # noqa: F401
|
|
||||||
_group_words_into_rows, # noqa: F401
|
|
||||||
_merge_row_sequences, # noqa: F401
|
|
||||||
_merge_paddle_tesseract, # noqa: F401
|
|
||||||
)
|
|
||||||
from ocr_pipeline_postprocess import router as _postprocess_router
|
|
||||||
from ocr_pipeline_auto import router as _auto_router
|
|
||||||
from ocr_pipeline_regression import router as _regression_router
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Composite router (used by main.py)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_sessions_router)
|
|
||||||
router.include_router(_geometry_router)
|
|
||||||
router.include_router(_rows_router)
|
|
||||||
router.include_router(_words_router)
|
|
||||||
router.include_router(_ocr_merge_router)
|
|
||||||
router.include_router(_postprocess_router)
|
|
||||||
router.include_router(_auto_router)
|
|
||||||
router.include_router(_regression_router)
|
|
||||||
|
|||||||
@@ -1,23 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/auto.py
|
||||||
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Split into submodules:
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto")
|
||||||
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
|
|
||||||
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from ocr_pipeline_reprocess import router as _reprocess_router
|
|
||||||
from ocr_pipeline_auto_steps import router as _steps_router
|
|
||||||
|
|
||||||
# Combine both sub-routers into a single router for backwards compatibility.
|
|
||||||
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
router.include_router(_reprocess_router)
|
|
||||||
router.include_router(_steps_router)
|
|
||||||
|
|
||||||
__all__ = ["router"]
|
|
||||||
|
|||||||
@@ -1,84 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/auto_helpers.py
|
||||||
OCR Pipeline Auto-Mode Helpers.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
VLM shear detection, SSE event formatting, and request models.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_helpers")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RunAutoRequest(BaseModel):
|
|
||||||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
|
||||||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
|
||||||
pronunciation: str = "british"
|
|
||||||
skip_llm_review: bool = False
|
|
||||||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
|
||||||
|
|
||||||
|
|
||||||
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
|
||||||
"""Format a single SSE event line."""
|
|
||||||
payload = {"step": step, "status": status, **data}
|
|
||||||
return f"data: {json.dumps(payload)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
|
||||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
|
||||||
|
|
||||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
|
||||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
|
||||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
|
||||||
"""
|
|
||||||
import httpx
|
|
||||||
import base64
|
|
||||||
|
|
||||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
||||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
|
||||||
"Are they perfectly vertical, or do they tilt slightly? "
|
|
||||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
|
||||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
|
||||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
|
||||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
|
||||||
)
|
|
||||||
|
|
||||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"images": [img_b64],
|
|
||||||
"stream": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
||||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
|
||||||
resp.raise_for_status()
|
|
||||||
text = resp.json().get("response", "")
|
|
||||||
|
|
||||||
# Parse JSON from response (may have surrounding text)
|
|
||||||
match = re.search(r'\{[^}]+\}', text)
|
|
||||||
if match:
|
|
||||||
data = json.loads(match.group(0))
|
|
||||||
shear = float(data.get("shear_degrees", 0.0))
|
|
||||||
conf = float(data.get("confidence", 0.0))
|
|
||||||
# Clamp to reasonable range
|
|
||||||
shear = max(-3.0, min(3.0, shear))
|
|
||||||
conf = max(0.0, min(1.0, conf))
|
|
||||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"VLM dewarp failed: {e}")
|
|
||||||
|
|
||||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
|
||||||
|
|||||||
@@ -1,528 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/auto_steps.py
|
||||||
OCR Pipeline Auto-Mode Orchestrator.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_steps")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
OLLAMA_REVIEW_MODEL,
|
|
||||||
PageRegion,
|
|
||||||
RowGeometry,
|
|
||||||
_cells_to_vocab_entries,
|
|
||||||
_detect_header_footer_gaps,
|
|
||||||
_detect_sub_columns,
|
|
||||||
_fix_character_confusion,
|
|
||||||
_fix_phonetic_brackets,
|
|
||||||
fix_cell_phonetics,
|
|
||||||
analyze_layout,
|
|
||||||
build_cell_grid,
|
|
||||||
classify_column_types,
|
|
||||||
create_layout_image,
|
|
||||||
create_ocr_image,
|
|
||||||
deskew_image,
|
|
||||||
deskew_image_by_word_alignment,
|
|
||||||
detect_column_geometry,
|
|
||||||
detect_row_geometry,
|
|
||||||
_apply_shear,
|
|
||||||
dewarp_image,
|
|
||||||
llm_review_entries,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_auto_helpers import (
|
|
||||||
RunAutoRequest,
|
|
||||||
auto_sse_event as _auto_sse_event,
|
|
||||||
detect_shear_with_vlm as _detect_shear_with_vlm,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/run-auto")
|
|
||||||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
|
||||||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Deskew -- straighten the scan
|
|
||||||
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
|
|
||||||
3. Columns -- detect column layout
|
|
||||||
4. Rows -- detect row layout
|
|
||||||
5. Words -- OCR each cell
|
|
||||||
6. LLM review -- correct OCR errors (optional)
|
|
||||||
|
|
||||||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
|
||||||
Yields SSE events of the form:
|
|
||||||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
|
||||||
|
|
||||||
Final event:
|
|
||||||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
|
||||||
"""
|
|
||||||
if req.from_step < 1 or req.from_step > 6:
|
|
||||||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
|
||||||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
|
||||||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
|
||||||
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
|
|
||||||
async def _generate():
|
|
||||||
steps_run: List[str] = []
|
|
||||||
steps_skipped: List[str] = []
|
|
||||||
error_step: Optional[str] = None
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
|
||||||
return
|
|
||||||
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
# Step 1: Deskew
|
|
||||||
if req.from_step <= 1:
|
|
||||||
yield await _auto_sse_event("deskew", "start", {})
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
orig_bgr = cached.get("original_bgr")
|
|
||||||
if orig_bgr is None:
|
|
||||||
raise ValueError("Original image not loaded")
|
|
||||||
|
|
||||||
try:
|
|
||||||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
|
||||||
except Exception:
|
|
||||||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
|
||||||
|
|
||||||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
|
||||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
|
||||||
try:
|
|
||||||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
|
||||||
except Exception:
|
|
||||||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
|
||||||
|
|
||||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
|
||||||
method_used = "word_alignment"
|
|
||||||
angle_applied = angle_wa
|
|
||||||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
|
||||||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
|
||||||
if deskewed_bgr is None:
|
|
||||||
deskewed_bgr = deskewed_hough
|
|
||||||
method_used = "hough"
|
|
||||||
angle_applied = angle_hough
|
|
||||||
else:
|
|
||||||
method_used = "hough"
|
|
||||||
angle_applied = angle_hough
|
|
||||||
deskewed_bgr = deskewed_hough
|
|
||||||
|
|
||||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
||||||
deskewed_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
deskew_result = {
|
|
||||||
"method_used": method_used,
|
|
||||||
"rotation_degrees": round(float(angle_applied), 3),
|
|
||||||
"duration_seconds": round(time.time() - t0, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["deskewed_bgr"] = deskewed_bgr
|
|
||||||
cached["deskew_result"] = deskew_result
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
deskewed_png=deskewed_png,
|
|
||||||
deskew_result=deskew_result,
|
|
||||||
auto_rotation_degrees=float(angle_applied),
|
|
||||||
current_step=3,
|
|
||||||
)
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
steps_run.append("deskew")
|
|
||||||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
|
||||||
error_step = "deskew"
|
|
||||||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
|
||||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
steps_skipped.append("deskew")
|
|
||||||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
|
||||||
|
|
||||||
# Step 2: Dewarp
|
|
||||||
if req.from_step <= 2:
|
|
||||||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
deskewed_bgr = cached.get("deskewed_bgr")
|
|
||||||
if deskewed_bgr is None:
|
|
||||||
raise ValueError("Deskewed image not available")
|
|
||||||
|
|
||||||
if req.dewarp_method == "vlm":
|
|
||||||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
||||||
img_bytes = png_buf.tobytes() if success_enc else b""
|
|
||||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
|
||||||
shear_deg = vlm_det["shear_degrees"]
|
|
||||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
|
||||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
|
||||||
else:
|
|
||||||
dewarped_bgr = deskewed_bgr
|
|
||||||
dewarp_info = {
|
|
||||||
"method": vlm_det["method"],
|
|
||||||
"shear_degrees": shear_deg,
|
|
||||||
"confidence": vlm_det["confidence"],
|
|
||||||
"detections": [vlm_det],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
|
||||||
|
|
||||||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
|
||||||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
|
||||||
|
|
||||||
dewarp_result = {
|
|
||||||
"method_used": dewarp_info["method"],
|
|
||||||
"shear_degrees": dewarp_info["shear_degrees"],
|
|
||||||
"confidence": dewarp_info["confidence"],
|
|
||||||
"duration_seconds": round(time.time() - t0, 2),
|
|
||||||
"detections": dewarp_info.get("detections", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["dewarped_bgr"] = dewarped_bgr
|
|
||||||
cached["dewarp_result"] = dewarp_result
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
dewarped_png=dewarped_png,
|
|
||||||
dewarp_result=dewarp_result,
|
|
||||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
|
||||||
current_step=4,
|
|
||||||
)
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
steps_run.append("dewarp")
|
|
||||||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
|
||||||
error_step = "dewarp"
|
|
||||||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
|
||||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
steps_skipped.append("dewarp")
|
|
||||||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
|
||||||
|
|
||||||
# Step 3: Columns
|
|
||||||
if req.from_step <= 3:
|
|
||||||
yield await _auto_sse_event("columns", "start", {})
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
if col_img is None:
|
|
||||||
raise ValueError("Cropped/dewarped image not available")
|
|
||||||
|
|
||||||
ocr_img = create_ocr_image(col_img)
|
|
||||||
h, w = ocr_img.shape[:2]
|
|
||||||
|
|
||||||
geo_result = detect_column_geometry(ocr_img, col_img)
|
|
||||||
if geo_result is None:
|
|
||||||
layout_img = create_layout_image(col_img)
|
|
||||||
regions = analyze_layout(layout_img, ocr_img)
|
|
||||||
cached["_word_dicts"] = None
|
|
||||||
cached["_inv"] = None
|
|
||||||
cached["_content_bounds"] = None
|
|
||||||
else:
|
|
||||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
|
||||||
content_w = right_x - left_x
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
||||||
|
|
||||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
|
||||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
|
||||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
|
||||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
|
||||||
left_x=left_x, right_x=right_x, inv=inv)
|
|
||||||
|
|
||||||
columns = [asdict(r) for r in regions]
|
|
||||||
column_result = {
|
|
||||||
"columns": columns,
|
|
||||||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
|
||||||
"duration_seconds": round(time.time() - t0, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["column_result"] = column_result
|
|
||||||
await update_session_db(session_id, column_result=column_result,
|
|
||||||
row_result=None, word_result=None, current_step=6)
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
steps_run.append("columns")
|
|
||||||
yield await _auto_sse_event("columns", "done", {
|
|
||||||
"column_count": len(columns),
|
|
||||||
"duration_seconds": column_result["duration_seconds"],
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
|
||||||
error_step = "columns"
|
|
||||||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
|
||||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
steps_skipped.append("columns")
|
|
||||||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
|
||||||
|
|
||||||
# Step 4: Rows
|
|
||||||
if req.from_step <= 4:
|
|
||||||
yield await _auto_sse_event("rows", "start", {})
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
column_result = session.get("column_result") or cached.get("column_result")
|
|
||||||
if not column_result or not column_result.get("columns"):
|
|
||||||
raise ValueError("Column detection must complete first")
|
|
||||||
|
|
||||||
col_regions = [
|
|
||||||
PageRegion(
|
|
||||||
type=c["type"], x=c["x"], y=c["y"],
|
|
||||||
width=c["width"], height=c["height"],
|
|
||||||
classification_confidence=c.get("classification_confidence", 1.0),
|
|
||||||
classification_method=c.get("classification_method", ""),
|
|
||||||
)
|
|
||||||
for c in column_result["columns"]
|
|
||||||
]
|
|
||||||
|
|
||||||
word_dicts = cached.get("_word_dicts")
|
|
||||||
inv = cached.get("_inv")
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
|
||||||
|
|
||||||
if word_dicts is None or inv is None or content_bounds is None:
|
|
||||||
ocr_img_tmp = create_ocr_image(row_img)
|
|
||||||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
|
||||||
if geo_result is None:
|
|
||||||
raise ValueError("Column geometry detection failed -- cannot detect rows")
|
|
||||||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (lx, rx, ty, by)
|
|
||||||
content_bounds = (lx, rx, ty, by)
|
|
||||||
|
|
||||||
left_x, right_x, top_y, bottom_y = content_bounds
|
|
||||||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
|
||||||
|
|
||||||
row_list = [
|
|
||||||
{
|
|
||||||
"index": r.index, "x": r.x, "y": r.y,
|
|
||||||
"width": r.width, "height": r.height,
|
|
||||||
"word_count": r.word_count,
|
|
||||||
"row_type": r.row_type,
|
|
||||||
"gap_before": r.gap_before,
|
|
||||||
}
|
|
||||||
for r in row_geoms
|
|
||||||
]
|
|
||||||
row_result = {
|
|
||||||
"rows": row_list,
|
|
||||||
"row_count": len(row_list),
|
|
||||||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
|
||||||
"duration_seconds": round(time.time() - t0, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
cached["row_result"] = row_result
|
|
||||||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
steps_run.append("rows")
|
|
||||||
yield await _auto_sse_event("rows", "done", {
|
|
||||||
"row_count": len(row_list),
|
|
||||||
"content_rows": row_result["content_rows"],
|
|
||||||
"duration_seconds": row_result["duration_seconds"],
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
|
||||||
error_step = "rows"
|
|
||||||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
|
||||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
steps_skipped.append("rows")
|
|
||||||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
|
||||||
|
|
||||||
# Step 5: Words (OCR)
|
|
||||||
if req.from_step <= 5:
|
|
||||||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
column_result = session.get("column_result") or cached.get("column_result")
|
|
||||||
row_result = session.get("row_result") or cached.get("row_result")
|
|
||||||
|
|
||||||
col_regions = [
|
|
||||||
PageRegion(
|
|
||||||
type=c["type"], x=c["x"], y=c["y"],
|
|
||||||
width=c["width"], height=c["height"],
|
|
||||||
classification_confidence=c.get("classification_confidence", 1.0),
|
|
||||||
classification_method=c.get("classification_method", ""),
|
|
||||||
)
|
|
||||||
for c in column_result["columns"]
|
|
||||||
]
|
|
||||||
row_geoms = [
|
|
||||||
RowGeometry(
|
|
||||||
index=r["index"], x=r["x"], y=r["y"],
|
|
||||||
width=r["width"], height=r["height"],
|
|
||||||
word_count=r.get("word_count", 0), words=[],
|
|
||||||
row_type=r.get("row_type", "content"),
|
|
||||||
gap_before=r.get("gap_before", 0),
|
|
||||||
)
|
|
||||||
for r in row_result["rows"]
|
|
||||||
]
|
|
||||||
|
|
||||||
word_dicts = cached.get("_word_dicts")
|
|
||||||
if word_dicts is not None:
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
|
||||||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
|
||||||
for row in row_geoms:
|
|
||||||
row_y_rel = row.y - top_y
|
|
||||||
row_bottom_rel = row_y_rel + row.height
|
|
||||||
row.words = [
|
|
||||||
w for w in word_dicts
|
|
||||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
|
||||||
]
|
|
||||||
row.word_count = len(row.words)
|
|
||||||
|
|
||||||
ocr_img = create_ocr_image(word_img)
|
|
||||||
img_h, img_w = word_img.shape[:2]
|
|
||||||
|
|
||||||
cells, columns_meta = build_cell_grid(
|
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
|
||||||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
|
||||||
)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
col_types = {c['type'] for c in columns_meta}
|
|
||||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
|
||||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
|
||||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
|
||||||
|
|
||||||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
|
||||||
|
|
||||||
word_result_data = {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {
|
|
||||||
"rows": n_content_rows,
|
|
||||||
"cols": len(columns_meta),
|
|
||||||
"total_cells": len(cells),
|
|
||||||
},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
has_text_col = 'column_text' in col_types
|
|
||||||
if is_vocab or has_text_col:
|
|
||||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
|
||||||
entries = _fix_character_confusion(entries)
|
|
||||||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
|
||||||
word_result_data["vocab_entries"] = entries
|
|
||||||
word_result_data["entries"] = entries
|
|
||||||
word_result_data["entry_count"] = len(entries)
|
|
||||||
word_result_data["summary"]["total_entries"] = len(entries)
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
|
||||||
cached["word_result"] = word_result_data
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
|
|
||||||
steps_run.append("words")
|
|
||||||
yield await _auto_sse_event("words", "done", {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"layout": word_result_data["layout"],
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"summary": word_result_data["summary"],
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
|
||||||
error_step = "words"
|
|
||||||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
|
||||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
steps_skipped.append("words")
|
|
||||||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
|
||||||
|
|
||||||
# Step 6: LLM Review (optional)
|
|
||||||
if req.from_step <= 6 and not req.skip_llm_review:
|
|
||||||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
|
||||||
try:
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
word_result = session.get("word_result") or cached.get("word_result")
|
|
||||||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
|
||||||
|
|
||||||
if not entries:
|
|
||||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
|
||||||
steps_skipped.append("llm_review")
|
|
||||||
else:
|
|
||||||
reviewed = await llm_review_entries(entries)
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
word_result_updated = dict(session.get("word_result") or {})
|
|
||||||
word_result_updated["entries"] = reviewed
|
|
||||||
word_result_updated["vocab_entries"] = reviewed
|
|
||||||
word_result_updated["llm_reviewed"] = True
|
|
||||||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
|
||||||
cached["word_result"] = word_result_updated
|
|
||||||
|
|
||||||
steps_run.append("llm_review")
|
|
||||||
yield await _auto_sse_event("llm_review", "done", {
|
|
||||||
"entries_reviewed": len(reviewed),
|
|
||||||
"model": OLLAMA_REVIEW_MODEL,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
|
||||||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
|
||||||
steps_skipped.append("llm_review")
|
|
||||||
else:
|
|
||||||
steps_skipped.append("llm_review")
|
|
||||||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
|
||||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
|
||||||
|
|
||||||
# Final event
|
|
||||||
yield await _auto_sse_event("complete", "done", {
|
|
||||||
"steps_run": steps_run,
|
|
||||||
"steps_skipped": steps_skipped,
|
|
||||||
})
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
_generate(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,293 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/columns.py
|
||||||
OCR Pipeline Column Detection Endpoints (Step 5)
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Detect invisible columns, manual column override, and ground truth.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.columns")
|
||||||
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import asdict
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
_detect_header_footer_gaps,
|
|
||||||
_detect_sub_columns,
|
|
||||||
classify_column_types,
|
|
||||||
create_layout_image,
|
|
||||||
create_ocr_image,
|
|
||||||
analyze_layout,
|
|
||||||
detect_column_geometry_zoned,
|
|
||||||
expand_narrow_columns,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_append_pipeline_log,
|
|
||||||
ManualColumnsRequest,
|
|
||||||
ColumnGroundTruthRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/columns")
|
|
||||||
async def detect_columns(session_id: str):
|
|
||||||
"""Run column detection on the cropped (or dewarped) image."""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Sub-sessions (box crops): skip column detection entirely.
|
|
||||||
# Instead, create a single pseudo-column spanning the full image width.
|
|
||||||
# Also run Tesseract + binarization here so that the row detection step
|
|
||||||
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
|
|
||||||
# instead of falling back to detect_column_geometry() which may fail
|
|
||||||
# on small box images with < 5 words.
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if session and session.get("parent_session_id"):
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# Binarize + invert for row detection (horizontal projection profile)
|
|
||||||
ocr_img = create_ocr_image(img_bgr)
|
|
||||||
inv = cv2.bitwise_not(ocr_img)
|
|
||||||
|
|
||||||
# Run Tesseract to get word bounding boxes.
|
|
||||||
try:
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
|
||||||
import pytesseract
|
|
||||||
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
|
|
||||||
word_dicts = []
|
|
||||||
for i in range(len(data['text'])):
|
|
||||||
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
|
|
||||||
text = str(data['text'][i]).strip()
|
|
||||||
if conf < 30 or not text:
|
|
||||||
continue
|
|
||||||
word_dicts.append({
|
|
||||||
'text': text, 'conf': conf,
|
|
||||||
'left': int(data['left'][i]),
|
|
||||||
'top': int(data['top'][i]),
|
|
||||||
'width': int(data['width'][i]),
|
|
||||||
'height': int(data['height'][i]),
|
|
||||||
})
|
|
||||||
# Log all words including low-confidence ones for debugging
|
|
||||||
all_count = sum(1 for i in range(len(data['text']))
|
|
||||||
if str(data['text'][i]).strip())
|
|
||||||
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
|
|
||||||
for i in range(len(data['text']))
|
|
||||||
if str(data['text'][i]).strip()
|
|
||||||
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
|
|
||||||
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
|
|
||||||
if low_conf:
|
|
||||||
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
|
|
||||||
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
|
|
||||||
word_dicts = []
|
|
||||||
|
|
||||||
# Cache intermediates for row detection (detect_rows reuses these)
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (0, w, 0, h)
|
|
||||||
|
|
||||||
column_result = {
|
|
||||||
"columns": [{
|
|
||||||
"type": "column_text",
|
|
||||||
"x": 0, "y": 0,
|
|
||||||
"width": w, "height": h,
|
|
||||||
}],
|
|
||||||
"zones": None,
|
|
||||||
"boxes_detected": 0,
|
|
||||||
"duration_seconds": 0,
|
|
||||||
"method": "sub_session_pseudo_column",
|
|
||||||
}
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
column_result=column_result,
|
|
||||||
row_result=None,
|
|
||||||
word_result=None,
|
|
||||||
current_step=6,
|
|
||||||
)
|
|
||||||
cached["column_result"] = column_result
|
|
||||||
cached.pop("row_result", None)
|
|
||||||
cached.pop("word_result", None)
|
|
||||||
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
|
|
||||||
return {"session_id": session_id, **column_result}
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# Binarized image for layout analysis
|
|
||||||
ocr_img = create_ocr_image(img_bgr)
|
|
||||||
h, w = ocr_img.shape[:2]
|
|
||||||
|
|
||||||
# Phase A: Zone-aware geometry detection
|
|
||||||
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
|
|
||||||
|
|
||||||
boxes_detected = 0
|
|
||||||
if zoned_result is None:
|
|
||||||
# Fallback to projection-based layout
|
|
||||||
layout_img = create_layout_image(img_bgr)
|
|
||||||
regions = analyze_layout(layout_img, ocr_img)
|
|
||||||
zones_data = None
|
|
||||||
else:
|
|
||||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
|
|
||||||
content_w = right_x - left_x
|
|
||||||
boxes_detected = len(boxes)
|
|
||||||
|
|
||||||
# Cache intermediates for row detection (avoids second Tesseract run)
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
||||||
cached["_zones_data"] = zones_data
|
|
||||||
cached["_boxes_detected"] = boxes_detected
|
|
||||||
|
|
||||||
# Detect header/footer early so sub-column clustering ignores them
|
|
||||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
|
||||||
|
|
||||||
# Split sub-columns (e.g. page references) before classification
|
|
||||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
|
||||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
|
||||||
|
|
||||||
# Expand narrow columns (sub-columns are often very narrow)
|
|
||||||
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
|
|
||||||
|
|
||||||
# Phase B: Content-based classification
|
|
||||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
|
||||||
left_x=left_x, right_x=right_x, inv=inv)
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
columns = [asdict(r) for r in regions]
|
|
||||||
|
|
||||||
# Determine classification methods used
|
|
||||||
methods = list(set(
|
|
||||||
c.get("classification_method", "") for c in columns
|
|
||||||
if c.get("classification_method")
|
|
||||||
))
|
|
||||||
|
|
||||||
column_result = {
|
|
||||||
"columns": columns,
|
|
||||||
"classification_methods": methods,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"boxes_detected": boxes_detected,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add zone data when boxes are present
|
|
||||||
if zones_data and boxes_detected > 0:
|
|
||||||
column_result["zones"] = zones_data
|
|
||||||
|
|
||||||
# Persist to DB -- also invalidate downstream results (rows, words)
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
column_result=column_result,
|
|
||||||
row_result=None,
|
|
||||||
word_result=None,
|
|
||||||
current_step=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["column_result"] = column_result
|
|
||||||
cached.pop("row_result", None)
|
|
||||||
cached.pop("word_result", None)
|
|
||||||
|
|
||||||
col_count = len([c for c in columns if c["type"].startswith("column")])
|
|
||||||
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
|
||||||
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
|
|
||||||
|
|
||||||
img_w = img_bgr.shape[1]
|
|
||||||
await _append_pipeline_log(session_id, "columns", {
|
|
||||||
"total_columns": len(columns),
|
|
||||||
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
|
|
||||||
"column_types": [c["type"] for c in columns],
|
|
||||||
"boxes_detected": boxes_detected,
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**column_result,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/columns/manual")
|
|
||||||
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
|
||||||
"""Override detected columns with manual definitions."""
|
|
||||||
column_result = {
|
|
||||||
"columns": req.columns,
|
|
||||||
"duration_seconds": 0,
|
|
||||||
"method": "manual",
|
|
||||||
}
|
|
||||||
|
|
||||||
await update_session_db(session_id, column_result=column_result,
|
|
||||||
row_result=None, word_result=None)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["column_result"] = column_result
|
|
||||||
_cache[session_id].pop("row_result", None)
|
|
||||||
_cache[session_id].pop("word_result", None)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
|
||||||
f"{len(req.columns)} columns set")
|
|
||||||
|
|
||||||
return {"session_id": session_id, **column_result}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/ground-truth/columns")
|
|
||||||
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
|
||||||
"""Save ground truth feedback for the column detection step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
gt = {
|
|
||||||
"is_correct": req.is_correct,
|
|
||||||
"corrected_columns": req.corrected_columns,
|
|
||||||
"notes": req.notes,
|
|
||||||
"saved_at": datetime.utcnow().isoformat(),
|
|
||||||
"column_result": session.get("column_result"),
|
|
||||||
}
|
|
||||||
ground_truth["columns"] = gt
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
return {"session_id": session_id, "ground_truth": gt}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/ground-truth/columns")
|
|
||||||
async def get_column_ground_truth(session_id: str):
|
|
||||||
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
columns_gt = ground_truth.get("columns")
|
|
||||||
if not columns_gt:
|
|
||||||
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"columns_gt": columns_gt,
|
|
||||||
"columns_auto": session.get("column_result"),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,354 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/common.py
|
||||||
Shared common module for the OCR pipeline.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Contains in-memory cache, helper functions, Pydantic request models,
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.common")
|
||||||
pipeline logging, and border-ghost word filtering used by the pipeline
|
|
||||||
API endpoints and related modules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Cache
|
|
||||||
"_cache",
|
|
||||||
# Helper functions
|
|
||||||
"_get_base_image_png",
|
|
||||||
"_load_session_to_cache",
|
|
||||||
"_get_cached",
|
|
||||||
# Pydantic models
|
|
||||||
"ManualDeskewRequest",
|
|
||||||
"DeskewGroundTruthRequest",
|
|
||||||
"ManualDewarpRequest",
|
|
||||||
"CombinedAdjustRequest",
|
|
||||||
"DewarpGroundTruthRequest",
|
|
||||||
"VALID_DOCUMENT_CATEGORIES",
|
|
||||||
"UpdateSessionRequest",
|
|
||||||
"ManualColumnsRequest",
|
|
||||||
"ColumnGroundTruthRequest",
|
|
||||||
"ManualRowsRequest",
|
|
||||||
"RowGroundTruthRequest",
|
|
||||||
"RemoveHandwritingRequest",
|
|
||||||
# Pipeline log
|
|
||||||
"_append_pipeline_log",
|
|
||||||
# Border-ghost filter
|
|
||||||
"_BORDER_GHOST_CHARS",
|
|
||||||
"_filter_border_ghost_words",
|
|
||||||
]
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
|
||||||
# DB is source of truth, cache holds BGR arrays during active processing.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_cache: Dict[str, Dict[str, Any]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
|
||||||
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
|
||||||
for img_type in ("cropped", "dewarped", "original"):
|
|
||||||
png_data = await get_session_image(session_id, img_type)
|
|
||||||
if png_data:
|
|
||||||
return png_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
|
||||||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
return _cache[session_id]
|
|
||||||
|
|
||||||
cache_entry: Dict[str, Any] = {
|
|
||||||
"id": session_id,
|
|
||||||
**session,
|
|
||||||
"original_bgr": None,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": None,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Decode images from DB into BGR numpy arrays
|
|
||||||
for img_type, bgr_key in [
|
|
||||||
("original", "original_bgr"),
|
|
||||||
("oriented", "oriented_bgr"),
|
|
||||||
("cropped", "cropped_bgr"),
|
|
||||||
("deskewed", "deskewed_bgr"),
|
|
||||||
("dewarped", "dewarped_bgr"),
|
|
||||||
]:
|
|
||||||
png_data = await get_session_image(session_id, img_type)
|
|
||||||
if png_data:
|
|
||||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
|
||||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
cache_entry[bgr_key] = bgr
|
|
||||||
|
|
||||||
# Sub-sessions: original image IS the cropped box region.
|
|
||||||
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
|
||||||
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
|
||||||
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
|
||||||
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
|
||||||
|
|
||||||
_cache[session_id] = cache_entry
|
|
||||||
return cache_entry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get from cache or raise 404."""
|
|
||||||
entry = _cache.get(session_id)
|
|
||||||
if not entry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
|
||||||
return entry
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pydantic Models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class ManualDeskewRequest(BaseModel):
|
|
||||||
angle: float
|
|
||||||
|
|
||||||
|
|
||||||
class DeskewGroundTruthRequest(BaseModel):
|
|
||||||
is_correct: bool
|
|
||||||
corrected_angle: Optional[float] = None
|
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ManualDewarpRequest(BaseModel):
|
|
||||||
shear_degrees: float
|
|
||||||
|
|
||||||
|
|
||||||
class CombinedAdjustRequest(BaseModel):
|
|
||||||
rotation_degrees: float = 0.0
|
|
||||||
shear_degrees: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class DewarpGroundTruthRequest(BaseModel):
|
|
||||||
is_correct: bool
|
|
||||||
corrected_shear: Optional[float] = None
|
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
VALID_DOCUMENT_CATEGORIES = {
|
|
||||||
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
|
||||||
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateSessionRequest(BaseModel):
|
|
||||||
name: Optional[str] = None
|
|
||||||
document_category: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ManualColumnsRequest(BaseModel):
|
|
||||||
columns: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnGroundTruthRequest(BaseModel):
|
|
||||||
is_correct: bool
|
|
||||||
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ManualRowsRequest(BaseModel):
|
|
||||||
rows: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class RowGroundTruthRequest(BaseModel):
|
|
||||||
is_correct: bool
|
|
||||||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class RemoveHandwritingRequest(BaseModel):
|
|
||||||
method: str = "auto" # "auto" | "telea" | "ns"
|
|
||||||
target_ink: str = "all" # "all" | "colored" | "pencil"
|
|
||||||
dilation: int = 2 # mask dilation iterations (0-5)
|
|
||||||
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pipeline Log Helper
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _append_pipeline_log(
|
|
||||||
session_id: str,
|
|
||||||
step_name: str,
|
|
||||||
metrics: Dict[str, Any],
|
|
||||||
success: bool = True,
|
|
||||||
duration_ms: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""Append a step entry to the session's pipeline_log JSONB."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
return
|
|
||||||
log = session.get("pipeline_log") or {"steps": []}
|
|
||||||
if not isinstance(log, dict):
|
|
||||||
log = {"steps": []}
|
|
||||||
entry = {
|
|
||||||
"step": step_name,
|
|
||||||
"completed_at": datetime.utcnow().isoformat(),
|
|
||||||
"success": success,
|
|
||||||
"metrics": metrics,
|
|
||||||
}
|
|
||||||
if duration_ms is not None:
|
|
||||||
entry["duration_ms"] = duration_ms
|
|
||||||
log.setdefault("steps", []).append(entry)
|
|
||||||
await update_session_db(session_id, pipeline_log=log)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Border-ghost word filter
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Characters that OCR produces when reading box-border lines.
|
|
||||||
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"")
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_border_ghost_words(
|
|
||||||
word_result: Dict,
|
|
||||||
boxes: List,
|
|
||||||
) -> int:
|
|
||||||
"""Remove OCR words that are actually box border lines.
|
|
||||||
|
|
||||||
A word is considered a border ghost when it sits on a known box edge
|
|
||||||
(left, right, top, or bottom) and looks like a line artefact (narrow
|
|
||||||
aspect ratio or text consists only of line-like characters).
|
|
||||||
|
|
||||||
After removing ghost cells, columns that have become empty are also
|
|
||||||
removed from ``columns_used`` so the grid no longer shows phantom
|
|
||||||
columns.
|
|
||||||
|
|
||||||
Modifies *word_result* in-place and returns the number of removed cells.
|
|
||||||
"""
|
|
||||||
if not boxes or not word_result:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
cells = word_result.get("cells")
|
|
||||||
if not cells:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Build border bands — vertical (X) and horizontal (Y)
|
|
||||||
x_bands = [] # list of (x_lo, x_hi)
|
|
||||||
y_bands = [] # list of (y_lo, y_hi)
|
|
||||||
for b in boxes:
|
|
||||||
bx = b.x if hasattr(b, "x") else b.get("x", 0)
|
|
||||||
by = b.y if hasattr(b, "y") else b.get("y", 0)
|
|
||||||
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
|
|
||||||
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
|
|
||||||
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
|
|
||||||
margin = max(bt * 2, 10) + 6 # generous margin
|
|
||||||
|
|
||||||
# Vertical edges (left / right)
|
|
||||||
x_bands.append((bx - margin, bx + margin))
|
|
||||||
x_bands.append((bx + bw - margin, bx + bw + margin))
|
|
||||||
# Horizontal edges (top / bottom)
|
|
||||||
y_bands.append((by - margin, by + margin))
|
|
||||||
y_bands.append((by + bh - margin, by + bh + margin))
|
|
||||||
|
|
||||||
img_w = word_result.get("image_width", 1)
|
|
||||||
img_h = word_result.get("image_height", 1)
|
|
||||||
|
|
||||||
def _is_ghost(cell: Dict) -> bool:
|
|
||||||
text = (cell.get("text") or "").strip()
|
|
||||||
if not text:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Compute absolute pixel position
|
|
||||||
if cell.get("bbox_px"):
|
|
||||||
px = cell["bbox_px"]
|
|
||||||
cx = px["x"] + px["w"] / 2
|
|
||||||
cy = px["y"] + px["h"] / 2
|
|
||||||
cw = px["w"]
|
|
||||||
ch = px["h"]
|
|
||||||
elif cell.get("bbox_pct"):
|
|
||||||
pct = cell["bbox_pct"]
|
|
||||||
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
|
|
||||||
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
|
|
||||||
cw = (pct["w"] / 100) * img_w
|
|
||||||
ch = (pct["h"] / 100) * img_h
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if center sits on a vertical or horizontal border
|
|
||||||
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
|
|
||||||
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
|
|
||||||
if not on_vertical and not on_horizontal:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Very short text (1-2 chars) on a border → very likely ghost
|
|
||||||
if len(text) <= 2:
|
|
||||||
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
|
|
||||||
if ch > 0 and cw / ch < 0.5:
|
|
||||||
return True
|
|
||||||
if cw > 0 and ch / cw < 0.5:
|
|
||||||
return True
|
|
||||||
# Text is only border-ghost characters?
|
|
||||||
if all(c in _BORDER_GHOST_CHARS for c in text):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Longer text but still only ghost chars and very narrow
|
|
||||||
if all(c in _BORDER_GHOST_CHARS for c in text):
|
|
||||||
if ch > 0 and cw / ch < 0.35:
|
|
||||||
return True
|
|
||||||
if cw > 0 and ch / cw < 0.35:
|
|
||||||
return True
|
|
||||||
return True # all ghost chars on a border → remove
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
before = len(cells)
|
|
||||||
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
|
|
||||||
removed = before - len(word_result["cells"])
|
|
||||||
|
|
||||||
# --- Remove empty columns from columns_used ---
|
|
||||||
columns_used = word_result.get("columns_used")
|
|
||||||
if removed and columns_used and len(columns_used) > 1:
|
|
||||||
remaining_cells = word_result["cells"]
|
|
||||||
occupied_cols = {c.get("col_index") for c in remaining_cells}
|
|
||||||
before_cols = len(columns_used)
|
|
||||||
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
|
|
||||||
|
|
||||||
# Re-index columns and remap cell col_index values
|
|
||||||
if len(columns_used) < before_cols:
|
|
||||||
old_to_new = {}
|
|
||||||
for new_i, col in enumerate(columns_used):
|
|
||||||
old_to_new[col["index"]] = new_i
|
|
||||||
col["index"] = new_i
|
|
||||||
for cell in remaining_cells:
|
|
||||||
old_ci = cell.get("col_index")
|
|
||||||
if old_ci in old_to_new:
|
|
||||||
cell["col_index"] = old_to_new[old_ci]
|
|
||||||
word_result["columns_used"] = columns_used
|
|
||||||
logger.info("border-ghost: removed %d empty column(s), %d remaining",
|
|
||||||
before_cols - len(columns_used), len(columns_used))
|
|
||||||
|
|
||||||
if removed:
|
|
||||||
# Update summary counts
|
|
||||||
summary = word_result.get("summary", {})
|
|
||||||
summary["total_cells"] = len(word_result["cells"])
|
|
||||||
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
|
|
||||||
word_result["summary"] = summary
|
|
||||||
gs = word_result.get("grid_shape", {})
|
|
||||||
gs["total_cells"] = len(word_result["cells"])
|
|
||||||
if columns_used is not None:
|
|
||||||
gs["cols"] = len(columns_used)
|
|
||||||
word_result["grid_shape"] = gs
|
|
||||||
|
|
||||||
return removed
|
|
||||||
|
|||||||
@@ -1,236 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/deskew.py
|
||||||
OCR Pipeline Deskew Endpoints (Step 2)
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Auto deskew, manual deskew, and ground truth for the deskew step.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.deskew")
|
||||||
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
create_ocr_image,
|
|
||||||
deskew_image,
|
|
||||||
deskew_image_by_word_alignment,
|
|
||||||
deskew_two_pass,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_append_pipeline_log,
|
|
||||||
ManualDeskewRequest,
|
|
||||||
DeskewGroundTruthRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/deskew")
|
|
||||||
async def auto_deskew(session_id: str):
|
|
||||||
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
|
|
||||||
# Ensure session is in cache
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
# Deskew runs right after orientation -- use oriented image, fall back to original
|
|
||||||
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None), None)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# Two-pass deskew: iterative (+-5 deg) + word-alignment residual check
|
|
||||||
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
|
|
||||||
|
|
||||||
# Also run individual methods for reporting (non-authoritative)
|
|
||||||
try:
|
|
||||||
_, angle_hough = deskew_image(img_bgr.copy())
|
|
||||||
except Exception:
|
|
||||||
angle_hough = 0.0
|
|
||||||
|
|
||||||
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
|
||||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
|
||||||
try:
|
|
||||||
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
|
||||||
except Exception:
|
|
||||||
angle_wa = 0.0
|
|
||||||
|
|
||||||
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
|
|
||||||
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
|
|
||||||
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
|
|
||||||
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Encode as PNG
|
|
||||||
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
||||||
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
# Create binarized version
|
|
||||||
binarized_png = None
|
|
||||||
try:
|
|
||||||
binarized = create_ocr_image(deskewed_bgr)
|
|
||||||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
||||||
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Binarization failed: {e}")
|
|
||||||
|
|
||||||
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
|
||||||
|
|
||||||
deskew_result = {
|
|
||||||
"angle_hough": round(angle_hough, 3),
|
|
||||||
"angle_word_alignment": round(angle_wa, 3),
|
|
||||||
"angle_iterative": round(angle_iterative, 3),
|
|
||||||
"angle_residual": round(angle_residual, 3),
|
|
||||||
"angle_textline": round(angle_textline, 3),
|
|
||||||
"angle_applied": round(angle_applied, 3),
|
|
||||||
"method_used": method_used,
|
|
||||||
"confidence": round(confidence, 2),
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"two_pass_debug": two_pass_debug,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["deskewed_bgr"] = deskewed_bgr
|
|
||||||
cached["binarized_png"] = binarized_png
|
|
||||||
cached["deskew_result"] = deskew_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
db_update = {
|
|
||||||
"deskewed_png": deskewed_png,
|
|
||||||
"deskew_result": deskew_result,
|
|
||||||
"current_step": 3,
|
|
||||||
}
|
|
||||||
if binarized_png:
|
|
||||||
db_update["binarized_png"] = binarized_png
|
|
||||||
await update_session_db(session_id, **db_update)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
|
||||||
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
|
|
||||||
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
|
|
||||||
f"textline={angle_textline:.2f} "
|
|
||||||
f"-> {method_used} total={angle_applied:.2f}")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "deskew", {
|
|
||||||
"angle_applied": round(angle_applied, 3),
|
|
||||||
"angle_iterative": round(angle_iterative, 3),
|
|
||||||
"angle_residual": round(angle_residual, 3),
|
|
||||||
"angle_textline": round(angle_textline, 3),
|
|
||||||
"confidence": round(confidence, 2),
|
|
||||||
"method": method_used,
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**deskew_result,
|
|
||||||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
|
||||||
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/deskew/manual")
|
|
||||||
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
|
||||||
"""Apply a manual rotation angle to the oriented image."""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None), None)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
|
||||||
|
|
||||||
angle = max(-5.0, min(5.0, req.angle))
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
center = (w // 2, h // 2)
|
|
||||||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
|
||||||
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
|
||||||
flags=cv2.INTER_LINEAR,
|
|
||||||
borderMode=cv2.BORDER_REPLICATE)
|
|
||||||
|
|
||||||
success, png_buf = cv2.imencode(".png", rotated)
|
|
||||||
deskewed_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
# Binarize
|
|
||||||
binarized_png = None
|
|
||||||
try:
|
|
||||||
binarized = create_ocr_image(rotated)
|
|
||||||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
||||||
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
deskew_result = {
|
|
||||||
**(cached.get("deskew_result") or {}),
|
|
||||||
"angle_applied": round(angle, 3),
|
|
||||||
"method_used": "manual",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["deskewed_bgr"] = rotated
|
|
||||||
cached["binarized_png"] = binarized_png
|
|
||||||
cached["deskew_result"] = deskew_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
db_update = {
|
|
||||||
"deskewed_png": deskewed_png,
|
|
||||||
"deskew_result": deskew_result,
|
|
||||||
}
|
|
||||||
if binarized_png:
|
|
||||||
db_update["binarized_png"] = binarized_png
|
|
||||||
await update_session_db(session_id, **db_update)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"angle_applied": round(angle, 3),
|
|
||||||
"method_used": "manual",
|
|
||||||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
|
||||||
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
|
||||||
"""Save ground truth feedback for the deskew step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
gt = {
|
|
||||||
"is_correct": req.is_correct,
|
|
||||||
"corrected_angle": req.corrected_angle,
|
|
||||||
"notes": req.notes,
|
|
||||||
"saved_at": datetime.utcnow().isoformat(),
|
|
||||||
"deskew_result": session.get("deskew_result"),
|
|
||||||
}
|
|
||||||
ground_truth["deskew"] = gt
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
|
||||||
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
|
||||||
|
|
||||||
return {"session_id": session_id, "ground_truth": gt}
|
|
||||||
|
|||||||
@@ -1,346 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/dewarp.py
|
||||||
OCR Pipeline Dewarp Endpoints
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Auto dewarp (with VLM/CV ensemble), manual dewarp, combined
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.dewarp")
|
||||||
rotation+shear adjustment, and ground truth.
|
|
||||||
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
_apply_shear,
|
|
||||||
create_ocr_image,
|
|
||||||
dewarp_image,
|
|
||||||
dewarp_image_manual,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_append_pipeline_log,
|
|
||||||
ManualDewarpRequest,
|
|
||||||
CombinedAdjustRequest,
|
|
||||||
DewarpGroundTruthRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
|
||||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
|
||||||
|
|
||||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
|
||||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
|
||||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
|
||||||
"""
|
|
||||||
import httpx
|
|
||||||
import base64
|
|
||||||
|
|
||||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
||||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
|
||||||
"Are they perfectly vertical, or do they tilt slightly? "
|
|
||||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
|
||||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
|
||||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
|
||||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
|
||||||
)
|
|
||||||
|
|
||||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"images": [img_b64],
|
|
||||||
"stream": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
||||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
|
||||||
resp.raise_for_status()
|
|
||||||
text = resp.json().get("response", "")
|
|
||||||
|
|
||||||
# Parse JSON from response (may have surrounding text)
|
|
||||||
match = re.search(r'\{[^}]+\}', text)
|
|
||||||
if match:
|
|
||||||
data = json.loads(match.group(0))
|
|
||||||
shear = float(data.get("shear_degrees", 0.0))
|
|
||||||
conf = float(data.get("confidence", 0.0))
|
|
||||||
# Clamp to reasonable range
|
|
||||||
shear = max(-3.0, min(3.0, shear))
|
|
||||||
conf = max(0.0, min(1.0, conf))
|
|
||||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"VLM dewarp failed: {e}")
|
|
||||||
|
|
||||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/dewarp")
|
|
||||||
async def auto_dewarp(
|
|
||||||
session_id: str,
|
|
||||||
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
|
|
||||||
):
|
|
||||||
"""Detect and correct vertical shear on the deskewed image.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
|
|
||||||
- **cv**: CV ensemble only (same as ensemble)
|
|
||||||
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
|
|
||||||
"""
|
|
||||||
if method not in ("ensemble", "cv", "vlm"):
|
|
||||||
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
|
|
||||||
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
deskewed_bgr = cached.get("deskewed_bgr")
|
|
||||||
if deskewed_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
if method == "vlm":
|
|
||||||
# Encode deskewed image to PNG for VLM
|
|
||||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
||||||
img_bytes = png_buf.tobytes() if success else b""
|
|
||||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
|
||||||
shear_deg = vlm_det["shear_degrees"]
|
|
||||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
|
||||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
|
||||||
else:
|
|
||||||
dewarped_bgr = deskewed_bgr
|
|
||||||
dewarp_info = {
|
|
||||||
"method": vlm_det["method"],
|
|
||||||
"shear_degrees": shear_deg,
|
|
||||||
"confidence": vlm_det["confidence"],
|
|
||||||
"detections": [vlm_det],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
# Encode as PNG
|
|
||||||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
|
||||||
dewarped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
dewarp_result = {
|
|
||||||
"method_used": dewarp_info["method"],
|
|
||||||
"shear_degrees": dewarp_info["shear_degrees"],
|
|
||||||
"confidence": dewarp_info["confidence"],
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"detections": dewarp_info.get("detections", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["dewarped_bgr"] = dewarped_bgr
|
|
||||||
cached["dewarp_result"] = dewarp_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
dewarped_png=dewarped_png,
|
|
||||||
dewarp_result=dewarp_result,
|
|
||||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
|
||||||
current_step=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
|
||||||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
|
||||||
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "dewarp", {
|
|
||||||
"shear_degrees": dewarp_info["shear_degrees"],
|
|
||||||
"confidence": dewarp_info["confidence"],
|
|
||||||
"method": dewarp_info["method"],
|
|
||||||
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**dewarp_result,
|
|
||||||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/dewarp/manual")
|
|
||||||
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
|
||||||
"""Apply shear correction with a manual angle."""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
deskewed_bgr = cached.get("deskewed_bgr")
|
|
||||||
if deskewed_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
|
||||||
|
|
||||||
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
|
||||||
|
|
||||||
if abs(shear_deg) < 0.001:
|
|
||||||
dewarped_bgr = deskewed_bgr
|
|
||||||
else:
|
|
||||||
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
|
||||||
|
|
||||||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
|
||||||
dewarped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
dewarp_result = {
|
|
||||||
**(cached.get("dewarp_result") or {}),
|
|
||||||
"method_used": "manual",
|
|
||||||
"shear_degrees": round(shear_deg, 3),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["dewarped_bgr"] = dewarped_bgr
|
|
||||||
cached["dewarp_result"] = dewarp_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
dewarped_png=dewarped_png,
|
|
||||||
dewarp_result=dewarp_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"shear_degrees": round(shear_deg, 3),
|
|
||||||
"method_used": "manual",
|
|
||||||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/adjust-combined")
|
|
||||||
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
|
|
||||||
"""Apply rotation + shear combined to the original image.
|
|
||||||
|
|
||||||
Used by the fine-tuning sliders to preview arbitrary rotation/shear
|
|
||||||
combinations without re-running the full deskew/dewarp pipeline.
|
|
||||||
"""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = cached.get("original_bgr")
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Original image not available")
|
|
||||||
|
|
||||||
rotation = max(-15.0, min(15.0, req.rotation_degrees))
|
|
||||||
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
result_bgr = img_bgr
|
|
||||||
|
|
||||||
# Step 1: Apply rotation
|
|
||||||
if abs(rotation) >= 0.001:
|
|
||||||
center = (w // 2, h // 2)
|
|
||||||
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
|
|
||||||
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
|
|
||||||
flags=cv2.INTER_LINEAR,
|
|
||||||
borderMode=cv2.BORDER_REPLICATE)
|
|
||||||
|
|
||||||
# Step 2: Apply shear
|
|
||||||
if abs(shear_deg) >= 0.001:
|
|
||||||
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
|
|
||||||
|
|
||||||
# Encode
|
|
||||||
success, png_buf = cv2.imencode(".png", result_bgr)
|
|
||||||
dewarped_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
# Binarize
|
|
||||||
binarized_png = None
|
|
||||||
try:
|
|
||||||
binarized = create_ocr_image(result_bgr)
|
|
||||||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
||||||
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build combined result dicts
|
|
||||||
deskew_result = {
|
|
||||||
**(cached.get("deskew_result") or {}),
|
|
||||||
"angle_applied": round(rotation, 3),
|
|
||||||
"method_used": "manual_combined",
|
|
||||||
}
|
|
||||||
dewarp_result = {
|
|
||||||
**(cached.get("dewarp_result") or {}),
|
|
||||||
"method_used": "manual_combined",
|
|
||||||
"shear_degrees": round(shear_deg, 3),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["deskewed_bgr"] = result_bgr
|
|
||||||
cached["dewarped_bgr"] = result_bgr
|
|
||||||
cached["deskew_result"] = deskew_result
|
|
||||||
cached["dewarp_result"] = dewarp_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
db_update = {
|
|
||||||
"dewarped_png": dewarped_png,
|
|
||||||
"deskew_result": deskew_result,
|
|
||||||
"dewarp_result": dewarp_result,
|
|
||||||
}
|
|
||||||
if binarized_png:
|
|
||||||
db_update["binarized_png"] = binarized_png
|
|
||||||
db_update["deskewed_png"] = dewarped_png
|
|
||||||
await update_session_db(session_id, **db_update)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
|
|
||||||
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"rotation_degrees": round(rotation, 3),
|
|
||||||
"shear_degrees": round(shear_deg, 3),
|
|
||||||
"method_used": "manual_combined",
|
|
||||||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
|
||||||
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
|
||||||
"""Save ground truth feedback for the dewarp step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
gt = {
|
|
||||||
"is_correct": req.is_correct,
|
|
||||||
"corrected_shear": req.corrected_shear,
|
|
||||||
"notes": req.notes,
|
|
||||||
"saved_at": datetime.utcnow().isoformat(),
|
|
||||||
"dewarp_result": session.get("dewarp_result"),
|
|
||||||
}
|
|
||||||
ground_truth["dewarp"] = gt
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
|
||||||
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
|
||||||
|
|
||||||
return {"session_id": session_id, "ground_truth": gt}
|
|
||||||
|
|||||||
@@ -1,27 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/geometry.py
|
||||||
OCR Pipeline Geometry API (barrel re-export)
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
This module was split into:
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.geometry")
|
||||||
- ocr_pipeline_deskew.py (Deskew endpoints)
|
|
||||||
- ocr_pipeline_dewarp.py (Dewarp endpoints)
|
|
||||||
- ocr_pipeline_structure.py (Structure detection + exclude regions)
|
|
||||||
- ocr_pipeline_columns.py (Column detection + ground truth)
|
|
||||||
|
|
||||||
The `router` object is assembled here by including all sub-routers.
|
|
||||||
Importers that did `from ocr_pipeline_geometry import router` continue to work.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from ocr_pipeline_deskew import router as _deskew_router
|
|
||||||
from ocr_pipeline_dewarp import router as _dewarp_router
|
|
||||||
from ocr_pipeline_structure import router as _structure_router
|
|
||||||
from ocr_pipeline_columns import router as _columns_router
|
|
||||||
|
|
||||||
# Assemble the combined router.
|
|
||||||
# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix.
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_deskew_router)
|
|
||||||
router.include_router(_dewarp_router)
|
|
||||||
router.include_router(_structure_router)
|
|
||||||
router.include_router(_columns_router)
|
|
||||||
|
|||||||
@@ -1,209 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/llm_review.py
|
||||||
OCR Pipeline LLM Review — LLM-based correction endpoints.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_postprocess.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.llm_review")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
OLLAMA_REVIEW_MODEL,
|
|
||||||
llm_review_entries,
|
|
||||||
llm_review_entries_streaming,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_append_pipeline_log,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Step 8: LLM Review
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/llm-review")
|
|
||||||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
|
||||||
"""Run LLM-based correction on vocab entries from Step 5.
|
|
||||||
|
|
||||||
Query params:
|
|
||||||
stream: false (default) for JSON response, true for SSE streaming
|
|
||||||
"""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
|
||||||
|
|
||||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
|
||||||
if not entries:
|
|
||||||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
|
||||||
|
|
||||||
# Optional model override from request body
|
|
||||||
body = {}
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(
|
|
||||||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-streaming path
|
|
||||||
try:
|
|
||||||
result = await llm_review_entries(entries, model=model)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
|
||||||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
|
||||||
|
|
||||||
# Store result inside word_result as a sub-key
|
|
||||||
word_result["llm_review"] = {
|
|
||||||
"changes": result["changes"],
|
|
||||||
"model_used": result["model_used"],
|
|
||||||
"duration_ms": result["duration_ms"],
|
|
||||||
"entries_corrected": result["entries_corrected"],
|
|
||||||
}
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
|
||||||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "correction", {
|
|
||||||
"engine": "llm",
|
|
||||||
"model": result["model_used"],
|
|
||||||
"total_entries": len(entries),
|
|
||||||
"corrections_proposed": len(result["changes"]),
|
|
||||||
}, duration_ms=result["duration_ms"])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"changes": result["changes"],
|
|
||||||
"model_used": result["model_used"],
|
|
||||||
"duration_ms": result["duration_ms"],
|
|
||||||
"total_entries": len(entries),
|
|
||||||
"corrections_found": len(result["changes"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _llm_review_stream_generator(
|
|
||||||
session_id: str,
|
|
||||||
entries: List[Dict],
|
|
||||||
word_result: Dict,
|
|
||||||
model: str,
|
|
||||||
request: Request,
|
|
||||||
):
|
|
||||||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
|
||||||
try:
|
|
||||||
async for event in llm_review_entries_streaming(entries, model=model):
|
|
||||||
if await request.is_disconnected():
|
|
||||||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# On complete: persist to DB
|
|
||||||
if event.get("type") == "complete":
|
|
||||||
word_result["llm_review"] = {
|
|
||||||
"changes": event["changes"],
|
|
||||||
"model_used": event["model_used"],
|
|
||||||
"duration_ms": event["duration_ms"],
|
|
||||||
"entries_corrected": event["entries_corrected"],
|
|
||||||
}
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
|
||||||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
|
||||||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
|
||||||
yield f"data: {json.dumps(error_event)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/llm-review/apply")
|
|
||||||
async def apply_llm_corrections(session_id: str, request: Request):
|
|
||||||
"""Apply selected LLM corrections to vocab entries."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found")
|
|
||||||
|
|
||||||
llm_review = word_result.get("llm_review")
|
|
||||||
if not llm_review:
|
|
||||||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
|
||||||
|
|
||||||
changes = llm_review.get("changes", [])
|
|
||||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
|
||||||
|
|
||||||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
|
||||||
corrections = {}
|
|
||||||
applied_count = 0
|
|
||||||
for idx, change in enumerate(changes):
|
|
||||||
if idx in accepted_indices:
|
|
||||||
key = (change["row_index"], change["field"])
|
|
||||||
corrections[key] = change["new"]
|
|
||||||
applied_count += 1
|
|
||||||
|
|
||||||
# Apply corrections to entries
|
|
||||||
for entry in entries:
|
|
||||||
row_idx = entry.get("row_index", -1)
|
|
||||||
for field_name in ("english", "german", "example"):
|
|
||||||
key = (row_idx, field_name)
|
|
||||||
if key in corrections:
|
|
||||||
entry[field_name] = corrections[key]
|
|
||||||
entry["llm_corrected"] = True
|
|
||||||
|
|
||||||
# Update word_result
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
word_result["entries"] = entries
|
|
||||||
word_result["llm_review"]["applied_count"] = applied_count
|
|
||||||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"applied_count": applied_count,
|
|
||||||
"total_changes": len(changes),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,266 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/ocr_merge.py
|
||||||
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Merge helper functions live in ocr_merge_helpers.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.ocr_merge")
|
||||||
This module re-exports them for backward compatibility.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
|
|
||||||
from cv_words_first import build_grid_from_words
|
|
||||||
from ocr_pipeline_common import _cache, _append_pipeline_log
|
|
||||||
from ocr_pipeline_session_store import get_session_image, update_session_db
|
|
||||||
|
|
||||||
# Re-export merge helpers for backward compatibility
|
|
||||||
from ocr_merge_helpers import ( # noqa: F401
|
|
||||||
_split_paddle_multi_words,
|
|
||||||
_group_words_into_rows,
|
|
||||||
_row_center_y,
|
|
||||||
_merge_row_sequences,
|
|
||||||
_merge_paddle_tesseract,
|
|
||||||
_deduplicate_words,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
def _run_tesseract_words(img_bgr) -> list:
|
|
||||||
"""Run Tesseract OCR on an image and return word dicts."""
|
|
||||||
from PIL import Image
|
|
||||||
import pytesseract
|
|
||||||
|
|
||||||
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
|
||||||
data = pytesseract.image_to_data(
|
|
||||||
pil_img, lang="eng+deu",
|
|
||||||
config="--psm 6 --oem 3",
|
|
||||||
output_type=pytesseract.Output.DICT,
|
|
||||||
)
|
|
||||||
tess_words = []
|
|
||||||
for i in range(len(data["text"])):
|
|
||||||
text = str(data["text"][i]).strip()
|
|
||||||
conf_raw = str(data["conf"][i])
|
|
||||||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
|
||||||
if not text or conf < 20:
|
|
||||||
continue
|
|
||||||
tess_words.append({
|
|
||||||
"text": text,
|
|
||||||
"left": data["left"][i],
|
|
||||||
"top": data["top"][i],
|
|
||||||
"width": data["width"][i],
|
|
||||||
"height": data["height"][i],
|
|
||||||
"conf": conf,
|
|
||||||
})
|
|
||||||
return tess_words
|
|
||||||
|
|
||||||
|
|
||||||
def _build_kombi_word_result(
|
|
||||||
cells: list,
|
|
||||||
columns_meta: list,
|
|
||||||
img_w: int,
|
|
||||||
img_h: int,
|
|
||||||
duration: float,
|
|
||||||
engine_name: str,
|
|
||||||
raw_engine_words: list,
|
|
||||||
raw_engine_words_split: list,
|
|
||||||
tess_words: list,
|
|
||||||
merged_words: list,
|
|
||||||
raw_engine_key: str = "raw_paddle_words",
|
|
||||||
raw_split_key: str = "raw_paddle_words_split",
|
|
||||||
) -> dict:
|
|
||||||
"""Build the word_result dict for kombi endpoints."""
|
|
||||||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
|
||||||
n_cols = len(columns_meta)
|
|
||||||
col_types = {c.get("type") for c in columns_meta}
|
|
||||||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": engine_name,
|
|
||||||
"grid_method": engine_name,
|
|
||||||
raw_engine_key: raw_engine_words,
|
|
||||||
raw_split_key: raw_engine_words_split,
|
|
||||||
"raw_tesseract_words": tess_words,
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
|
|
||||||
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
|
|
||||||
"tesseract_words": len(tess_words),
|
|
||||||
"merged_words": len(merged_words),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_session_image(session_id: str):
|
|
||||||
"""Load preprocessed image for kombi endpoints."""
|
|
||||||
img_png = await get_session_image(session_id, "cropped")
|
|
||||||
if not img_png:
|
|
||||||
img_png = await get_session_image(session_id, "dewarped")
|
|
||||||
if not img_png:
|
|
||||||
img_png = await get_session_image(session_id, "original")
|
|
||||||
if not img_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No image found for this session")
|
|
||||||
|
|
||||||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
|
||||||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
|
||||||
|
|
||||||
return img_png, img_bgr
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Kombi endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/paddle-kombi")
|
|
||||||
async def paddle_kombi(session_id: str):
|
|
||||||
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
|
|
||||||
img_png, img_bgr = await _load_session_image(session_id)
|
|
||||||
img_h, img_w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
from cv_ocr_engines import ocr_region_paddle
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
|
||||||
if not paddle_words:
|
|
||||||
paddle_words = []
|
|
||||||
|
|
||||||
tess_words = _run_tesseract_words(img_bgr)
|
|
||||||
|
|
||||||
paddle_words_split = _split_paddle_multi_words(paddle_words)
|
|
||||||
logger.info(
|
|
||||||
"paddle_kombi: split %d paddle boxes -> %d individual words",
|
|
||||||
len(paddle_words), len(paddle_words_split),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not paddle_words_split and not tess_words:
|
|
||||||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
|
||||||
|
|
||||||
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
|
|
||||||
merged_words = _deduplicate_words(merged_words)
|
|
||||||
|
|
||||||
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
for cell in cells:
|
|
||||||
cell["ocr_engine"] = "kombi"
|
|
||||||
|
|
||||||
word_result = _build_kombi_word_result(
|
|
||||||
cells, columns_meta, img_w, img_h, duration, "kombi",
|
|
||||||
paddle_words, paddle_words_split, tess_words, merged_words,
|
|
||||||
"raw_paddle_words", "raw_paddle_words_split",
|
|
||||||
)
|
|
||||||
|
|
||||||
await update_session_db(
|
|
||||||
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
|
||||||
)
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
|
||||||
"[paddle=%d, tess=%d, merged=%d]",
|
|
||||||
session_id, len(cells), word_result["grid_shape"]["rows"],
|
|
||||||
word_result["grid_shape"]["cols"], duration,
|
|
||||||
len(paddle_words), len(tess_words), len(merged_words),
|
|
||||||
)
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "paddle_kombi", {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
|
||||||
"paddle_words": len(paddle_words),
|
|
||||||
"tesseract_words": len(tess_words),
|
|
||||||
"merged_words": len(merged_words),
|
|
||||||
"ocr_engine": "kombi",
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **word_result}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/rapid-kombi")
|
|
||||||
async def rapid_kombi(session_id: str):
|
|
||||||
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
|
|
||||||
img_png, img_bgr = await _load_session_image(session_id)
|
|
||||||
img_h, img_w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
from cv_ocr_engines import ocr_region_rapid
|
|
||||||
from cv_vocab_types import PageRegion
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
full_region = PageRegion(
|
|
||||||
type="full_page", x=0, y=0, width=img_w, height=img_h,
|
|
||||||
)
|
|
||||||
rapid_words = ocr_region_rapid(img_bgr, full_region)
|
|
||||||
if not rapid_words:
|
|
||||||
rapid_words = []
|
|
||||||
|
|
||||||
tess_words = _run_tesseract_words(img_bgr)
|
|
||||||
|
|
||||||
rapid_words_split = _split_paddle_multi_words(rapid_words)
|
|
||||||
logger.info(
|
|
||||||
"rapid_kombi: split %d rapid boxes -> %d individual words",
|
|
||||||
len(rapid_words), len(rapid_words_split),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not rapid_words_split and not tess_words:
|
|
||||||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
|
||||||
|
|
||||||
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
|
|
||||||
merged_words = _deduplicate_words(merged_words)
|
|
||||||
|
|
||||||
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
for cell in cells:
|
|
||||||
cell["ocr_engine"] = "rapid_kombi"
|
|
||||||
|
|
||||||
word_result = _build_kombi_word_result(
|
|
||||||
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
|
|
||||||
rapid_words, rapid_words_split, tess_words, merged_words,
|
|
||||||
"raw_rapid_words", "raw_rapid_words_split",
|
|
||||||
)
|
|
||||||
|
|
||||||
await update_session_db(
|
|
||||||
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
|
||||||
)
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
|
||||||
"[rapid=%d, tess=%d, merged=%d]",
|
|
||||||
session_id, len(cells), word_result["grid_shape"]["rows"],
|
|
||||||
word_result["grid_shape"]["cols"], duration,
|
|
||||||
len(rapid_words), len(tess_words), len(merged_words),
|
|
||||||
)
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "rapid_kombi", {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
|
||||||
"rapid_words": len(rapid_words),
|
|
||||||
"tesseract_words": len(tess_words),
|
|
||||||
"merged_words": len(merged_words),
|
|
||||||
"ocr_engine": "rapid_kombi",
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **word_result}
|
|
||||||
|
|||||||
@@ -1,333 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/overlay_grid.py
|
||||||
Overlay rendering for columns, rows, and words (grid-based overlays).
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_overlays.py for modularity.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_grid")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from ocr_pipeline_common import _get_base_image_png
|
|
||||||
from ocr_pipeline_session_store import get_session_db
|
|
||||||
from ocr_pipeline_rows import _draw_box_exclusion_overlay
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_columns_overlay(session_id: str) -> Response:
|
|
||||||
"""Generate cropped (or dewarped) image with column borders drawn on it."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
column_result = session.get("column_result")
|
|
||||||
if not column_result or not column_result.get("columns"):
|
|
||||||
raise HTTPException(status_code=404, detail="No column data available")
|
|
||||||
|
|
||||||
# Load best available base image (cropped > dewarped > original)
|
|
||||||
base_png = await _get_base_image_png(session_id)
|
|
||||||
if not base_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No base image available")
|
|
||||||
|
|
||||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
|
|
||||||
# Color map for region types (BGR)
|
|
||||||
colors = {
|
|
||||||
"column_en": (255, 180, 0), # Blue
|
|
||||||
"column_de": (0, 200, 0), # Green
|
|
||||||
"column_example": (0, 140, 255), # Orange
|
|
||||||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
|
||||||
"page_ref": (200, 0, 200), # Purple
|
|
||||||
"column_marker": (0, 0, 220), # Red
|
|
||||||
"column_ignore": (180, 180, 180), # Light Gray
|
|
||||||
"header": (128, 128, 128), # Gray
|
|
||||||
"footer": (128, 128, 128), # Gray
|
|
||||||
"margin_top": (100, 100, 100), # Dark Gray
|
|
||||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
|
||||||
}
|
|
||||||
|
|
||||||
overlay = img.copy()
|
|
||||||
for col in column_result["columns"]:
|
|
||||||
x, y = col["x"], col["y"]
|
|
||||||
w, h = col["width"], col["height"]
|
|
||||||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
|
||||||
|
|
||||||
# Semi-transparent fill
|
|
||||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
|
||||||
|
|
||||||
# Solid border
|
|
||||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
|
||||||
|
|
||||||
# Label with confidence
|
|
||||||
label = col.get("type", "unknown").replace("column_", "").upper()
|
|
||||||
conf = col.get("classification_confidence")
|
|
||||||
if conf is not None and conf < 1.0:
|
|
||||||
label = f"{label} {int(conf * 100)}%"
|
|
||||||
cv2.putText(img, label, (x + 10, y + 30),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
|
||||||
|
|
||||||
# Blend overlay at 20% opacity
|
|
||||||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
|
||||||
|
|
||||||
# Draw detected box boundaries as dashed rectangles
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
|
||||||
box = zone["box"]
|
|
||||||
bx, by = box["x"], box["y"]
|
|
||||||
bw, bh = box["width"], box["height"]
|
|
||||||
box_color = (0, 200, 255) # Yellow (BGR)
|
|
||||||
# Draw dashed rectangle by drawing short line segments
|
|
||||||
dash_len = 15
|
|
||||||
for edge_x in range(bx, bx + bw, dash_len * 2):
|
|
||||||
end_x = min(edge_x + dash_len, bx + bw)
|
|
||||||
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
|
|
||||||
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
|
|
||||||
for edge_y in range(by, by + bh, dash_len * 2):
|
|
||||||
end_y = min(edge_y + dash_len, by + bh)
|
|
||||||
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
|
|
||||||
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
|
|
||||||
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
|
|
||||||
|
|
||||||
# Red semi-transparent overlay for box zones
|
|
||||||
_draw_box_exclusion_overlay(img, zones)
|
|
||||||
|
|
||||||
success, result_png = cv2.imencode(".png", img)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
|
||||||
|
|
||||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_rows_overlay(session_id: str) -> Response:
|
|
||||||
"""Generate cropped (or dewarped) image with row bands drawn on it."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
row_result = session.get("row_result")
|
|
||||||
if not row_result or not row_result.get("rows"):
|
|
||||||
raise HTTPException(status_code=404, detail="No row data available")
|
|
||||||
|
|
||||||
# Load best available base image (cropped > dewarped > original)
|
|
||||||
base_png = await _get_base_image_png(session_id)
|
|
||||||
if not base_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No base image available")
|
|
||||||
|
|
||||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
|
|
||||||
# Color map for row types (BGR)
|
|
||||||
row_colors = {
|
|
||||||
"content": (255, 180, 0), # Blue
|
|
||||||
"header": (128, 128, 128), # Gray
|
|
||||||
"footer": (128, 128, 128), # Gray
|
|
||||||
"margin_top": (100, 100, 100), # Dark Gray
|
|
||||||
"margin_bottom": (100, 100, 100), # Dark Gray
|
|
||||||
}
|
|
||||||
|
|
||||||
overlay = img.copy()
|
|
||||||
for row in row_result["rows"]:
|
|
||||||
x, y = row["x"], row["y"]
|
|
||||||
w, h = row["width"], row["height"]
|
|
||||||
row_type = row.get("row_type", "content")
|
|
||||||
color = row_colors.get(row_type, (200, 200, 200))
|
|
||||||
|
|
||||||
# Semi-transparent fill
|
|
||||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
|
||||||
|
|
||||||
# Solid border
|
|
||||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
|
||||||
|
|
||||||
# Label
|
|
||||||
idx = row.get("index", 0)
|
|
||||||
label = f"R{idx} {row_type.upper()}"
|
|
||||||
wc = row.get("word_count", 0)
|
|
||||||
if wc:
|
|
||||||
label = f"{label} ({wc}w)"
|
|
||||||
cv2.putText(img, label, (x + 5, y + 18),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
|
||||||
|
|
||||||
# Blend overlay at 15% opacity
|
|
||||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
|
||||||
|
|
||||||
# Draw zone separator lines if zones exist
|
|
||||||
column_result = session.get("column_result") or {}
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
if zones:
|
|
||||||
img_w_px = img.shape[1]
|
|
||||||
zone_color = (0, 200, 255) # Yellow (BGR)
|
|
||||||
dash_len = 20
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") == "box":
|
|
||||||
zy = zone["y"]
|
|
||||||
zh = zone["height"]
|
|
||||||
for line_y in [zy, zy + zh]:
|
|
||||||
for sx in range(0, img_w_px, dash_len * 2):
|
|
||||||
ex = min(sx + dash_len, img_w_px)
|
|
||||||
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
|
|
||||||
|
|
||||||
# Red semi-transparent overlay for box zones
|
|
||||||
_draw_box_exclusion_overlay(img, zones)
|
|
||||||
|
|
||||||
success, result_png = cv2.imencode(".png", img)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
|
||||||
|
|
||||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_words_overlay(session_id: str) -> Response:
|
|
||||||
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=404, detail="No word data available")
|
|
||||||
|
|
||||||
# Support both new cell-based and legacy entry-based formats
|
|
||||||
cells = word_result.get("cells")
|
|
||||||
if not cells and not word_result.get("entries"):
|
|
||||||
raise HTTPException(status_code=404, detail="No word data available")
|
|
||||||
|
|
||||||
# Load best available base image (cropped > dewarped > original)
|
|
||||||
base_png = await _get_base_image_png(session_id)
|
|
||||||
if not base_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No base image available")
|
|
||||||
|
|
||||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
|
|
||||||
img_h, img_w = img.shape[:2]
|
|
||||||
|
|
||||||
overlay = img.copy()
|
|
||||||
|
|
||||||
if cells:
|
|
||||||
# New cell-based overlay: color by column index
|
|
||||||
col_palette = [
|
|
||||||
(255, 180, 0), # Blue (BGR)
|
|
||||||
(0, 200, 0), # Green
|
|
||||||
(0, 140, 255), # Orange
|
|
||||||
(200, 100, 200), # Purple
|
|
||||||
(200, 200, 0), # Cyan
|
|
||||||
(100, 200, 200), # Yellow-ish
|
|
||||||
]
|
|
||||||
|
|
||||||
for cell in cells:
|
|
||||||
bbox = cell.get("bbox_px", {})
|
|
||||||
cx = bbox.get("x", 0)
|
|
||||||
cy = bbox.get("y", 0)
|
|
||||||
cw = bbox.get("w", 0)
|
|
||||||
ch = bbox.get("h", 0)
|
|
||||||
if cw <= 0 or ch <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
col_idx = cell.get("col_index", 0)
|
|
||||||
color = col_palette[col_idx % len(col_palette)]
|
|
||||||
|
|
||||||
# Cell rectangle border
|
|
||||||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
|
||||||
# Semi-transparent fill
|
|
||||||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
|
||||||
|
|
||||||
# Cell-ID label (top-left corner)
|
|
||||||
cell_id = cell.get("cell_id", "")
|
|
||||||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
|
||||||
|
|
||||||
# Text label (bottom of cell)
|
|
||||||
text = cell.get("text", "")
|
|
||||||
if text:
|
|
||||||
conf = cell.get("confidence", 0)
|
|
||||||
if conf >= 70:
|
|
||||||
text_color = (0, 180, 0)
|
|
||||||
elif conf >= 50:
|
|
||||||
text_color = (0, 180, 220)
|
|
||||||
else:
|
|
||||||
text_color = (0, 0, 220)
|
|
||||||
|
|
||||||
label = text.replace('\n', ' ')[:30]
|
|
||||||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
|
||||||
else:
|
|
||||||
# Legacy fallback: entry-based overlay (for old sessions)
|
|
||||||
column_result = session.get("column_result")
|
|
||||||
row_result = session.get("row_result")
|
|
||||||
col_colors = {
|
|
||||||
"column_en": (255, 180, 0),
|
|
||||||
"column_de": (0, 200, 0),
|
|
||||||
"column_example": (0, 140, 255),
|
|
||||||
}
|
|
||||||
|
|
||||||
columns = []
|
|
||||||
if column_result and column_result.get("columns"):
|
|
||||||
columns = [c for c in column_result["columns"]
|
|
||||||
if c.get("type", "").startswith("column_")]
|
|
||||||
|
|
||||||
content_rows_data = []
|
|
||||||
if row_result and row_result.get("rows"):
|
|
||||||
content_rows_data = [r for r in row_result["rows"]
|
|
||||||
if r.get("row_type") == "content"]
|
|
||||||
|
|
||||||
for col in columns:
|
|
||||||
col_type = col.get("type", "")
|
|
||||||
color = col_colors.get(col_type, (200, 200, 200))
|
|
||||||
cx, cw = col["x"], col["width"]
|
|
||||||
for row in content_rows_data:
|
|
||||||
ry, rh = row["y"], row["height"]
|
|
||||||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
|
||||||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
|
||||||
|
|
||||||
entries = word_result["entries"]
|
|
||||||
entry_by_row: Dict[int, Dict] = {}
|
|
||||||
for entry in entries:
|
|
||||||
entry_by_row[entry.get("row_index", -1)] = entry
|
|
||||||
|
|
||||||
for row_idx, row in enumerate(content_rows_data):
|
|
||||||
entry = entry_by_row.get(row_idx)
|
|
||||||
if not entry:
|
|
||||||
continue
|
|
||||||
conf = entry.get("confidence", 0)
|
|
||||||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
|
||||||
ry, rh = row["y"], row["height"]
|
|
||||||
for col in columns:
|
|
||||||
col_type = col.get("type", "")
|
|
||||||
cx, cw = col["x"], col["width"]
|
|
||||||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
|
||||||
text = entry.get(field, "") if field else ""
|
|
||||||
if text:
|
|
||||||
label = text.replace('\n', ' ')[:30]
|
|
||||||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
|
||||||
|
|
||||||
# Blend overlay at 10% opacity
|
|
||||||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
|
||||||
|
|
||||||
# Red semi-transparent overlay for box zones
|
|
||||||
column_result = session.get("column_result") or {}
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
_draw_box_exclusion_overlay(img, zones)
|
|
||||||
|
|
||||||
success, result_png = cv2.imencode(".png", img)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
|
||||||
|
|
||||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
|
||||||
|
|||||||
@@ -1,205 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/overlay_structure.py
|
||||||
Overlay rendering for structure detection (boxes, zones, colors, graphics).
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_overlays.py for modularity.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_structure")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from ocr_pipeline_common import _get_base_image_png
|
|
||||||
from ocr_pipeline_session_store import get_session_db
|
|
||||||
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
|
|
||||||
from cv_box_detect import detect_boxes, split_page_into_zones
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_structure_overlay(session_id: str) -> Response:
|
|
||||||
"""Generate overlay image showing detected boxes, zones, and color regions."""
|
|
||||||
base_png = await _get_base_image_png(session_id)
|
|
||||||
if not base_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No base image available")
|
|
||||||
|
|
||||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
# Get structure result (run detection if not cached)
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
structure = (session or {}).get("structure_result")
|
|
||||||
|
|
||||||
if not structure:
|
|
||||||
# Run detection on-the-fly
|
|
||||||
margin = int(min(w, h) * 0.03)
|
|
||||||
content_x, content_y = margin, margin
|
|
||||||
content_w_px = w - 2 * margin
|
|
||||||
content_h_px = h - 2 * margin
|
|
||||||
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
|
|
||||||
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
|
||||||
structure = {
|
|
||||||
"boxes": [
|
|
||||||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
|
||||||
"confidence": b.confidence, "border_thickness": b.border_thickness}
|
|
||||||
for b in boxes
|
|
||||||
],
|
|
||||||
"zones": [
|
|
||||||
{"index": z.index, "zone_type": z.zone_type,
|
|
||||||
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
|
|
||||||
for z in zones
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
overlay = img.copy()
|
|
||||||
|
|
||||||
# --- Draw zone boundaries ---
|
|
||||||
zone_colors = {
|
|
||||||
"content": (200, 200, 200), # light gray
|
|
||||||
"box": (255, 180, 0), # blue-ish (BGR)
|
|
||||||
}
|
|
||||||
for zone in structure.get("zones", []):
|
|
||||||
zx = zone["x"]
|
|
||||||
zy = zone["y"]
|
|
||||||
zw = zone["w"]
|
|
||||||
zh = zone["h"]
|
|
||||||
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
|
|
||||||
|
|
||||||
# Draw zone boundary as dashed line
|
|
||||||
dash_len = 12
|
|
||||||
for edge_x in range(zx, zx + zw, dash_len * 2):
|
|
||||||
end_x = min(edge_x + dash_len, zx + zw)
|
|
||||||
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
|
|
||||||
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
|
|
||||||
|
|
||||||
# Zone label
|
|
||||||
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
|
|
||||||
cv2.putText(img, zone_label, (zx + 5, zy + 15),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
|
|
||||||
|
|
||||||
# --- Draw detected boxes ---
|
|
||||||
# Color map for box backgrounds (BGR)
|
|
||||||
bg_hex_to_bgr = {
|
|
||||||
"#dc2626": (38, 38, 220), # red
|
|
||||||
"#2563eb": (235, 99, 37), # blue
|
|
||||||
"#16a34a": (74, 163, 22), # green
|
|
||||||
"#ea580c": (12, 88, 234), # orange
|
|
||||||
"#9333ea": (234, 51, 147), # purple
|
|
||||||
"#ca8a04": (4, 138, 202), # yellow
|
|
||||||
"#6b7280": (128, 114, 107), # gray
|
|
||||||
}
|
|
||||||
|
|
||||||
for box_data in structure.get("boxes", []):
|
|
||||||
bx = box_data["x"]
|
|
||||||
by = box_data["y"]
|
|
||||||
bw = box_data["w"]
|
|
||||||
bh = box_data["h"]
|
|
||||||
conf = box_data.get("confidence", 0)
|
|
||||||
thickness = box_data.get("border_thickness", 0)
|
|
||||||
bg_hex = box_data.get("bg_color_hex", "#6b7280")
|
|
||||||
bg_name = box_data.get("bg_color_name", "")
|
|
||||||
|
|
||||||
# Box fill color
|
|
||||||
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
|
|
||||||
|
|
||||||
# Semi-transparent fill
|
|
||||||
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
|
|
||||||
|
|
||||||
# Solid border
|
|
||||||
border_color = fill_bgr
|
|
||||||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
|
|
||||||
|
|
||||||
# Label
|
|
||||||
label = f"BOX"
|
|
||||||
if bg_name and bg_name not in ("unknown", "white"):
|
|
||||||
label += f" ({bg_name})"
|
|
||||||
if thickness > 0:
|
|
||||||
label += f" border={thickness}px"
|
|
||||||
label += f" {int(conf * 100)}%"
|
|
||||||
cv2.putText(img, label, (bx + 8, by + 22),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
|
|
||||||
cv2.putText(img, label, (bx + 8, by + 22),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
|
|
||||||
|
|
||||||
# Blend overlay at 15% opacity
|
|
||||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
|
||||||
|
|
||||||
# --- Draw color regions (HSV masks) ---
|
|
||||||
hsv = cv2.cvtColor(
|
|
||||||
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
|
|
||||||
cv2.COLOR_BGR2HSV,
|
|
||||||
)
|
|
||||||
color_bgr_map = {
|
|
||||||
"red": (0, 0, 255),
|
|
||||||
"orange": (0, 140, 255),
|
|
||||||
"yellow": (0, 200, 255),
|
|
||||||
"green": (0, 200, 0),
|
|
||||||
"blue": (255, 150, 0),
|
|
||||||
"purple": (200, 0, 200),
|
|
||||||
}
|
|
||||||
for color_name, ranges in _COLOR_RANGES.items():
|
|
||||||
mask = np.zeros((h, w), dtype=np.uint8)
|
|
||||||
for lower, upper in ranges:
|
|
||||||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
|
||||||
# Only draw if there are significant colored pixels
|
|
||||||
if np.sum(mask > 0) < 100:
|
|
||||||
continue
|
|
||||||
# Draw colored contours
|
|
||||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
if area < 20:
|
|
||||||
continue
|
|
||||||
cv2.drawContours(img, [cnt], -1, draw_color, 2)
|
|
||||||
|
|
||||||
# --- Draw graphic elements ---
|
|
||||||
graphics_data = structure.get("graphics", [])
|
|
||||||
shape_icons = {
|
|
||||||
"image": "IMAGE",
|
|
||||||
"illustration": "ILLUST",
|
|
||||||
}
|
|
||||||
for gfx in graphics_data:
|
|
||||||
gx, gy = gfx["x"], gfx["y"]
|
|
||||||
gw, gh = gfx["w"], gfx["h"]
|
|
||||||
shape = gfx.get("shape", "icon")
|
|
||||||
color_hex = gfx.get("color_hex", "#6b7280")
|
|
||||||
conf = gfx.get("confidence", 0)
|
|
||||||
|
|
||||||
# Pick draw color based on element color (BGR)
|
|
||||||
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
|
|
||||||
|
|
||||||
# Draw bounding box (dashed style via short segments)
|
|
||||||
dash = 6
|
|
||||||
for seg_x in range(gx, gx + gw, dash * 2):
|
|
||||||
end_x = min(seg_x + dash, gx + gw)
|
|
||||||
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
|
|
||||||
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
|
|
||||||
for seg_y in range(gy, gy + gh, dash * 2):
|
|
||||||
end_y = min(seg_y + dash, gy + gh)
|
|
||||||
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
|
|
||||||
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
|
|
||||||
|
|
||||||
# Label
|
|
||||||
icon = shape_icons.get(shape, shape.upper()[:5])
|
|
||||||
label = f"{icon} {int(conf * 100)}%"
|
|
||||||
# White background for readability
|
|
||||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
|
||||||
lx = gx + 2
|
|
||||||
ly = max(gy - 4, th + 4)
|
|
||||||
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
|
|
||||||
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
|
|
||||||
|
|
||||||
# Encode result
|
|
||||||
_, png_buf = cv2.imencode(".png", img)
|
|
||||||
return Response(content=png_buf.tobytes(), media_type="image/png")
|
|
||||||
|
|||||||
@@ -1,34 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/overlays.py
|
||||||
Overlay image rendering for OCR pipeline — barrel re-export.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
All implementation split into:
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlays")
|
||||||
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
|
|
||||||
ocr_pipeline_overlay_grid — columns, rows, words overlays
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401
|
|
||||||
from ocr_pipeline_overlay_grid import ( # noqa: F401
|
|
||||||
_get_columns_overlay,
|
|
||||||
_get_rows_overlay,
|
|
||||||
_get_words_overlay,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def render_overlay(overlay_type: str, session_id: str) -> Response:
|
|
||||||
"""Dispatch to the appropriate overlay renderer."""
|
|
||||||
if overlay_type == "structure":
|
|
||||||
return await _get_structure_overlay(session_id)
|
|
||||||
elif overlay_type == "columns":
|
|
||||||
return await _get_columns_overlay(session_id)
|
|
||||||
elif overlay_type == "rows":
|
|
||||||
return await _get_rows_overlay(session_id)
|
|
||||||
elif overlay_type == "words":
|
|
||||||
return await _get_words_overlay(session_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
|
|
||||||
|
|||||||
@@ -1,26 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/postprocess.py
|
||||||
OCR Pipeline Postprocessing API — composite router assembling LLM review,
|
import importlib as _importlib
|
||||||
reconstruction, export, validation, image detection/generation, and
|
import sys as _sys
|
||||||
handwriting removal endpoints.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.postprocess")
|
||||||
|
|
||||||
Split into sub-modules:
|
|
||||||
ocr_pipeline_llm_review — LLM review + apply corrections
|
|
||||||
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
|
|
||||||
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from ocr_pipeline_llm_review import router as _llm_review_router
|
|
||||||
from ocr_pipeline_reconstruction import router as _reconstruction_router
|
|
||||||
from ocr_pipeline_validation import router as _validation_router
|
|
||||||
|
|
||||||
# Composite router — drop-in replacement for the old monolithic router.
|
|
||||||
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_llm_review_router)
|
|
||||||
router.include_router(_reconstruction_router)
|
|
||||||
router.include_router(_validation_router)
|
|
||||||
|
|||||||
@@ -1,362 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/reconstruction.py
|
||||||
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_postprocess.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reconstruction")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
get_sub_sessions,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import _cache
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Step 9: Reconstruction + Fabric JSON export
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/reconstruction")
|
|
||||||
async def save_reconstruction(session_id: str, request: Request):
|
|
||||||
"""Save edited cell texts from reconstruction step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found")
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
cell_updates = body.get("cells", [])
|
|
||||||
|
|
||||||
if not cell_updates:
|
|
||||||
await update_session_db(session_id, current_step=10)
|
|
||||||
return {"session_id": session_id, "updated": 0}
|
|
||||||
|
|
||||||
# Build update map: cell_id -> new text
|
|
||||||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
|
||||||
|
|
||||||
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
|
||||||
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
|
||||||
main_updates: Dict[str, str] = {}
|
|
||||||
for cell_id, text in update_map.items():
|
|
||||||
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
|
||||||
if m:
|
|
||||||
bi = int(m.group(1))
|
|
||||||
original_id = m.group(2)
|
|
||||||
sub_updates.setdefault(bi, {})[original_id] = text
|
|
||||||
else:
|
|
||||||
main_updates[cell_id] = text
|
|
||||||
|
|
||||||
# Update main session cells
|
|
||||||
cells = word_result.get("cells", [])
|
|
||||||
updated_count = 0
|
|
||||||
for cell in cells:
|
|
||||||
if cell["cell_id"] in main_updates:
|
|
||||||
cell["text"] = main_updates[cell["cell_id"]]
|
|
||||||
cell["status"] = "edited"
|
|
||||||
updated_count += 1
|
|
||||||
|
|
||||||
word_result["cells"] = cells
|
|
||||||
|
|
||||||
# Also update vocab_entries if present
|
|
||||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
|
||||||
if entries:
|
|
||||||
for entry in entries:
|
|
||||||
row_idx = entry.get("row_index", -1)
|
|
||||||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
|
||||||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
|
||||||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
|
||||||
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
|
||||||
if new_text is not None:
|
|
||||||
entry[field_name] = new_text
|
|
||||||
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
if "entries" in word_result:
|
|
||||||
word_result["entries"] = entries
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=10)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["word_result"] = word_result
|
|
||||||
|
|
||||||
# Route sub-session updates
|
|
||||||
sub_updated = 0
|
|
||||||
if sub_updates:
|
|
||||||
subs = await get_sub_sessions(session_id)
|
|
||||||
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
|
||||||
for bi, updates in sub_updates.items():
|
|
||||||
sub_id = sub_by_index.get(bi)
|
|
||||||
if not sub_id:
|
|
||||||
continue
|
|
||||||
sub_session = await get_session_db(sub_id)
|
|
||||||
if not sub_session:
|
|
||||||
continue
|
|
||||||
sub_word = sub_session.get("word_result")
|
|
||||||
if not sub_word:
|
|
||||||
continue
|
|
||||||
sub_cells = sub_word.get("cells", [])
|
|
||||||
for cell in sub_cells:
|
|
||||||
if cell["cell_id"] in updates:
|
|
||||||
cell["text"] = updates[cell["cell_id"]]
|
|
||||||
cell["status"] = "edited"
|
|
||||||
sub_updated += 1
|
|
||||||
sub_word["cells"] = sub_cells
|
|
||||||
await update_session_db(sub_id, word_result=sub_word)
|
|
||||||
if sub_id in _cache:
|
|
||||||
_cache[sub_id]["word_result"] = sub_word
|
|
||||||
|
|
||||||
total_updated = updated_count + sub_updated
|
|
||||||
logger.info(f"Reconstruction saved for session {session_id}: "
|
|
||||||
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"updated": total_updated,
|
|
||||||
"main_updated": updated_count,
|
|
||||||
"sub_updated": sub_updated,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
|
||||||
async def get_fabric_json(session_id: str):
|
|
||||||
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found")
|
|
||||||
|
|
||||||
cells = list(word_result.get("cells", []))
|
|
||||||
img_w = word_result.get("image_width", 800)
|
|
||||||
img_h = word_result.get("image_height", 600)
|
|
||||||
|
|
||||||
# Merge sub-session cells at box positions
|
|
||||||
subs = await get_sub_sessions(session_id)
|
|
||||||
if subs:
|
|
||||||
column_result = session.get("column_result") or {}
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
|
||||||
|
|
||||||
for sub in subs:
|
|
||||||
sub_session = await get_session_db(sub["id"])
|
|
||||||
if not sub_session:
|
|
||||||
continue
|
|
||||||
sub_word = sub_session.get("word_result")
|
|
||||||
if not sub_word or not sub_word.get("cells"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
bi = sub.get("box_index", 0)
|
|
||||||
if bi < len(box_zones):
|
|
||||||
box = box_zones[bi]["box"]
|
|
||||||
box_y, box_x = box["y"], box["x"]
|
|
||||||
else:
|
|
||||||
box_y, box_x = 0, 0
|
|
||||||
|
|
||||||
for cell in sub_word["cells"]:
|
|
||||||
cell_copy = dict(cell)
|
|
||||||
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
|
||||||
cell_copy["source"] = f"box_{bi}"
|
|
||||||
bbox = cell_copy.get("bbox_px", {})
|
|
||||||
if bbox:
|
|
||||||
bbox = dict(bbox)
|
|
||||||
bbox["x"] = bbox.get("x", 0) + box_x
|
|
||||||
bbox["y"] = bbox.get("y", 0) + box_y
|
|
||||||
cell_copy["bbox_px"] = bbox
|
|
||||||
cells.append(cell_copy)
|
|
||||||
|
|
||||||
from services.layout_reconstruction_service import cells_to_fabric_json
|
|
||||||
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
|
||||||
|
|
||||||
return fabric_json
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Vocab entries merged + PDF/DOCX export
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
|
||||||
async def get_merged_vocab_entries(session_id: str):
|
|
||||||
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result") or {}
|
|
||||||
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
|
||||||
|
|
||||||
for e in entries:
|
|
||||||
e.setdefault("source", "main")
|
|
||||||
|
|
||||||
subs = await get_sub_sessions(session_id)
|
|
||||||
if subs:
|
|
||||||
column_result = session.get("column_result") or {}
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
|
||||||
|
|
||||||
for sub in subs:
|
|
||||||
sub_session = await get_session_db(sub["id"])
|
|
||||||
if not sub_session:
|
|
||||||
continue
|
|
||||||
sub_word = sub_session.get("word_result") or {}
|
|
||||||
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
|
||||||
|
|
||||||
bi = sub.get("box_index", 0)
|
|
||||||
box_y = 0
|
|
||||||
if bi < len(box_zones):
|
|
||||||
box_y = box_zones[bi]["box"]["y"]
|
|
||||||
|
|
||||||
for e in sub_entries:
|
|
||||||
e_copy = dict(e)
|
|
||||||
e_copy["source"] = f"box_{bi}"
|
|
||||||
e_copy["source_y"] = box_y
|
|
||||||
entries.append(e_copy)
|
|
||||||
|
|
||||||
def _sort_key(e):
|
|
||||||
if e.get("source", "main") == "main":
|
|
||||||
return e.get("row_index", 0) * 100
|
|
||||||
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
|
||||||
|
|
||||||
entries.sort(key=_sort_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"entries": entries,
|
|
||||||
"total": len(entries),
|
|
||||||
"sources": list(set(e.get("source", "main") for e in entries)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
|
||||||
async def export_reconstruction_pdf(session_id: str):
|
|
||||||
"""Export the reconstructed cell grid as a PDF table."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found")
|
|
||||||
|
|
||||||
cells = word_result.get("cells", [])
|
|
||||||
columns_used = word_result.get("columns_used", [])
|
|
||||||
grid_shape = word_result.get("grid_shape", {})
|
|
||||||
n_rows = grid_shape.get("rows", 0)
|
|
||||||
n_cols = grid_shape.get("cols", 0)
|
|
||||||
|
|
||||||
# Build table data: rows x columns
|
|
||||||
table_data: list[list[str]] = []
|
|
||||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
|
||||||
if not header:
|
|
||||||
header = [f"Col {i}" for i in range(n_cols)]
|
|
||||||
table_data.append(header)
|
|
||||||
|
|
||||||
for r in range(n_rows):
|
|
||||||
row_texts = []
|
|
||||||
for ci in range(n_cols):
|
|
||||||
cell_id = f"R{r:02d}_C{ci}"
|
|
||||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
|
||||||
row_texts.append(cell.get("text", "") if cell else "")
|
|
||||||
table_data.append(row_texts)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from reportlab.lib.pagesizes import A4
|
|
||||||
from reportlab.lib import colors
|
|
||||||
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
|
||||||
import io as _io
|
|
||||||
|
|
||||||
buf = _io.BytesIO()
|
|
||||||
doc = SimpleDocTemplate(buf, pagesize=A4)
|
|
||||||
if not table_data or not table_data[0]:
|
|
||||||
raise HTTPException(status_code=400, detail="No data to export")
|
|
||||||
|
|
||||||
t = Table(table_data)
|
|
||||||
t.setStyle(TableStyle([
|
|
||||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
|
||||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
|
||||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
|
||||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
|
||||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
|
||||||
('WORDWRAP', (0, 0), (-1, -1), True),
|
|
||||||
]))
|
|
||||||
doc.build([t])
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
buf,
|
|
||||||
media_type="application/pdf",
|
|
||||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
raise HTTPException(status_code=501, detail="reportlab not installed")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
|
||||||
async def export_reconstruction_docx(session_id: str):
|
|
||||||
"""Export the reconstructed cell grid as a DOCX table."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if not word_result:
|
|
||||||
raise HTTPException(status_code=400, detail="No word result found")
|
|
||||||
|
|
||||||
cells = word_result.get("cells", [])
|
|
||||||
columns_used = word_result.get("columns_used", [])
|
|
||||||
grid_shape = word_result.get("grid_shape", {})
|
|
||||||
n_rows = grid_shape.get("rows", 0)
|
|
||||||
n_cols = grid_shape.get("cols", 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from docx import Document
|
|
||||||
from docx.shared import Pt
|
|
||||||
import io as _io
|
|
||||||
|
|
||||||
doc = Document()
|
|
||||||
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
|
|
||||||
|
|
||||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
|
||||||
if not header:
|
|
||||||
header = [f"Col {i}" for i in range(n_cols)]
|
|
||||||
|
|
||||||
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
|
||||||
table.style = 'Table Grid'
|
|
||||||
|
|
||||||
for ci, h in enumerate(header):
|
|
||||||
table.rows[0].cells[ci].text = h
|
|
||||||
|
|
||||||
for r in range(n_rows):
|
|
||||||
for ci in range(n_cols):
|
|
||||||
cell_id = f"R{r:02d}_C{ci}"
|
|
||||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
|
||||||
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
|
||||||
|
|
||||||
buf = _io.BytesIO()
|
|
||||||
doc.save(buf)
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
buf,
|
|
||||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
||||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
raise HTTPException(status_code=501, detail="python-docx not installed")
|
|
||||||
|
|||||||
@@ -1,22 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/regression.py
|
||||||
OCR Pipeline Regression Tests — barrel re-export.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
All implementation split into:
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression")
|
||||||
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
|
|
||||||
ocr_pipeline_regression_endpoints — FastAPI routes
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Helpers (used by grid_editor_api_grid.py)
|
|
||||||
from ocr_pipeline_regression_helpers import ( # noqa: F401
|
|
||||||
_init_regression_table,
|
|
||||||
_persist_regression_run,
|
|
||||||
_extract_cells_for_comparison,
|
|
||||||
_build_reference_snapshot,
|
|
||||||
compare_grids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Endpoints (router used by ocr_pipeline_api.py)
|
|
||||||
from ocr_pipeline_regression_endpoints import router # noqa: F401
|
|
||||||
|
|||||||
@@ -1,421 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/regression_endpoints.py
|
||||||
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_regression.py for modularity.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_endpoints")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
|
||||||
|
|
||||||
from grid_editor_api import _build_grid_core
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
list_ground_truth_sessions_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_regression_helpers import (
|
|
||||||
_build_reference_snapshot,
|
|
||||||
_init_regression_table,
|
|
||||||
_persist_regression_run,
|
|
||||||
compare_grids,
|
|
||||||
get_pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/mark-ground-truth")
|
|
||||||
async def mark_ground_truth(
|
|
||||||
session_id: str,
|
|
||||||
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
|
|
||||||
):
|
|
||||||
"""Save the current build-grid result as ground-truth reference."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
grid_result = session.get("grid_editor_result")
|
|
||||||
if not grid_result or not grid_result.get("zones"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="No grid_editor_result found. Run build-grid first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-detect pipeline from word_result if not provided
|
|
||||||
if not pipeline:
|
|
||||||
wr = session.get("word_result") or {}
|
|
||||||
engine = wr.get("ocr_engine", "")
|
|
||||||
if engine in ("kombi", "rapid_kombi"):
|
|
||||||
pipeline = "kombi"
|
|
||||||
elif engine == "paddle_direct":
|
|
||||||
pipeline = "paddle-direct"
|
|
||||||
else:
|
|
||||||
pipeline = "pipeline"
|
|
||||||
|
|
||||||
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
|
|
||||||
|
|
||||||
# Merge into existing ground_truth JSONB
|
|
||||||
gt = session.get("ground_truth") or {}
|
|
||||||
gt["build_grid_reference"] = reference
|
|
||||||
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
|
||||||
|
|
||||||
# Compare with auto-snapshot if available (shows what the user corrected)
|
|
||||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
|
||||||
correction_diff = None
|
|
||||||
if auto_snapshot:
|
|
||||||
correction_diff = compare_grids(auto_snapshot, reference)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
|
||||||
session_id,
|
|
||||||
len(reference["cells"]),
|
|
||||||
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "ok",
|
|
||||||
"session_id": session_id,
|
|
||||||
"cells_saved": len(reference["cells"]),
|
|
||||||
"summary": reference["summary"],
|
|
||||||
"correction_diff": correction_diff,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
|
||||||
async def unmark_ground_truth(session_id: str):
|
|
||||||
"""Remove the ground-truth reference from a session."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
gt = session.get("ground_truth") or {}
|
|
||||||
if "build_grid_reference" not in gt:
|
|
||||||
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
|
||||||
|
|
||||||
del gt["build_grid_reference"]
|
|
||||||
await update_session_db(session_id, ground_truth=gt)
|
|
||||||
|
|
||||||
logger.info("Ground truth removed for session %s", session_id)
|
|
||||||
return {"status": "ok", "session_id": session_id}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/correction-diff")
|
|
||||||
async def get_correction_diff(session_id: str):
|
|
||||||
"""Compare automatic OCR grid with manually corrected ground truth.
|
|
||||||
|
|
||||||
Returns a diff showing exactly which cells the user corrected,
|
|
||||||
broken down by col_type (english, german, ipa, etc.).
|
|
||||||
"""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
gt = session.get("ground_truth") or {}
|
|
||||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
|
||||||
reference = gt.get("build_grid_reference")
|
|
||||||
|
|
||||||
if not auto_snapshot:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
|
||||||
)
|
|
||||||
if not reference:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="No ground truth reference found. Mark as ground truth first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
diff = compare_grids(auto_snapshot, reference)
|
|
||||||
|
|
||||||
# Enrich with per-col_type breakdown
|
|
||||||
col_type_stats: Dict[str, Dict[str, int]] = {}
|
|
||||||
for cell_diff in diff.get("cell_diffs", []):
|
|
||||||
if cell_diff["type"] != "text_change":
|
|
||||||
continue
|
|
||||||
# Find col_type from reference cells
|
|
||||||
cell_id = cell_diff["cell_id"]
|
|
||||||
ref_cell = next(
|
|
||||||
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
|
||||||
if ct not in col_type_stats:
|
|
||||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
|
||||||
col_type_stats[ct]["corrected"] += 1
|
|
||||||
|
|
||||||
# Count total cells per col_type from reference
|
|
||||||
for cell in reference.get("cells", []):
|
|
||||||
ct = cell.get("col_type", "unknown")
|
|
||||||
if ct not in col_type_stats:
|
|
||||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
|
||||||
col_type_stats[ct]["total"] += 1
|
|
||||||
|
|
||||||
# Calculate accuracy per col_type
|
|
||||||
for ct, stats in col_type_stats.items():
|
|
||||||
total = stats["total"]
|
|
||||||
corrected = stats["corrected"]
|
|
||||||
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
|
||||||
|
|
||||||
diff["col_type_breakdown"] = col_type_stats
|
|
||||||
|
|
||||||
return diff
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ground-truth-sessions")
|
|
||||||
async def list_ground_truth_sessions():
|
|
||||||
"""List all sessions that have a ground-truth reference."""
|
|
||||||
sessions = await list_ground_truth_sessions_db()
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for s in sessions:
|
|
||||||
gt = s.get("ground_truth") or {}
|
|
||||||
ref = gt.get("build_grid_reference", {})
|
|
||||||
result.append({
|
|
||||||
"session_id": s["id"],
|
|
||||||
"name": s.get("name", ""),
|
|
||||||
"filename": s.get("filename", ""),
|
|
||||||
"document_category": s.get("document_category"),
|
|
||||||
"pipeline": ref.get("pipeline"),
|
|
||||||
"saved_at": ref.get("saved_at"),
|
|
||||||
"summary": ref.get("summary", {}),
|
|
||||||
})
|
|
||||||
|
|
||||||
return {"sessions": result, "count": len(result)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/regression/run")
|
|
||||||
async def run_single_regression(session_id: str):
|
|
||||||
"""Re-run build_grid for a single session and compare to ground truth."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
gt = session.get("ground_truth") or {}
|
|
||||||
reference = gt.get("build_grid_reference")
|
|
||||||
if not reference:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="No ground truth reference found for this session",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-compute grid without persisting
|
|
||||||
try:
|
|
||||||
new_result = await _build_grid_core(session_id, session)
|
|
||||||
except (ValueError, Exception) as e:
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": session.get("name", ""),
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
new_snapshot = _build_reference_snapshot(new_result)
|
|
||||||
diff = compare_grids(reference, new_snapshot)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
|
||||||
session_id, diff["status"],
|
|
||||||
diff["summary"]["structural_changes"],
|
|
||||||
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": session.get("name", ""),
|
|
||||||
"status": diff["status"],
|
|
||||||
"diff": diff,
|
|
||||||
"reference_summary": reference.get("summary", {}),
|
|
||||||
"current_summary": new_snapshot.get("summary", {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/regression/run")
|
|
||||||
async def run_all_regressions(
|
|
||||||
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
|
|
||||||
):
|
|
||||||
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
|
||||||
start_time = time.monotonic()
|
|
||||||
sessions = await list_ground_truth_sessions_db()
|
|
||||||
|
|
||||||
if not sessions:
|
|
||||||
return {
|
|
||||||
"status": "pass",
|
|
||||||
"message": "No ground truth sessions found",
|
|
||||||
"results": [],
|
|
||||||
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
passed = 0
|
|
||||||
failed = 0
|
|
||||||
errors = 0
|
|
||||||
|
|
||||||
for s in sessions:
|
|
||||||
session_id = s["id"]
|
|
||||||
gt = s.get("ground_truth") or {}
|
|
||||||
reference = gt.get("build_grid_reference")
|
|
||||||
if not reference:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Re-load full session (list query may not include all JSONB fields)
|
|
||||||
full_session = await get_session_db(session_id)
|
|
||||||
if not full_session:
|
|
||||||
results.append({
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": s.get("name", ""),
|
|
||||||
"status": "error",
|
|
||||||
"error": "Session not found during re-load",
|
|
||||||
})
|
|
||||||
errors += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
new_result = await _build_grid_core(session_id, full_session)
|
|
||||||
except (ValueError, Exception) as e:
|
|
||||||
results.append({
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": s.get("name", ""),
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
})
|
|
||||||
errors += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_snapshot = _build_reference_snapshot(new_result)
|
|
||||||
diff = compare_grids(reference, new_snapshot)
|
|
||||||
|
|
||||||
entry = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": s.get("name", ""),
|
|
||||||
"status": diff["status"],
|
|
||||||
"diff_summary": diff["summary"],
|
|
||||||
"reference_summary": reference.get("summary", {}),
|
|
||||||
"current_summary": new_snapshot.get("summary", {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include full diffs only for failures (keep response compact)
|
|
||||||
if diff["status"] == "fail":
|
|
||||||
entry["structural_diffs"] = diff["structural_diffs"]
|
|
||||||
entry["cell_diffs"] = diff["cell_diffs"]
|
|
||||||
failed += 1
|
|
||||||
else:
|
|
||||||
passed += 1
|
|
||||||
|
|
||||||
results.append(entry)
|
|
||||||
|
|
||||||
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
|
||||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
|
||||||
|
|
||||||
summary = {
|
|
||||||
"total": len(results),
|
|
||||||
"passed": passed,
|
|
||||||
"failed": failed,
|
|
||||||
"errors": errors,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
|
|
||||||
overall, passed, failed, errors, len(results), duration_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
run_id = await _persist_regression_run(
|
|
||||||
status=overall,
|
|
||||||
summary=summary,
|
|
||||||
results=results,
|
|
||||||
duration_ms=duration_ms,
|
|
||||||
triggered_by=triggered_by,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": overall,
|
|
||||||
"run_id": run_id,
|
|
||||||
"duration_ms": duration_ms,
|
|
||||||
"results": results,
|
|
||||||
"summary": summary,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/regression/history")
|
|
||||||
async def get_regression_history(
|
|
||||||
limit: int = Query(20, ge=1, le=100),
|
|
||||||
):
|
|
||||||
"""Get recent regression run history from the database."""
|
|
||||||
try:
|
|
||||||
await _init_regression_table()
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
rows = await conn.fetch(
|
|
||||||
"""
|
|
||||||
SELECT id, run_at, status, total, passed, failed, errors,
|
|
||||||
duration_ms, triggered_by
|
|
||||||
FROM regression_runs
|
|
||||||
ORDER BY run_at DESC
|
|
||||||
LIMIT $1
|
|
||||||
""",
|
|
||||||
limit,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"runs": [
|
|
||||||
{
|
|
||||||
"id": str(row["id"]),
|
|
||||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
|
||||||
"status": row["status"],
|
|
||||||
"total": row["total"],
|
|
||||||
"passed": row["passed"],
|
|
||||||
"failed": row["failed"],
|
|
||||||
"errors": row["errors"],
|
|
||||||
"duration_ms": row["duration_ms"],
|
|
||||||
"triggered_by": row["triggered_by"],
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
],
|
|
||||||
"count": len(rows),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to fetch regression history: %s", e)
|
|
||||||
return {"runs": [], "count": 0, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/regression/history/{run_id}")
|
|
||||||
async def get_regression_run_detail(run_id: str):
|
|
||||||
"""Get detailed results of a specific regression run."""
|
|
||||||
try:
|
|
||||||
await _init_regression_table()
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
row = await conn.fetchrow(
|
|
||||||
"SELECT * FROM regression_runs WHERE id = $1",
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Run not found")
|
|
||||||
return {
|
|
||||||
"id": str(row["id"]),
|
|
||||||
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
|
||||||
"status": row["status"],
|
|
||||||
"total": row["total"],
|
|
||||||
"passed": row["passed"],
|
|
||||||
"failed": row["failed"],
|
|
||||||
"errors": row["errors"],
|
|
||||||
"duration_ms": row["duration_ms"],
|
|
||||||
"triggered_by": row["triggered_by"],
|
|
||||||
"results": json.loads(row["results"]) if row["results"] else [],
|
|
||||||
}
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|||||||
@@ -1,207 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/regression_helpers.py
|
||||||
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_regression.py for modularity.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_helpers")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from ocr_pipeline_session_store import get_pool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# DB persistence for regression runs
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _init_regression_table():
|
|
||||||
"""Ensure regression_runs table exists (idempotent)."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
migration_path = os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
"migrations/008_regression_runs.sql",
|
|
||||||
)
|
|
||||||
if os.path.exists(migration_path):
|
|
||||||
with open(migration_path, "r") as f:
|
|
||||||
sql = f.read()
|
|
||||||
await conn.execute(sql)
|
|
||||||
|
|
||||||
|
|
||||||
async def _persist_regression_run(
|
|
||||||
status: str,
|
|
||||||
summary: dict,
|
|
||||||
results: list,
|
|
||||||
duration_ms: int,
|
|
||||||
triggered_by: str = "manual",
|
|
||||||
) -> str:
|
|
||||||
"""Save a regression run to the database. Returns the run ID."""
|
|
||||||
try:
|
|
||||||
await _init_regression_table()
|
|
||||||
pool = await get_pool()
|
|
||||||
run_id = str(uuid.uuid4())
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO regression_runs
|
|
||||||
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
|
|
||||||
""",
|
|
||||||
run_id,
|
|
||||||
status,
|
|
||||||
summary.get("total", 0),
|
|
||||||
summary.get("passed", 0),
|
|
||||||
summary.get("failed", 0),
|
|
||||||
summary.get("errors", 0),
|
|
||||||
duration_ms,
|
|
||||||
json.dumps(results),
|
|
||||||
triggered_by,
|
|
||||||
)
|
|
||||||
logger.info("Regression run %s persisted: %s", run_id, status)
|
|
||||||
return run_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to persist regression run: %s", e)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
|
||||||
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
|
||||||
|
|
||||||
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
|
||||||
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
|
||||||
"""
|
|
||||||
cells = []
|
|
||||||
for zone in grid_result.get("zones", []):
|
|
||||||
for cell in zone.get("cells", []):
|
|
||||||
cells.append({
|
|
||||||
"cell_id": cell.get("cell_id", ""),
|
|
||||||
"row_index": cell.get("row_index"),
|
|
||||||
"col_index": cell.get("col_index"),
|
|
||||||
"col_type": cell.get("col_type", ""),
|
|
||||||
"text": cell.get("text", ""),
|
|
||||||
})
|
|
||||||
return cells
|
|
||||||
|
|
||||||
|
|
||||||
def _build_reference_snapshot(
|
|
||||||
grid_result: dict,
|
|
||||||
pipeline: Optional[str] = None,
|
|
||||||
) -> dict:
|
|
||||||
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
|
||||||
cells = _extract_cells_for_comparison(grid_result)
|
|
||||||
|
|
||||||
total_zones = len(grid_result.get("zones", []))
|
|
||||||
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
|
||||||
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
|
||||||
|
|
||||||
snapshot = {
|
|
||||||
"saved_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"version": 1,
|
|
||||||
"pipeline": pipeline,
|
|
||||||
"summary": {
|
|
||||||
"total_zones": total_zones,
|
|
||||||
"total_columns": total_columns,
|
|
||||||
"total_rows": total_rows,
|
|
||||||
"total_cells": len(cells),
|
|
||||||
},
|
|
||||||
"cells": cells,
|
|
||||||
}
|
|
||||||
return snapshot
|
|
||||||
|
|
||||||
|
|
||||||
def compare_grids(reference: dict, current: dict) -> dict:
|
|
||||||
"""Compare a reference grid snapshot with a newly computed one.
|
|
||||||
|
|
||||||
Returns a diff report with:
|
|
||||||
- status: "pass" or "fail"
|
|
||||||
- structural_diffs: changes in zone/row/column counts
|
|
||||||
- cell_diffs: list of individual cell changes
|
|
||||||
"""
|
|
||||||
ref_summary = reference.get("summary", {})
|
|
||||||
cur_summary = current.get("summary", {})
|
|
||||||
|
|
||||||
structural_diffs = []
|
|
||||||
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
|
||||||
ref_val = ref_summary.get(key, 0)
|
|
||||||
cur_val = cur_summary.get(key, 0)
|
|
||||||
if ref_val != cur_val:
|
|
||||||
structural_diffs.append({
|
|
||||||
"field": key,
|
|
||||||
"reference": ref_val,
|
|
||||||
"current": cur_val,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Build cell lookup by cell_id
|
|
||||||
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
|
||||||
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
|
||||||
|
|
||||||
cell_diffs: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
# Check for missing cells (in reference but not in current)
|
|
||||||
for cell_id in ref_cells:
|
|
||||||
if cell_id not in cur_cells:
|
|
||||||
cell_diffs.append({
|
|
||||||
"type": "cell_missing",
|
|
||||||
"cell_id": cell_id,
|
|
||||||
"reference_text": ref_cells[cell_id].get("text", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check for added cells (in current but not in reference)
|
|
||||||
for cell_id in cur_cells:
|
|
||||||
if cell_id not in ref_cells:
|
|
||||||
cell_diffs.append({
|
|
||||||
"type": "cell_added",
|
|
||||||
"cell_id": cell_id,
|
|
||||||
"current_text": cur_cells[cell_id].get("text", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check for changes in shared cells
|
|
||||||
for cell_id in ref_cells:
|
|
||||||
if cell_id not in cur_cells:
|
|
||||||
continue
|
|
||||||
ref_cell = ref_cells[cell_id]
|
|
||||||
cur_cell = cur_cells[cell_id]
|
|
||||||
|
|
||||||
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
|
||||||
cell_diffs.append({
|
|
||||||
"type": "text_change",
|
|
||||||
"cell_id": cell_id,
|
|
||||||
"reference": ref_cell.get("text", ""),
|
|
||||||
"current": cur_cell.get("text", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
|
||||||
cell_diffs.append({
|
|
||||||
"type": "col_type_change",
|
|
||||||
"cell_id": cell_id,
|
|
||||||
"reference": ref_cell.get("col_type", ""),
|
|
||||||
"current": cur_cell.get("col_type", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": status,
|
|
||||||
"structural_diffs": structural_diffs,
|
|
||||||
"cell_diffs": cell_diffs,
|
|
||||||
"summary": {
|
|
||||||
"structural_changes": len(structural_diffs),
|
|
||||||
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
|
||||||
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
|
||||||
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
|
||||||
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,94 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/reprocess.py
|
||||||
OCR Pipeline Reprocess Endpoint.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reprocess")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
|
|
||||||
from ocr_pipeline_common import _cache
|
|
||||||
from ocr_pipeline_session_store import get_session_db, update_session_db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/reprocess")
|
|
||||||
async def reprocess_session(session_id: str, request: Request):
|
|
||||||
"""Re-run pipeline from a specific step, clearing downstream data.
|
|
||||||
|
|
||||||
Body: {"from_step": 5} (1-indexed step number)
|
|
||||||
|
|
||||||
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
|
|
||||||
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
|
|
||||||
|
|
||||||
Clears downstream results:
|
|
||||||
- from_step <= 1: orientation_result + all downstream
|
|
||||||
- from_step <= 2: deskew_result + all downstream
|
|
||||||
- from_step <= 3: dewarp_result + all downstream
|
|
||||||
- from_step <= 4: crop_result + all downstream
|
|
||||||
- from_step <= 5: column_result, row_result, word_result
|
|
||||||
- from_step <= 6: row_result, word_result
|
|
||||||
- from_step <= 7: word_result (cells, vocab_entries)
|
|
||||||
- from_step <= 8: word_result.llm_review only
|
|
||||||
"""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
from_step = body.get("from_step", 1)
|
|
||||||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
|
||||||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
|
||||||
|
|
||||||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
|
||||||
|
|
||||||
# Clear downstream data based on from_step
|
|
||||||
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
|
|
||||||
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
|
|
||||||
if from_step <= 8:
|
|
||||||
update_kwargs["word_result"] = None
|
|
||||||
elif from_step == 9:
|
|
||||||
# Only clear LLM review from word_result
|
|
||||||
word_result = session.get("word_result")
|
|
||||||
if word_result:
|
|
||||||
word_result.pop("llm_review", None)
|
|
||||||
word_result.pop("llm_corrections", None)
|
|
||||||
update_kwargs["word_result"] = word_result
|
|
||||||
|
|
||||||
if from_step <= 7:
|
|
||||||
update_kwargs["row_result"] = None
|
|
||||||
if from_step <= 6:
|
|
||||||
update_kwargs["column_result"] = None
|
|
||||||
if from_step <= 4:
|
|
||||||
update_kwargs["crop_result"] = None
|
|
||||||
if from_step <= 3:
|
|
||||||
update_kwargs["dewarp_result"] = None
|
|
||||||
if from_step <= 2:
|
|
||||||
update_kwargs["deskew_result"] = None
|
|
||||||
if from_step <= 1:
|
|
||||||
update_kwargs["orientation_result"] = None
|
|
||||||
|
|
||||||
await update_session_db(session_id, **update_kwargs)
|
|
||||||
|
|
||||||
# Also clear cache
|
|
||||||
if session_id in _cache:
|
|
||||||
for key in list(update_kwargs.keys()):
|
|
||||||
if key != "current_step":
|
|
||||||
_cache[session_id][key] = update_kwargs[key]
|
|
||||||
_cache[session_id]["current_step"] = from_step
|
|
||||||
|
|
||||||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"from_step": from_step,
|
|
||||||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,348 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/rows.py
|
||||||
OCR Pipeline - Row Detection Endpoints.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_api.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.rows")
|
||||||
Handles row detection (auto + manual) and row ground truth.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
create_ocr_image,
|
|
||||||
detect_column_geometry,
|
|
||||||
detect_row_geometry,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_append_pipeline_log,
|
|
||||||
ManualRowsRequest,
|
|
||||||
RowGroundTruthRequest,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _draw_box_exclusion_overlay(
|
|
||||||
img: np.ndarray,
|
|
||||||
zones: List[Dict],
|
|
||||||
*,
|
|
||||||
label: str = "BOX — separat verarbeitet",
|
|
||||||
) -> None:
|
|
||||||
"""Draw red semi-transparent rectangles over box zones (in-place).
|
|
||||||
|
|
||||||
Reusable for columns, rows, and words overlays.
|
|
||||||
"""
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") != "box" or not zone.get("box"):
|
|
||||||
continue
|
|
||||||
box = zone["box"]
|
|
||||||
bx, by = box["x"], box["y"]
|
|
||||||
bw, bh = box["width"], box["height"]
|
|
||||||
|
|
||||||
# Red semi-transparent fill (~25 %)
|
|
||||||
box_overlay = img.copy()
|
|
||||||
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
|
|
||||||
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
|
|
||||||
|
|
||||||
# Border
|
|
||||||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
|
|
||||||
|
|
||||||
# Label
|
|
||||||
cv2.putText(img, label, (bx + 10, by + bh - 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Row Detection Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/rows")
|
|
||||||
async def detect_rows(session_id: str):
|
|
||||||
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
if dewarped_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# Try to reuse cached word_dicts and inv from column detection
|
|
||||||
word_dicts = cached.get("_word_dicts")
|
|
||||||
inv = cached.get("_inv")
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
|
||||||
|
|
||||||
if word_dicts is None or inv is None or content_bounds is None:
|
|
||||||
# Not cached — run column geometry to get intermediates
|
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
|
||||||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
|
||||||
if geo_result is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
|
||||||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
||||||
else:
|
|
||||||
left_x, right_x, top_y, bottom_y = content_bounds
|
|
||||||
|
|
||||||
# Read zones from column_result to exclude box regions
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
column_result = (session or {}).get("column_result") or {}
|
|
||||||
is_sub_session = bool((session or {}).get("parent_session_id"))
|
|
||||||
|
|
||||||
# Sub-sessions (box crops): use word-grouping instead of gap-based
|
|
||||||
# row detection. Box images are small with complex internal layouts
|
|
||||||
# (headings, sub-columns) where the horizontal projection approach
|
|
||||||
# merges rows. Word-grouping directly clusters words by Y proximity,
|
|
||||||
# which is more robust for these cases.
|
|
||||||
if is_sub_session and word_dicts:
|
|
||||||
from cv_layout import _build_rows_from_word_grouping
|
|
||||||
rows = _build_rows_from_word_grouping(
|
|
||||||
word_dicts, left_x, right_x, top_y, bottom_y,
|
|
||||||
right_x - left_x, bottom_y - top_y,
|
|
||||||
)
|
|
||||||
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
|
|
||||||
else:
|
|
||||||
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
|
|
||||||
|
|
||||||
# Collect box y-ranges for filtering.
|
|
||||||
# Use border_thickness to shrink the exclusion zone: the border pixels
|
|
||||||
# belong visually to the box frame, but text rows above/below the box
|
|
||||||
# may overlap with the border area and must not be clipped.
|
|
||||||
box_ranges = [] # [(y_start, y_end)]
|
|
||||||
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
|
||||||
box = zone["box"]
|
|
||||||
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
|
||||||
box_ranges.append((box["y"], box["y"] + box["height"]))
|
|
||||||
# Inner range: shrink by border thickness so boundary rows aren't excluded
|
|
||||||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
|
||||||
|
|
||||||
if box_ranges and inv is not None:
|
|
||||||
# Combined-image approach: strip box regions from inv image,
|
|
||||||
# run row detection on the combined image, then remap y-coords back.
|
|
||||||
content_strips = [] # [(y_start, y_end)] in absolute coords
|
|
||||||
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
|
|
||||||
# Using inner ranges means the border area is included in the content
|
|
||||||
# strips, so the last row above a box isn't clipped by the border.
|
|
||||||
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
|
|
||||||
strip_start = top_y
|
|
||||||
for by_start, by_end in sorted_boxes:
|
|
||||||
if by_start > strip_start:
|
|
||||||
content_strips.append((strip_start, by_start))
|
|
||||||
strip_start = max(strip_start, by_end)
|
|
||||||
if strip_start < bottom_y:
|
|
||||||
content_strips.append((strip_start, bottom_y))
|
|
||||||
|
|
||||||
# Filter to strips with meaningful height
|
|
||||||
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
|
|
||||||
|
|
||||||
if content_strips:
|
|
||||||
# Stack content strips vertically
|
|
||||||
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
|
|
||||||
combined_inv = np.vstack(inv_strips)
|
|
||||||
|
|
||||||
# Filter word_dicts to only include words from content strips
|
|
||||||
combined_words = []
|
|
||||||
cum_y = 0
|
|
||||||
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
|
|
||||||
for ys, ye in content_strips:
|
|
||||||
h = ye - ys
|
|
||||||
strip_offsets.append((cum_y, h, ys))
|
|
||||||
for w in word_dicts:
|
|
||||||
w_abs_y = w['top'] + top_y # word y is relative to content top
|
|
||||||
w_center = w_abs_y + w['height'] / 2
|
|
||||||
if ys <= w_center < ye:
|
|
||||||
# Remap to combined coordinates
|
|
||||||
w_copy = dict(w)
|
|
||||||
w_copy['top'] = cum_y + (w_abs_y - ys)
|
|
||||||
combined_words.append(w_copy)
|
|
||||||
cum_y += h
|
|
||||||
|
|
||||||
# Run row detection on combined image
|
|
||||||
combined_h = combined_inv.shape[0]
|
|
||||||
rows = detect_row_geometry(
|
|
||||||
combined_inv, combined_words, left_x, right_x, 0, combined_h,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remap y-coordinates back to absolute page coords
|
|
||||||
def _combined_y_to_abs(cy: int) -> int:
|
|
||||||
for c_start, s_h, abs_start in strip_offsets:
|
|
||||||
if cy < c_start + s_h:
|
|
||||||
return abs_start + (cy - c_start)
|
|
||||||
last_c, last_h, last_abs = strip_offsets[-1]
|
|
||||||
return last_abs + last_h
|
|
||||||
|
|
||||||
for r in rows:
|
|
||||||
abs_y = _combined_y_to_abs(r.y)
|
|
||||||
abs_y_end = _combined_y_to_abs(r.y + r.height)
|
|
||||||
r.y = abs_y
|
|
||||||
r.height = abs_y_end - abs_y
|
|
||||||
else:
|
|
||||||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
|
||||||
else:
|
|
||||||
# No boxes — standard row detection
|
|
||||||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
# Assign zone_index based on which content zone each row falls in
|
|
||||||
# Build content zone list with indices
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
|
|
||||||
|
|
||||||
# Build serializable result (exclude words to keep payload small)
|
|
||||||
rows_data = []
|
|
||||||
for r in rows:
|
|
||||||
# Determine zone_index
|
|
||||||
zone_idx = 0
|
|
||||||
row_center_y = r.y + r.height / 2
|
|
||||||
for zi, zone in content_zones:
|
|
||||||
zy = zone["y"]
|
|
||||||
zh = zone["height"]
|
|
||||||
if zy <= row_center_y < zy + zh:
|
|
||||||
zone_idx = zi
|
|
||||||
break
|
|
||||||
|
|
||||||
rd = {
|
|
||||||
"index": r.index,
|
|
||||||
"x": r.x,
|
|
||||||
"y": r.y,
|
|
||||||
"width": r.width,
|
|
||||||
"height": r.height,
|
|
||||||
"word_count": r.word_count,
|
|
||||||
"row_type": r.row_type,
|
|
||||||
"gap_before": r.gap_before,
|
|
||||||
"zone_index": zone_idx,
|
|
||||||
}
|
|
||||||
rows_data.append(rd)
|
|
||||||
|
|
||||||
type_counts = {}
|
|
||||||
for r in rows:
|
|
||||||
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
|
||||||
|
|
||||||
row_result = {
|
|
||||||
"rows": rows_data,
|
|
||||||
"summary": type_counts,
|
|
||||||
"total_rows": len(rows),
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Persist to DB — also invalidate word_result since rows changed
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
row_result=row_result,
|
|
||||||
word_result=None,
|
|
||||||
current_step=7,
|
|
||||||
)
|
|
||||||
|
|
||||||
cached["row_result"] = row_result
|
|
||||||
cached.pop("word_result", None)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
|
||||||
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
|
||||||
|
|
||||||
content_rows = sum(1 for r in rows if r.row_type == "content")
|
|
||||||
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
|
|
||||||
await _append_pipeline_log(session_id, "rows", {
|
|
||||||
"total_rows": len(rows),
|
|
||||||
"content_rows": content_rows,
|
|
||||||
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
|
|
||||||
"avg_row_height_px": avg_height,
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**row_result,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/rows/manual")
|
|
||||||
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
|
||||||
"""Override detected rows with manual definitions."""
|
|
||||||
row_result = {
|
|
||||||
"rows": req.rows,
|
|
||||||
"total_rows": len(req.rows),
|
|
||||||
"duration_seconds": 0,
|
|
||||||
"method": "manual",
|
|
||||||
}
|
|
||||||
|
|
||||||
await update_session_db(session_id, row_result=row_result, word_result=None)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["row_result"] = row_result
|
|
||||||
_cache[session_id].pop("word_result", None)
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
|
||||||
f"{len(req.rows)} rows set")
|
|
||||||
|
|
||||||
return {"session_id": session_id, **row_result}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/ground-truth/rows")
|
|
||||||
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
|
||||||
"""Save ground truth feedback for the row detection step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
gt = {
|
|
||||||
"is_correct": req.is_correct,
|
|
||||||
"corrected_rows": req.corrected_rows,
|
|
||||||
"notes": req.notes,
|
|
||||||
"saved_at": datetime.utcnow().isoformat(),
|
|
||||||
"row_result": session.get("row_result"),
|
|
||||||
}
|
|
||||||
ground_truth["rows"] = gt
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
return {"session_id": session_id, "ground_truth": gt}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/ground-truth/rows")
|
|
||||||
async def get_row_ground_truth(session_id: str):
|
|
||||||
"""Retrieve saved ground truth for row detection."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
rows_gt = ground_truth.get("rows")
|
|
||||||
if not rows_gt:
|
|
||||||
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"rows_gt": rows_gt,
|
|
||||||
"rows_auto": session.get("row_result"),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,388 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/session_store.py
|
||||||
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Replaces in-memory storage with database persistence.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.session_store")
|
||||||
See migrations/002_ocr_pipeline_sessions.sql for schema.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Database configuration (same as vocab_session_store)
|
|
||||||
DATABASE_URL = os.getenv(
|
|
||||||
"DATABASE_URL",
|
|
||||||
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Connection pool (initialized lazily)
|
|
||||||
_pool: Optional[asyncpg.Pool] = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_pool() -> asyncpg.Pool:
|
|
||||||
"""Get or create the database connection pool."""
|
|
||||||
global _pool
|
|
||||||
if _pool is None:
|
|
||||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
|
||||||
return _pool
|
|
||||||
|
|
||||||
|
|
||||||
async def init_ocr_pipeline_tables():
|
|
||||||
"""Initialize OCR pipeline tables if they don't exist."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
tables_exist = await conn.fetchval("""
|
|
||||||
SELECT EXISTS (
|
|
||||||
SELECT FROM information_schema.tables
|
|
||||||
WHERE table_name = 'ocr_pipeline_sessions'
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
if not tables_exist:
|
|
||||||
logger.info("Creating OCR pipeline tables...")
|
|
||||||
migration_path = os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
"migrations/002_ocr_pipeline_sessions.sql"
|
|
||||||
)
|
|
||||||
if os.path.exists(migration_path):
|
|
||||||
with open(migration_path, "r") as f:
|
|
||||||
sql = f.read()
|
|
||||||
await conn.execute(sql)
|
|
||||||
logger.info("OCR pipeline tables created successfully")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Migration file not found: {migration_path}")
|
|
||||||
else:
|
|
||||||
logger.debug("OCR pipeline tables already exist")
|
|
||||||
|
|
||||||
# Ensure new columns exist (idempotent ALTER TABLE)
|
|
||||||
await conn.execute("""
|
|
||||||
ALTER TABLE ocr_pipeline_sessions
|
|
||||||
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
|
||||||
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
|
|
||||||
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
|
|
||||||
ADD COLUMN IF NOT EXISTS pipeline_log JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
|
|
||||||
ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
|
|
||||||
ADD COLUMN IF NOT EXISTS orientation_result JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS crop_result JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
|
|
||||||
ADD COLUMN IF NOT EXISTS box_index INT,
|
|
||||||
ADD COLUMN IF NOT EXISTS grid_editor_result JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS structure_result JSONB,
|
|
||||||
ADD COLUMN IF NOT EXISTS document_group_id UUID,
|
|
||||||
ADD COLUMN IF NOT EXISTS page_number INT
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Index for document group lookups
|
|
||||||
await conn.execute("""
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group
|
|
||||||
ON ocr_pipeline_sessions (document_group_id)
|
|
||||||
WHERE document_group_id IS NOT NULL
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# SESSION CRUD
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
async def create_session_db(
|
|
||||||
session_id: str,
|
|
||||||
name: str,
|
|
||||||
filename: str,
|
|
||||||
original_png: bytes,
|
|
||||||
parent_session_id: Optional[str] = None,
|
|
||||||
box_index: Optional[int] = None,
|
|
||||||
document_group_id: Optional[str] = None,
|
|
||||||
page_number: Optional[int] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Create a new OCR pipeline session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parent_session_id: If set, this is a sub-session for a box region.
|
|
||||||
box_index: 0-based index of the box this sub-session represents.
|
|
||||||
document_group_id: Groups multi-page uploads into one document.
|
|
||||||
page_number: 1-based page index within the document group.
|
|
||||||
"""
|
|
||||||
pool = await get_pool()
|
|
||||||
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
|
|
||||||
group_uuid = uuid.UUID(document_group_id) if document_group_id else None
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
row = await conn.fetchrow("""
|
|
||||||
INSERT INTO ocr_pipeline_sessions (
|
|
||||||
id, name, filename, original_png, status, current_step,
|
|
||||||
parent_session_id, box_index, document_group_id, page_number
|
|
||||||
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8)
|
|
||||||
RETURNING id, name, filename, status, current_step,
|
|
||||||
orientation_result, crop_result,
|
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
|
||||||
doc_type, doc_type_result,
|
|
||||||
document_category, pipeline_log,
|
|
||||||
grid_editor_result, structure_result,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at
|
|
||||||
""", uuid.UUID(session_id), name, filename, original_png,
|
|
||||||
parent_uuid, box_index, group_uuid, page_number)
|
|
||||||
|
|
||||||
return _row_to_dict(row)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Get session metadata (without images)."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
row = await conn.fetchrow("""
|
|
||||||
SELECT id, name, filename, status, current_step,
|
|
||||||
orientation_result, crop_result,
|
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
|
||||||
doc_type, doc_type_result,
|
|
||||||
document_category, pipeline_log,
|
|
||||||
grid_editor_result, structure_result,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM ocr_pipeline_sessions WHERE id = $1
|
|
||||||
""", uuid.UUID(session_id))
|
|
||||||
|
|
||||||
if row:
|
|
||||||
return _row_to_dict(row)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
|
|
||||||
"""Load a single image (BYTEA) from the session."""
|
|
||||||
column_map = {
|
|
||||||
"original": "original_png",
|
|
||||||
"oriented": "oriented_png",
|
|
||||||
"cropped": "cropped_png",
|
|
||||||
"deskewed": "deskewed_png",
|
|
||||||
"binarized": "binarized_png",
|
|
||||||
"dewarped": "dewarped_png",
|
|
||||||
"clean": "clean_png",
|
|
||||||
}
|
|
||||||
column = column_map.get(image_type)
|
|
||||||
if not column:
|
|
||||||
return None
|
|
||||||
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
return await conn.fetchval(
|
|
||||||
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
|
|
||||||
uuid.UUID(session_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Update session fields dynamically."""
|
|
||||||
pool = await get_pool()
|
|
||||||
|
|
||||||
fields = []
|
|
||||||
values = []
|
|
||||||
param_idx = 1
|
|
||||||
|
|
||||||
allowed_fields = {
|
|
||||||
'name', 'filename', 'status', 'current_step',
|
|
||||||
'original_png', 'oriented_png', 'cropped_png',
|
|
||||||
'deskewed_png', 'binarized_png', 'dewarped_png',
|
|
||||||
'clean_png', 'handwriting_removal_meta',
|
|
||||||
'orientation_result', 'crop_result',
|
|
||||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
|
||||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
|
||||||
'doc_type', 'doc_type_result',
|
|
||||||
'document_category', 'pipeline_log',
|
|
||||||
'grid_editor_result', 'structure_result',
|
|
||||||
'parent_session_id', 'box_index',
|
|
||||||
'document_group_id', 'page_number',
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'}
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key in allowed_fields:
|
|
||||||
fields.append(f"{key} = ${param_idx}")
|
|
||||||
if key in jsonb_fields and value is not None and not isinstance(value, str):
|
|
||||||
value = json.dumps(value)
|
|
||||||
values.append(value)
|
|
||||||
param_idx += 1
|
|
||||||
|
|
||||||
if not fields:
|
|
||||||
return await get_session_db(session_id)
|
|
||||||
|
|
||||||
# Always update updated_at
|
|
||||||
fields.append(f"updated_at = NOW()")
|
|
||||||
|
|
||||||
values.append(uuid.UUID(session_id))
|
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
row = await conn.fetchrow(f"""
|
|
||||||
UPDATE ocr_pipeline_sessions
|
|
||||||
SET {', '.join(fields)}
|
|
||||||
WHERE id = ${param_idx}
|
|
||||||
RETURNING id, name, filename, status, current_step,
|
|
||||||
orientation_result, crop_result,
|
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
|
||||||
doc_type, doc_type_result,
|
|
||||||
document_category, pipeline_log,
|
|
||||||
grid_editor_result, structure_result,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at
|
|
||||||
""", *values)
|
|
||||||
|
|
||||||
if row:
|
|
||||||
return _row_to_dict(row)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def list_sessions_db(
|
|
||||||
limit: int = 50,
|
|
||||||
include_sub_sessions: bool = False,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""List sessions (metadata only, no images).
|
|
||||||
|
|
||||||
By default, sub-sessions (those with parent_session_id) are excluded.
|
|
||||||
Pass include_sub_sessions=True to include them.
|
|
||||||
"""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')"
|
|
||||||
rows = await conn.fetch(f"""
|
|
||||||
SELECT id, name, filename, status, current_step,
|
|
||||||
document_category, doc_type,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at,
|
|
||||||
ground_truth
|
|
||||||
FROM ocr_pipeline_sessions
|
|
||||||
{where}
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT $1
|
|
||||||
""", limit)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for row in rows:
|
|
||||||
d = _row_to_dict(row)
|
|
||||||
# Derive is_ground_truth flag from JSONB, then drop the heavy field
|
|
||||||
gt = d.pop("ground_truth", None) or {}
|
|
||||||
d["is_ground_truth"] = bool(gt.get("build_grid_reference"))
|
|
||||||
results.append(d)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Get all sub-sessions for a parent session, ordered by box_index."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
rows = await conn.fetch("""
|
|
||||||
SELECT id, name, filename, status, current_step,
|
|
||||||
document_category, doc_type,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM ocr_pipeline_sessions
|
|
||||||
WHERE parent_session_id = $1
|
|
||||||
ORDER BY box_index ASC
|
|
||||||
""", uuid.UUID(parent_session_id))
|
|
||||||
|
|
||||||
return [_row_to_dict(row) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Get all sessions in a document group, ordered by page_number."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
rows = await conn.fetch("""
|
|
||||||
SELECT id, name, filename, status, current_step,
|
|
||||||
document_category, doc_type,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
document_group_id, page_number,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM ocr_pipeline_sessions
|
|
||||||
WHERE document_group_id = $1
|
|
||||||
ORDER BY page_number ASC
|
|
||||||
""", uuid.UUID(document_group_id))
|
|
||||||
|
|
||||||
return [_row_to_dict(row) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
|
|
||||||
"""List sessions that have a build_grid_reference in ground_truth."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
rows = await conn.fetch("""
|
|
||||||
SELECT id, name, filename, status, current_step,
|
|
||||||
document_category, doc_type,
|
|
||||||
ground_truth,
|
|
||||||
parent_session_id, box_index,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM ocr_pipeline_sessions
|
|
||||||
WHERE ground_truth IS NOT NULL
|
|
||||||
AND ground_truth::text LIKE '%build_grid_reference%'
|
|
||||||
AND parent_session_id IS NULL
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
""")
|
|
||||||
|
|
||||||
return [_row_to_dict(row) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_session_db(session_id: str) -> bool:
|
|
||||||
"""Delete a session."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
result = await conn.execute("""
|
|
||||||
DELETE FROM ocr_pipeline_sessions WHERE id = $1
|
|
||||||
""", uuid.UUID(session_id))
|
|
||||||
return result == "DELETE 1"
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_all_sessions_db() -> int:
|
|
||||||
"""Delete all sessions. Returns number of deleted rows."""
|
|
||||||
pool = await get_pool()
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
|
|
||||||
# result is e.g. "DELETE 5"
|
|
||||||
try:
|
|
||||||
return int(result.split()[-1])
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# HELPER
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
|
||||||
"""Convert asyncpg Record to JSON-serializable dict."""
|
|
||||||
if row is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
result = dict(row)
|
|
||||||
|
|
||||||
# UUID → string
|
|
||||||
for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']:
|
|
||||||
if key in result and result[key] is not None:
|
|
||||||
result[key] = str(result[key])
|
|
||||||
|
|
||||||
# datetime → ISO string
|
|
||||||
for key in ['created_at', 'updated_at']:
|
|
||||||
if key in result and result[key] is not None:
|
|
||||||
result[key] = result[key].isoformat()
|
|
||||||
|
|
||||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
|
||||||
for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']:
|
|
||||||
if key in result and result[key] is not None:
|
|
||||||
if isinstance(result[key], str):
|
|
||||||
result[key] = json.loads(result[key])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,20 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/sessions.py
|
||||||
OCR Pipeline Sessions API — barrel re-export.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
All implementation split into:
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions")
|
||||||
ocr_pipeline_sessions_crud — session CRUD, box sessions
|
|
||||||
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401
|
|
||||||
from ocr_pipeline_sessions_images import router as _images_router # noqa: F401
|
|
||||||
|
|
||||||
# Composite router (used by ocr_pipeline_api.py)
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_crud_router)
|
|
||||||
router.include_router(_images_router)
|
|
||||||
|
|||||||
@@ -1,449 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/sessions_crud.py
|
||||||
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_sessions.py for modularity.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_crud")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
VALID_DOCUMENT_CATEGORIES,
|
|
||||||
UpdateSessionRequest,
|
|
||||||
_cache,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
create_session_db,
|
|
||||||
delete_all_sessions_db,
|
|
||||||
delete_session_db,
|
|
||||||
get_session_db,
|
|
||||||
get_session_image,
|
|
||||||
get_sub_sessions,
|
|
||||||
list_sessions_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Session Management Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/sessions")
|
|
||||||
async def list_sessions(include_sub_sessions: bool = False):
|
|
||||||
"""List OCR pipeline sessions.
|
|
||||||
|
|
||||||
By default, sub-sessions (box regions) are hidden.
|
|
||||||
Pass ?include_sub_sessions=true to show them.
|
|
||||||
"""
|
|
||||||
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
|
||||||
return {"sessions": sessions}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions")
|
|
||||||
async def create_session(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
name: Optional[str] = Form(None),
|
|
||||||
):
|
|
||||||
"""Upload a PDF or image file and create a pipeline session.
|
|
||||||
|
|
||||||
For multi-page PDFs (> 1 page), each page becomes its own session
|
|
||||||
grouped under a ``document_group_id``. The response includes a
|
|
||||||
``pages`` array with one entry per page/session.
|
|
||||||
"""
|
|
||||||
file_data = await file.read()
|
|
||||||
filename = file.filename or "upload"
|
|
||||||
content_type = file.content_type or ""
|
|
||||||
|
|
||||||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
|
||||||
session_name = name or filename
|
|
||||||
|
|
||||||
# --- Multi-page PDF handling ---
|
|
||||||
if is_pdf:
|
|
||||||
try:
|
|
||||||
import fitz # PyMuPDF
|
|
||||||
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
|
|
||||||
page_count = pdf_doc.page_count
|
|
||||||
pdf_doc.close()
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
|
|
||||||
|
|
||||||
if page_count > 1:
|
|
||||||
return await _create_multi_page_sessions(
|
|
||||||
file_data, filename, session_name, page_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Single page (image or 1-page PDF) ---
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_pdf:
|
|
||||||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
|
||||||
else:
|
|
||||||
img_bgr = render_image_high_res(file_data)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
|
||||||
|
|
||||||
# Encode original as PNG bytes
|
|
||||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
|
||||||
|
|
||||||
original_png = png_buf.tobytes()
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await create_session_db(
|
|
||||||
session_id=session_id,
|
|
||||||
name=session_name,
|
|
||||||
filename=filename,
|
|
||||||
original_png=original_png,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache BGR array for immediate processing
|
|
||||||
_cache[session_id] = {
|
|
||||||
"id": session_id,
|
|
||||||
"filename": filename,
|
|
||||||
"name": session_name,
|
|
||||||
"original_bgr": img_bgr,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": None,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
"orientation_result": None,
|
|
||||||
"crop_result": None,
|
|
||||||
"deskew_result": None,
|
|
||||||
"dewarp_result": None,
|
|
||||||
"ground_truth": {},
|
|
||||||
"current_step": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
|
||||||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"filename": filename,
|
|
||||||
"name": session_name,
|
|
||||||
"image_width": img_bgr.shape[1],
|
|
||||||
"image_height": img_bgr.shape[0],
|
|
||||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_multi_page_sessions(
|
|
||||||
pdf_data: bytes,
|
|
||||||
filename: str,
|
|
||||||
base_name: str,
|
|
||||||
page_count: int,
|
|
||||||
) -> dict:
|
|
||||||
"""Create one session per PDF page, grouped by document_group_id."""
|
|
||||||
document_group_id = str(uuid.uuid4())
|
|
||||||
pages = []
|
|
||||||
|
|
||||||
for page_idx in range(page_count):
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
page_name = f"{base_name} — Seite {page_idx + 1}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
ok, png_buf = cv2.imencode(".png", img_bgr)
|
|
||||||
if not ok:
|
|
||||||
continue
|
|
||||||
page_png = png_buf.tobytes()
|
|
||||||
|
|
||||||
await create_session_db(
|
|
||||||
session_id=session_id,
|
|
||||||
name=page_name,
|
|
||||||
filename=filename,
|
|
||||||
original_png=page_png,
|
|
||||||
document_group_id=document_group_id,
|
|
||||||
page_number=page_idx + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
_cache[session_id] = {
|
|
||||||
"id": session_id,
|
|
||||||
"filename": filename,
|
|
||||||
"name": page_name,
|
|
||||||
"original_bgr": img_bgr,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": None,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
"orientation_result": None,
|
|
||||||
"crop_result": None,
|
|
||||||
"deskew_result": None,
|
|
||||||
"dewarp_result": None,
|
|
||||||
"ground_truth": {},
|
|
||||||
"current_step": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
pages.append({
|
|
||||||
"session_id": session_id,
|
|
||||||
"name": page_name,
|
|
||||||
"page_number": page_idx + 1,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"OCR Pipeline: created page session {session_id} "
|
|
||||||
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Include session_id pointing to first page for backwards compatibility
|
|
||||||
# (frontends that expect a single session_id will navigate to page 1)
|
|
||||||
first_session_id = pages[0]["session_id"] if pages else None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": first_session_id,
|
|
||||||
"document_group_id": document_group_id,
|
|
||||||
"filename": filename,
|
|
||||||
"name": base_name,
|
|
||||||
"page_count": page_count,
|
|
||||||
"pages": pages,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}")
|
|
||||||
async def get_session_info(session_id: str):
|
|
||||||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
# Get image dimensions from original PNG
|
|
||||||
original_png = await get_session_image(session_id, "original")
|
|
||||||
if original_png:
|
|
||||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
|
||||||
else:
|
|
||||||
img_w, img_h = 0, 0
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"session_id": session["id"],
|
|
||||||
"filename": session.get("filename", ""),
|
|
||||||
"name": session.get("name", ""),
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
|
||||||
"current_step": session.get("current_step", 1),
|
|
||||||
"document_category": session.get("document_category"),
|
|
||||||
"doc_type": session.get("doc_type"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.get("orientation_result"):
|
|
||||||
result["orientation_result"] = session["orientation_result"]
|
|
||||||
if session.get("crop_result"):
|
|
||||||
result["crop_result"] = session["crop_result"]
|
|
||||||
if session.get("deskew_result"):
|
|
||||||
result["deskew_result"] = session["deskew_result"]
|
|
||||||
if session.get("dewarp_result"):
|
|
||||||
result["dewarp_result"] = session["dewarp_result"]
|
|
||||||
if session.get("column_result"):
|
|
||||||
result["column_result"] = session["column_result"]
|
|
||||||
if session.get("row_result"):
|
|
||||||
result["row_result"] = session["row_result"]
|
|
||||||
if session.get("word_result"):
|
|
||||||
result["word_result"] = session["word_result"]
|
|
||||||
if session.get("doc_type_result"):
|
|
||||||
result["doc_type_result"] = session["doc_type_result"]
|
|
||||||
if session.get("structure_result"):
|
|
||||||
result["structure_result"] = session["structure_result"]
|
|
||||||
if session.get("grid_editor_result"):
|
|
||||||
# Include summary only to keep response small
|
|
||||||
gr = session["grid_editor_result"]
|
|
||||||
result["grid_editor_result"] = {
|
|
||||||
"summary": gr.get("summary", {}),
|
|
||||||
"zones_count": len(gr.get("zones", [])),
|
|
||||||
"edited": gr.get("edited", False),
|
|
||||||
}
|
|
||||||
if session.get("ground_truth"):
|
|
||||||
result["ground_truth"] = session["ground_truth"]
|
|
||||||
|
|
||||||
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
|
|
||||||
if session.get("parent_session_id"):
|
|
||||||
result["parent_session_id"] = session["parent_session_id"]
|
|
||||||
result["box_index"] = session.get("box_index")
|
|
||||||
else:
|
|
||||||
# Check for box sub-sessions (column detection creates these)
|
|
||||||
subs = await get_sub_sessions(session_id)
|
|
||||||
if subs:
|
|
||||||
result["sub_sessions"] = [
|
|
||||||
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
|
||||||
for s in subs
|
|
||||||
]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/sessions/{session_id}")
|
|
||||||
async def update_session(session_id: str, req: UpdateSessionRequest):
|
|
||||||
"""Update session name and/or document category."""
|
|
||||||
kwargs: Dict[str, Any] = {}
|
|
||||||
if req.name is not None:
|
|
||||||
kwargs["name"] = req.name
|
|
||||||
if req.document_category is not None:
|
|
||||||
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
|
||||||
)
|
|
||||||
kwargs["document_category"] = req.document_category
|
|
||||||
if not kwargs:
|
|
||||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
|
||||||
updated = await update_session_db(session_id, **kwargs)
|
|
||||||
if not updated:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
return {"session_id": session_id, **kwargs}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/sessions/{session_id}")
|
|
||||||
async def delete_session(session_id: str):
|
|
||||||
"""Delete a session."""
|
|
||||||
_cache.pop(session_id, None)
|
|
||||||
deleted = await delete_session_db(session_id)
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
return {"session_id": session_id, "deleted": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/sessions")
|
|
||||||
async def delete_all_sessions():
|
|
||||||
"""Delete ALL sessions (cleanup)."""
|
|
||||||
_cache.clear()
|
|
||||||
count = await delete_all_sessions_db()
|
|
||||||
return {"deleted_count": count}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/create-box-sessions")
|
|
||||||
async def create_box_sessions(session_id: str):
|
|
||||||
"""Create sub-sessions for each detected box region.
|
|
||||||
|
|
||||||
Crops box regions from the cropped/dewarped image and creates
|
|
||||||
independent sub-sessions that can be processed through the pipeline.
|
|
||||||
"""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
column_result = session.get("column_result")
|
|
||||||
if not column_result:
|
|
||||||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
|
||||||
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
|
||||||
if not box_zones:
|
|
||||||
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
|
||||||
|
|
||||||
# Check for existing sub-sessions
|
|
||||||
existing = await get_sub_sessions(session_id)
|
|
||||||
if existing:
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
|
||||||
"message": f"{len(existing)} sub-session(s) already exist",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load base image
|
|
||||||
base_png = await get_session_image(session_id, "cropped")
|
|
||||||
if not base_png:
|
|
||||||
base_png = await get_session_image(session_id, "dewarped")
|
|
||||||
if not base_png:
|
|
||||||
raise HTTPException(status_code=400, detail="No base image available")
|
|
||||||
|
|
||||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
|
|
||||||
parent_name = session.get("name", "Session")
|
|
||||||
created = []
|
|
||||||
|
|
||||||
for i, zone in enumerate(box_zones):
|
|
||||||
box = zone["box"]
|
|
||||||
bx, by = box["x"], box["y"]
|
|
||||||
bw, bh = box["width"], box["height"]
|
|
||||||
|
|
||||||
# Crop box region with small padding
|
|
||||||
pad = 5
|
|
||||||
y1 = max(0, by - pad)
|
|
||||||
y2 = min(img.shape[0], by + bh + pad)
|
|
||||||
x1 = max(0, bx - pad)
|
|
||||||
x2 = min(img.shape[1], bx + bw + pad)
|
|
||||||
crop = img[y1:y2, x1:x2]
|
|
||||||
|
|
||||||
# Encode as PNG
|
|
||||||
success, png_buf = cv2.imencode(".png", crop)
|
|
||||||
if not success:
|
|
||||||
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
sub_id = str(uuid.uuid4())
|
|
||||||
sub_name = f"{parent_name} — Box {i + 1}"
|
|
||||||
|
|
||||||
await create_session_db(
|
|
||||||
session_id=sub_id,
|
|
||||||
name=sub_name,
|
|
||||||
filename=session.get("filename", "box-crop.png"),
|
|
||||||
original_png=png_buf.tobytes(),
|
|
||||||
parent_session_id=session_id,
|
|
||||||
box_index=i,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache the BGR for immediate processing
|
|
||||||
# Promote original to cropped so column/row/word detection finds it
|
|
||||||
box_bgr = crop.copy()
|
|
||||||
_cache[sub_id] = {
|
|
||||||
"id": sub_id,
|
|
||||||
"filename": session.get("filename", "box-crop.png"),
|
|
||||||
"name": sub_name,
|
|
||||||
"parent_session_id": session_id,
|
|
||||||
"original_bgr": box_bgr,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": box_bgr,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
"orientation_result": None,
|
|
||||||
"crop_result": None,
|
|
||||||
"deskew_result": None,
|
|
||||||
"dewarp_result": None,
|
|
||||||
"ground_truth": {},
|
|
||||||
"current_step": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
created.append({
|
|
||||||
"id": sub_id,
|
|
||||||
"name": sub_name,
|
|
||||||
"box_index": i,
|
|
||||||
"box": box,
|
|
||||||
"image_width": crop.shape[1],
|
|
||||||
"image_height": crop.shape[0],
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
|
||||||
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"sub_sessions": created,
|
|
||||||
"total": len(created),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,176 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/sessions_images.py
|
||||||
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
|
import importlib as _importlib
|
||||||
categories, and document type detection.
|
import sys as _sys
|
||||||
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_images")
|
||||||
Extracted from ocr_pipeline_sessions.py for modularity.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import create_ocr_image, detect_document_type
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
VALID_DOCUMENT_CATEGORIES,
|
|
||||||
_append_pipeline_log,
|
|
||||||
_cache,
|
|
||||||
_get_base_image_png,
|
|
||||||
_get_cached,
|
|
||||||
_load_session_to_cache,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_overlays import render_overlay
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
get_session_image,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Thumbnail & Log Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/thumbnail")
|
|
||||||
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
|
||||||
"""Return a small thumbnail of the original image."""
|
|
||||||
original_png = await get_session_image(session_id, "original")
|
|
||||||
if not original_png:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
|
||||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
|
||||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
if img is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
scale = size / max(h, w)
|
|
||||||
new_w, new_h = int(w * scale), int(h * scale)
|
|
||||||
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
||||||
_, png_bytes = cv2.imencode(".png", thumb)
|
|
||||||
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
|
||||||
headers={"Cache-Control": "public, max-age=3600"})
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/pipeline-log")
|
|
||||||
async def get_pipeline_log(session_id: str):
|
|
||||||
"""Get the pipeline execution log for a session."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/categories")
|
|
||||||
async def list_categories():
|
|
||||||
"""List valid document categories."""
|
|
||||||
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Image Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
|
||||||
async def get_image(session_id: str, image_type: str):
|
|
||||||
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
|
||||||
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
|
||||||
if image_type not in valid_types:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
|
||||||
|
|
||||||
if image_type == "structure-overlay":
|
|
||||||
return await render_overlay("structure", session_id)
|
|
||||||
|
|
||||||
if image_type == "columns-overlay":
|
|
||||||
return await render_overlay("columns", session_id)
|
|
||||||
|
|
||||||
if image_type == "rows-overlay":
|
|
||||||
return await render_overlay("rows", session_id)
|
|
||||||
|
|
||||||
if image_type == "words-overlay":
|
|
||||||
return await render_overlay("words", session_id)
|
|
||||||
|
|
||||||
# Try cache first for fast serving
|
|
||||||
cached = _cache.get(session_id)
|
|
||||||
if cached:
|
|
||||||
png_key = f"{image_type}_png" if image_type != "original" else None
|
|
||||||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
|
||||||
|
|
||||||
# For binarized, check if we have it cached as PNG
|
|
||||||
if image_type == "binarized" and cached.get("binarized_png"):
|
|
||||||
return Response(content=cached["binarized_png"], media_type="image/png")
|
|
||||||
|
|
||||||
# Load from DB — for cropped/dewarped, fall back through the chain
|
|
||||||
if image_type in ("cropped", "dewarped"):
|
|
||||||
data = await _get_base_image_png(session_id)
|
|
||||||
else:
|
|
||||||
data = await get_session_image(session_id, image_type)
|
|
||||||
if not data:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
|
||||||
|
|
||||||
return Response(content=data, media_type="image/png")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Document Type Detection (between Dewarp and Columns)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/detect-type")
|
|
||||||
async def detect_type(session_id: str):
|
|
||||||
"""Detect document type (vocab_table, full_text, generic_table).
|
|
||||||
|
|
||||||
Should be called after crop (clean image available).
|
|
||||||
Falls back to dewarped if crop was skipped.
|
|
||||||
Stores result in session for frontend to decide pipeline flow.
|
|
||||||
"""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
ocr_img = create_ocr_image(img_bgr)
|
|
||||||
result = detect_document_type(ocr_img, img_bgr)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
"doc_type": result.doc_type,
|
|
||||||
"confidence": result.confidence,
|
|
||||||
"pipeline": result.pipeline,
|
|
||||||
"skip_steps": result.skip_steps,
|
|
||||||
"features": result.features,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
doc_type=result.doc_type,
|
|
||||||
doc_type_result=result_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
cached["doc_type_result"] = result_dict
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
|
||||||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "detect_type", {
|
|
||||||
"doc_type": result.doc_type,
|
|
||||||
"pipeline": result.pipeline,
|
|
||||||
"confidence": result.confidence,
|
|
||||||
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **result_dict}
|
|
||||||
|
|||||||
@@ -1,299 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/structure.py
|
||||||
OCR Pipeline Structure Detection and Exclude Regions
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Detect document structure (boxes, zones, color regions, graphics)
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.structure")
|
||||||
and manage user-drawn exclude regions.
|
|
||||||
Extracted from ocr_pipeline_geometry.py for file-size compliance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from cv_box_detect import detect_boxes
|
|
||||||
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
|
|
||||||
from cv_graphic_detect import detect_graphic_elements
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_filter_border_ghost_words,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Structure Detection Endpoint
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/detect-structure")
|
|
||||||
async def detect_structure(session_id: str):
|
|
||||||
"""Detect document structure: boxes, zones, and color regions.
|
|
||||||
|
|
||||||
Runs box detection (line + shading) and color analysis on the cropped
|
|
||||||
image. Returns structured JSON with all detected elements for the
|
|
||||||
structure visualization step.
|
|
||||||
"""
|
|
||||||
if session_id not in _cache:
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = (
|
|
||||||
cached.get("cropped_bgr")
|
|
||||||
if cached.get("cropped_bgr") is not None
|
|
||||||
else cached.get("dewarped_bgr")
|
|
||||||
)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# --- Content bounds from word result (if available) or full image ---
|
|
||||||
word_result = cached.get("word_result")
|
|
||||||
words: List[Dict] = []
|
|
||||||
if word_result and word_result.get("cells"):
|
|
||||||
for cell in word_result["cells"]:
|
|
||||||
for wb in (cell.get("word_boxes") or []):
|
|
||||||
words.append(wb)
|
|
||||||
# Fallback: use raw OCR words if cell word_boxes are empty
|
|
||||||
if not words and word_result:
|
|
||||||
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
|
|
||||||
raw = word_result.get(key, [])
|
|
||||||
if raw:
|
|
||||||
words = raw
|
|
||||||
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
|
|
||||||
break
|
|
||||||
# If no words yet, use image dimensions with small margin
|
|
||||||
if words:
|
|
||||||
content_x = max(0, min(int(wb["left"]) for wb in words))
|
|
||||||
content_y = max(0, min(int(wb["top"]) for wb in words))
|
|
||||||
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
|
|
||||||
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
|
|
||||||
content_w_px = content_r - content_x
|
|
||||||
content_h_px = content_b - content_y
|
|
||||||
else:
|
|
||||||
margin = int(min(w, h) * 0.03)
|
|
||||||
content_x, content_y = margin, margin
|
|
||||||
content_w_px = w - 2 * margin
|
|
||||||
content_h_px = h - 2 * margin
|
|
||||||
|
|
||||||
# --- Box detection ---
|
|
||||||
boxes = detect_boxes(
|
|
||||||
img_bgr,
|
|
||||||
content_x=content_x,
|
|
||||||
content_w=content_w_px,
|
|
||||||
content_y=content_y,
|
|
||||||
content_h=content_h_px,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Zone splitting ---
|
|
||||||
from cv_box_detect import split_page_into_zones as _split_zones
|
|
||||||
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
|
||||||
|
|
||||||
# --- Color region sampling ---
|
|
||||||
# Sample background shading in each detected box
|
|
||||||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
|
||||||
box_colors = []
|
|
||||||
for box in boxes:
|
|
||||||
# Sample the center region of each box
|
|
||||||
cy1 = box.y + box.height // 4
|
|
||||||
cy2 = box.y + 3 * box.height // 4
|
|
||||||
cx1 = box.x + box.width // 4
|
|
||||||
cx2 = box.x + 3 * box.width // 4
|
|
||||||
cy1 = max(0, min(cy1, h - 1))
|
|
||||||
cy2 = max(0, min(cy2, h - 1))
|
|
||||||
cx1 = max(0, min(cx1, w - 1))
|
|
||||||
cx2 = max(0, min(cx2, w - 1))
|
|
||||||
if cy2 > cy1 and cx2 > cx1:
|
|
||||||
roi_hsv = hsv[cy1:cy2, cx1:cx2]
|
|
||||||
med_h = float(np.median(roi_hsv[:, :, 0]))
|
|
||||||
med_s = float(np.median(roi_hsv[:, :, 1]))
|
|
||||||
med_v = float(np.median(roi_hsv[:, :, 2]))
|
|
||||||
if med_s > 15:
|
|
||||||
from cv_color_detect import _hue_to_color_name
|
|
||||||
bg_name = _hue_to_color_name(med_h)
|
|
||||||
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
|
|
||||||
else:
|
|
||||||
bg_name = "gray" if med_v < 220 else "white"
|
|
||||||
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
|
|
||||||
else:
|
|
||||||
bg_name = "unknown"
|
|
||||||
bg_hex = "#6b7280"
|
|
||||||
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
|
|
||||||
|
|
||||||
# --- Color text detection overview ---
|
|
||||||
# Quick scan for colored text regions across the page
|
|
||||||
color_summary: Dict[str, int] = {}
|
|
||||||
for color_name, ranges in _COLOR_RANGES.items():
|
|
||||||
mask = np.zeros((h, w), dtype=np.uint8)
|
|
||||||
for lower, upper in ranges:
|
|
||||||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
|
||||||
pixel_count = int(np.sum(mask > 0))
|
|
||||||
if pixel_count > 50: # minimum threshold
|
|
||||||
color_summary[color_name] = pixel_count
|
|
||||||
|
|
||||||
# --- Graphic element detection ---
|
|
||||||
box_dicts = [
|
|
||||||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
|
|
||||||
for b in boxes
|
|
||||||
]
|
|
||||||
graphics = detect_graphic_elements(
|
|
||||||
img_bgr, words,
|
|
||||||
detected_boxes=box_dicts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Filter border-ghost words from OCR result ---
|
|
||||||
ghost_count = 0
|
|
||||||
if boxes and word_result:
|
|
||||||
ghost_count = _filter_border_ghost_words(word_result, boxes)
|
|
||||||
if ghost_count:
|
|
||||||
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
|
|
||||||
await update_session_db(session_id, word_result=word_result)
|
|
||||||
cached["word_result"] = word_result
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
# Preserve user-drawn exclude regions from previous run
|
|
||||||
prev_sr = cached.get("structure_result") or {}
|
|
||||||
prev_exclude = prev_sr.get("exclude_regions", [])
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"content_bounds": {
|
|
||||||
"x": content_x, "y": content_y,
|
|
||||||
"w": content_w_px, "h": content_h_px,
|
|
||||||
},
|
|
||||||
"boxes": [
|
|
||||||
{
|
|
||||||
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
|
||||||
"confidence": b.confidence,
|
|
||||||
"border_thickness": b.border_thickness,
|
|
||||||
"bg_color_name": box_colors[i]["color_name"],
|
|
||||||
"bg_color_hex": box_colors[i]["color_hex"],
|
|
||||||
}
|
|
||||||
for i, b in enumerate(boxes)
|
|
||||||
],
|
|
||||||
"zones": [
|
|
||||||
{
|
|
||||||
"index": z.index,
|
|
||||||
"zone_type": z.zone_type,
|
|
||||||
"y": z.y, "h": z.height,
|
|
||||||
"x": z.x, "w": z.width,
|
|
||||||
}
|
|
||||||
for z in zones
|
|
||||||
],
|
|
||||||
"graphics": [
|
|
||||||
{
|
|
||||||
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
|
|
||||||
"area": g.area,
|
|
||||||
"shape": g.shape,
|
|
||||||
"color_name": g.color_name,
|
|
||||||
"color_hex": g.color_hex,
|
|
||||||
"confidence": round(g.confidence, 2),
|
|
||||||
}
|
|
||||||
for g in graphics
|
|
||||||
],
|
|
||||||
"exclude_regions": prev_exclude,
|
|
||||||
"color_pixel_counts": color_summary,
|
|
||||||
"has_words": len(words) > 0,
|
|
||||||
"word_count": len(words),
|
|
||||||
"border_ghosts_removed": ghost_count,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Persist to session
|
|
||||||
await update_session_db(session_id, structure_result=result_dict)
|
|
||||||
cached["structure_result"] = result_dict
|
|
||||||
|
|
||||||
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
|
|
||||||
session_id, len(boxes), len(zones), len(graphics), duration)
|
|
||||||
|
|
||||||
return {"session_id": session_id, **result_dict}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Exclude Regions -- user-drawn rectangles to exclude from OCR results
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class _ExcludeRegionIn(BaseModel):
|
|
||||||
x: int
|
|
||||||
y: int
|
|
||||||
w: int
|
|
||||||
h: int
|
|
||||||
label: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class _ExcludeRegionsBatchIn(BaseModel):
|
|
||||||
regions: list[_ExcludeRegionIn]
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/sessions/{session_id}/exclude-regions")
|
|
||||||
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
|
|
||||||
"""Replace all exclude regions for a session.
|
|
||||||
|
|
||||||
Regions are stored inside ``structure_result.exclude_regions``.
|
|
||||||
"""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
|
||||||
|
|
||||||
sr = session.get("structure_result") or {}
|
|
||||||
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
|
|
||||||
|
|
||||||
# Invalidate grid so it rebuilds with new exclude regions
|
|
||||||
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["structure_result"] = sr
|
|
||||||
_cache[session_id].pop("grid_editor_result", None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"exclude_regions": sr["exclude_regions"],
|
|
||||||
"count": len(sr["exclude_regions"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
|
|
||||||
async def delete_exclude_region(session_id: str, region_index: int):
|
|
||||||
"""Remove a single exclude region by index."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
|
||||||
|
|
||||||
sr = session.get("structure_result") or {}
|
|
||||||
regions = sr.get("exclude_regions", [])
|
|
||||||
|
|
||||||
if region_index < 0 or region_index >= len(regions):
|
|
||||||
raise HTTPException(status_code=404, detail="Region index out of range")
|
|
||||||
|
|
||||||
removed = regions.pop(region_index)
|
|
||||||
sr["exclude_regions"] = regions
|
|
||||||
|
|
||||||
# Invalidate grid so it rebuilds with new exclude regions
|
|
||||||
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["structure_result"] = sr
|
|
||||||
_cache[session_id].pop("grid_editor_result", None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"removed": removed,
|
|
||||||
"remaining": len(regions),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,362 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/validation.py
|
||||||
OCR Pipeline Validation — image detection, generation, validation save,
|
import importlib as _importlib
|
||||||
and handwriting removal endpoints.
|
import sys as _sys
|
||||||
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.validation")
|
||||||
Extracted from ocr_pipeline_postprocess.py.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
get_session_image,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
RemoveHandwritingRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pydantic Models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
STYLE_SUFFIXES = {
|
|
||||||
"educational": "educational illustration, textbook style, clear, colorful",
|
|
||||||
"cartoon": "cartoon, child-friendly, simple shapes",
|
|
||||||
"sketch": "pencil sketch, hand-drawn, black and white",
|
|
||||||
"clipart": "clipart, flat vector style, simple",
|
|
||||||
"realistic": "photorealistic, high detail",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationRequest(BaseModel):
|
|
||||||
notes: Optional[str] = None
|
|
||||||
score: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageRequest(BaseModel):
|
|
||||||
region_index: int
|
|
||||||
prompt: str
|
|
||||||
style: str = "educational"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Image detection + generation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
|
||||||
async def detect_image_regions(session_id: str):
|
|
||||||
"""Detect illustration/image regions in the original scan using VLM."""
|
|
||||||
import base64
|
|
||||||
import httpx
|
|
||||||
import re
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
original_png = await get_session_image(session_id, "original")
|
|
||||||
if not original_png:
|
|
||||||
raise HTTPException(status_code=400, detail="No original image found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result") or {}
|
|
||||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
|
||||||
vocab_context = ""
|
|
||||||
if entries:
|
|
||||||
sample = entries[:10]
|
|
||||||
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
|
||||||
if words:
|
|
||||||
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
|
||||||
|
|
||||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
||||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
|
||||||
"(NOT text, NOT table cells, NOT blank areas). "
|
|
||||||
"For each image region found, return its bounding box as percentage of page dimensions "
|
|
||||||
"and a short English description of what the image shows. "
|
|
||||||
"Reply with ONLY a JSON array like: "
|
|
||||||
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
|
||||||
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
|
||||||
"If there are NO images on the page, return an empty array: []"
|
|
||||||
f"{vocab_context}"
|
|
||||||
)
|
|
||||||
|
|
||||||
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"images": [img_b64],
|
|
||||||
"stream": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
||||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
|
||||||
resp.raise_for_status()
|
|
||||||
text = resp.json().get("response", "")
|
|
||||||
|
|
||||||
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
raw_regions = json.loads(match.group(0))
|
|
||||||
else:
|
|
||||||
raw_regions = []
|
|
||||||
|
|
||||||
regions = []
|
|
||||||
for r in raw_regions:
|
|
||||||
regions.append({
|
|
||||||
"bbox_pct": {
|
|
||||||
"x": max(0, min(100, float(r.get("x", 0)))),
|
|
||||||
"y": max(0, min(100, float(r.get("y", 0)))),
|
|
||||||
"w": max(1, min(100, float(r.get("w", 10)))),
|
|
||||||
"h": max(1, min(100, float(r.get("h", 10)))),
|
|
||||||
},
|
|
||||||
"description": r.get("description", ""),
|
|
||||||
"prompt": r.get("description", ""),
|
|
||||||
"image_b64": None,
|
|
||||||
"style": "educational",
|
|
||||||
})
|
|
||||||
|
|
||||||
# Enrich prompts with nearby vocab context
|
|
||||||
if entries:
|
|
||||||
for region in regions:
|
|
||||||
ry = region["bbox_pct"]["y"]
|
|
||||||
rh = region["bbox_pct"]["h"]
|
|
||||||
nearby = [
|
|
||||||
e for e in entries
|
|
||||||
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
|
||||||
]
|
|
||||||
if nearby:
|
|
||||||
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
|
||||||
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
|
||||||
if en_words or de_words:
|
|
||||||
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
|
||||||
if de_words:
|
|
||||||
context += f" / {', '.join(de_words[:5])}"
|
|
||||||
context += ")"
|
|
||||||
region["prompt"] = region["description"] + context
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
validation = ground_truth.get("validation") or {}
|
|
||||||
validation["image_regions"] = regions
|
|
||||||
validation["detected_at"] = datetime.utcnow().isoformat()
|
|
||||||
ground_truth["validation"] = validation
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
|
||||||
|
|
||||||
return {"regions": regions, "count": len(regions)}
|
|
||||||
|
|
||||||
except httpx.ConnectError:
|
|
||||||
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
|
||||||
return {"regions": [], "count": 0, "error": "VLM not available"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Image detection failed for {session_id}: {e}")
|
|
||||||
return {"regions": [], "count": 0, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
|
||||||
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
|
||||||
"""Generate a replacement image for a detected region using mflux."""
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
validation = ground_truth.get("validation") or {}
|
|
||||||
regions = validation.get("image_regions") or []
|
|
||||||
|
|
||||||
if req.region_index < 0 or req.region_index >= len(regions):
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
|
||||||
|
|
||||||
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
|
||||||
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
|
||||||
full_prompt = f"{req.prompt}, {style_suffix}"
|
|
||||||
|
|
||||||
region = regions[req.region_index]
|
|
||||||
bbox = region["bbox_pct"]
|
|
||||||
aspect = bbox["w"] / max(bbox["h"], 1)
|
|
||||||
if aspect > 1.3:
|
|
||||||
width, height = 768, 512
|
|
||||||
elif aspect < 0.7:
|
|
||||||
width, height = 512, 768
|
|
||||||
else:
|
|
||||||
width, height = 512, 512
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
||||||
resp = await client.post(f"{mflux_url}/generate", json={
|
|
||||||
"prompt": full_prompt,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"steps": 4,
|
|
||||||
})
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
image_b64 = data.get("image_b64")
|
|
||||||
|
|
||||||
if not image_b64:
|
|
||||||
return {"image_b64": None, "success": False, "error": "No image returned"}
|
|
||||||
|
|
||||||
regions[req.region_index]["image_b64"] = image_b64
|
|
||||||
regions[req.region_index]["prompt"] = req.prompt
|
|
||||||
regions[req.region_index]["style"] = req.style
|
|
||||||
validation["image_regions"] = regions
|
|
||||||
ground_truth["validation"] = validation
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
|
||||||
return {"image_b64": image_b64, "success": True}
|
|
||||||
|
|
||||||
except httpx.ConnectError:
|
|
||||||
logger.warning(f"mflux-service not available at {mflux_url}")
|
|
||||||
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Image generation failed for {session_id}: {e}")
|
|
||||||
return {"image_b64": None, "success": False, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Validation save/get
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/reconstruction/validate")
|
|
||||||
async def save_validation(session_id: str, req: ValidationRequest):
|
|
||||||
"""Save final validation results for step 8."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
validation = ground_truth.get("validation") or {}
|
|
||||||
validation["validated_at"] = datetime.utcnow().isoformat()
|
|
||||||
validation["notes"] = req.notes
|
|
||||||
validation["score"] = req.score
|
|
||||||
ground_truth["validation"] = validation
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
|
||||||
|
|
||||||
return {"session_id": session_id, "validation": validation}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions/{session_id}/reconstruction/validation")
|
|
||||||
async def get_validation(session_id: str):
|
|
||||||
"""Retrieve saved validation data for step 8."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
validation = ground_truth.get("validation")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"validation": validation,
|
|
||||||
"word_result": session.get("word_result"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Remove handwriting
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/remove-handwriting")
|
|
||||||
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
|
||||||
"""Remove handwriting from a session image using inpainting."""
|
|
||||||
import time as _time
|
|
||||||
|
|
||||||
from services.handwriting_detection import detect_handwriting
|
|
||||||
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
t0 = _time.monotonic()
|
|
||||||
|
|
||||||
# 1. Determine source image
|
|
||||||
source = req.use_source
|
|
||||||
if source == "auto":
|
|
||||||
deskewed = await get_session_image(session_id, "deskewed")
|
|
||||||
source = "deskewed" if deskewed else "original"
|
|
||||||
|
|
||||||
image_bytes = await get_session_image(session_id, source)
|
|
||||||
if not image_bytes:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
|
||||||
|
|
||||||
# 2. Detect handwriting mask
|
|
||||||
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
|
||||||
|
|
||||||
# 3. Convert mask to PNG bytes and dilate
|
|
||||||
import io
|
|
||||||
from PIL import Image as _PILImage
|
|
||||||
mask_img = _PILImage.fromarray(detection.mask)
|
|
||||||
mask_buf = io.BytesIO()
|
|
||||||
mask_img.save(mask_buf, format="PNG")
|
|
||||||
mask_bytes = mask_buf.getvalue()
|
|
||||||
|
|
||||||
if req.dilation > 0:
|
|
||||||
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
|
||||||
|
|
||||||
# 4. Inpaint
|
|
||||||
method_map = {
|
|
||||||
"telea": InpaintingMethod.OPENCV_TELEA,
|
|
||||||
"ns": InpaintingMethod.OPENCV_NS,
|
|
||||||
"auto": InpaintingMethod.AUTO,
|
|
||||||
}
|
|
||||||
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
|
||||||
|
|
||||||
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
|
||||||
if not result.success:
|
|
||||||
raise HTTPException(status_code=500, detail="Inpainting failed")
|
|
||||||
|
|
||||||
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
|
||||||
|
|
||||||
meta = {
|
|
||||||
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
|
||||||
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
|
||||||
"detection_confidence": round(detection.confidence, 4),
|
|
||||||
"target_ink": req.target_ink,
|
|
||||||
"dilation": req.dilation,
|
|
||||||
"source_image": source,
|
|
||||||
"processing_time_ms": elapsed_ms,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 5. Persist clean image
|
|
||||||
clean_png_bytes = image_to_png(result.image)
|
|
||||||
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
|
||||||
|
|
||||||
return {
|
|
||||||
**meta,
|
|
||||||
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,185 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/words.py
|
||||||
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
|
import importlib as _importlib
|
||||||
and ground truth endpoints.
|
import sys as _sys
|
||||||
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words")
|
||||||
Split into sub-modules:
|
|
||||||
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
|
|
||||||
ocr_pipeline_words_stream — SSE streaming generators
|
|
||||||
|
|
||||||
This barrel module contains the PaddleOCR direct endpoint and ground truth
|
|
||||||
endpoints, and assembles all word-related routers.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from cv_words_first import build_grid_from_words
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
get_session_image,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_append_pipeline_log,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_words_detect import router as _detect_router
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pydantic models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class WordGroundTruthRequest(BaseModel):
|
|
||||||
is_correct: bool
|
|
||||||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
|
||||||
notes: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# PaddleOCR Direct Endpoint
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@_local_router.post("/sessions/{session_id}/paddle-direct")
|
|
||||||
async def paddle_direct(session_id: str):
|
|
||||||
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
|
|
||||||
img_png = await get_session_image(session_id, "cropped")
|
|
||||||
if not img_png:
|
|
||||||
img_png = await get_session_image(session_id, "dewarped")
|
|
||||||
if not img_png:
|
|
||||||
img_png = await get_session_image(session_id, "original")
|
|
||||||
if not img_png:
|
|
||||||
raise HTTPException(status_code=404, detail="No image found for this session")
|
|
||||||
|
|
||||||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
|
||||||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to decode original image")
|
|
||||||
|
|
||||||
img_h, img_w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
from cv_ocr_engines import ocr_region_paddle
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
word_dicts = await ocr_region_paddle(img_bgr, region=None)
|
|
||||||
if not word_dicts:
|
|
||||||
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
|
|
||||||
|
|
||||||
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
for cell in cells:
|
|
||||||
cell["ocr_engine"] = "paddle_direct"
|
|
||||||
|
|
||||||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
|
||||||
n_cols = len(columns_meta)
|
|
||||||
col_types = {c.get("type") for c in columns_meta}
|
|
||||||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
|
||||||
|
|
||||||
word_result = {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": "paddle_direct",
|
|
||||||
"grid_method": "paddle_direct",
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
word_result=word_result,
|
|
||||||
cropped_png=img_png,
|
|
||||||
current_step=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
|
|
||||||
session_id, len(cells), n_rows, n_cols, duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "paddle_direct", {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
|
||||||
"ocr_engine": "paddle_direct",
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **word_result}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Ground Truth Words Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@_local_router.post("/sessions/{session_id}/ground-truth/words")
|
|
||||||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
|
||||||
"""Save ground truth feedback for the word recognition step."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
gt = {
|
|
||||||
"is_correct": req.is_correct,
|
|
||||||
"corrected_entries": req.corrected_entries,
|
|
||||||
"notes": req.notes,
|
|
||||||
"saved_at": datetime.utcnow().isoformat(),
|
|
||||||
"word_result": session.get("word_result"),
|
|
||||||
}
|
|
||||||
ground_truth["words"] = gt
|
|
||||||
|
|
||||||
await update_session_db(session_id, ground_truth=ground_truth)
|
|
||||||
|
|
||||||
if session_id in _cache:
|
|
||||||
_cache[session_id]["ground_truth"] = ground_truth
|
|
||||||
|
|
||||||
return {"session_id": session_id, "ground_truth": gt}
|
|
||||||
|
|
||||||
|
|
||||||
@_local_router.get("/sessions/{session_id}/ground-truth/words")
|
|
||||||
async def get_word_ground_truth(session_id: str):
|
|
||||||
"""Retrieve saved ground truth for word recognition."""
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
ground_truth = session.get("ground_truth") or {}
|
|
||||||
words_gt = ground_truth.get("words")
|
|
||||||
if not words_gt:
|
|
||||||
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"words_gt": words_gt,
|
|
||||||
"words_auto": session.get("word_result"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Composite router
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_detect_router)
|
|
||||||
router.include_router(_local_router)
|
|
||||||
|
|||||||
@@ -1,393 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/words_detect.py
|
||||||
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_detect")
|
||||||
endpoint which handles both v2 and words_first grid methods.
|
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
PageRegion,
|
|
||||||
RowGeometry,
|
|
||||||
_cells_to_vocab_entries,
|
|
||||||
_fix_phonetic_brackets,
|
|
||||||
fix_cell_phonetics,
|
|
||||||
build_cell_grid_v2,
|
|
||||||
create_ocr_image,
|
|
||||||
detect_column_geometry,
|
|
||||||
)
|
|
||||||
from cv_words_first import build_grid_from_words
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_common import (
|
|
||||||
_cache,
|
|
||||||
_load_session_to_cache,
|
|
||||||
_get_cached,
|
|
||||||
_append_pipeline_log,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_words_stream import (
|
|
||||||
_word_batch_stream_generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Word Detection Endpoint (Step 7)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/words")
|
|
||||||
async def detect_words(
|
|
||||||
session_id: str,
|
|
||||||
request: Request,
|
|
||||||
engine: str = "auto",
|
|
||||||
pronunciation: str = "british",
|
|
||||||
stream: bool = False,
|
|
||||||
skip_heal_gaps: bool = False,
|
|
||||||
grid_method: str = "v2",
|
|
||||||
):
|
|
||||||
"""Build word grid from columns x rows, OCR each cell.
|
|
||||||
|
|
||||||
Query params:
|
|
||||||
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
|
||||||
pronunciation: 'british' (default) or 'american'
|
|
||||||
stream: false (default) for JSON response, true for SSE streaming
|
|
||||||
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
|
|
||||||
grid_method: 'v2' (default) or 'words_first'
|
|
||||||
"""
|
|
||||||
# PaddleOCR is full-page remote OCR -> force words_first grid method
|
|
||||||
if engine == "paddle" and grid_method != "words_first":
|
|
||||||
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
|
||||||
grid_method = "words_first"
|
|
||||||
|
|
||||||
if session_id not in _cache:
|
|
||||||
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
|
||||||
await _load_session_to_cache(session_id)
|
|
||||||
cached = _get_cached(session_id)
|
|
||||||
|
|
||||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
||||||
if dewarped_bgr is None:
|
|
||||||
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
|
||||||
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
|
||||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
column_result = session.get("column_result")
|
|
||||||
row_result = session.get("row_result")
|
|
||||||
if not column_result or not column_result.get("columns"):
|
|
||||||
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
|
||||||
column_result = {
|
|
||||||
"columns": [{
|
|
||||||
"type": "column_text",
|
|
||||||
"x": 0, "y": 0,
|
|
||||||
"width": img_w_tmp, "height": img_h_tmp,
|
|
||||||
"classification_confidence": 1.0,
|
|
||||||
"classification_method": "full_page_fallback",
|
|
||||||
}],
|
|
||||||
"zones": [],
|
|
||||||
"duration_seconds": 0,
|
|
||||||
}
|
|
||||||
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
|
||||||
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
|
||||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
|
||||||
|
|
||||||
# Convert column dicts back to PageRegion objects
|
|
||||||
col_regions = [
|
|
||||||
PageRegion(
|
|
||||||
type=c["type"],
|
|
||||||
x=c["x"], y=c["y"],
|
|
||||||
width=c["width"], height=c["height"],
|
|
||||||
classification_confidence=c.get("classification_confidence", 1.0),
|
|
||||||
classification_method=c.get("classification_method", ""),
|
|
||||||
)
|
|
||||||
for c in column_result["columns"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert row dicts back to RowGeometry objects
|
|
||||||
row_geoms = [
|
|
||||||
RowGeometry(
|
|
||||||
index=r["index"],
|
|
||||||
x=r["x"], y=r["y"],
|
|
||||||
width=r["width"], height=r["height"],
|
|
||||||
word_count=r.get("word_count", 0),
|
|
||||||
words=[],
|
|
||||||
row_type=r.get("row_type", "content"),
|
|
||||||
gap_before=r.get("gap_before", 0),
|
|
||||||
)
|
|
||||||
for r in row_result["rows"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Populate word counts from cached words
|
|
||||||
word_dicts = cached.get("_word_dicts")
|
|
||||||
if word_dicts is None:
|
|
||||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
|
||||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
|
||||||
if geo_result is not None:
|
|
||||||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
|
||||||
cached["_word_dicts"] = word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
||||||
|
|
||||||
if word_dicts:
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
|
||||||
if content_bounds:
|
|
||||||
_lx, _rx, top_y, _by = content_bounds
|
|
||||||
else:
|
|
||||||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
|
||||||
|
|
||||||
for row in row_geoms:
|
|
||||||
row_y_rel = row.y - top_y
|
|
||||||
row_bottom_rel = row_y_rel + row.height
|
|
||||||
row.words = [
|
|
||||||
w for w in word_dicts
|
|
||||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
|
||||||
]
|
|
||||||
row.word_count = len(row.words)
|
|
||||||
|
|
||||||
# Exclude rows that fall within box zones
|
|
||||||
zones = column_result.get("zones") or []
|
|
||||||
box_ranges_inner = []
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
|
||||||
box = zone["box"]
|
|
||||||
bt = max(box.get("border_thickness", 0), 5)
|
|
||||||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
|
||||||
|
|
||||||
if box_ranges_inner:
|
|
||||||
def _row_in_box(r):
|
|
||||||
center_y = r.y + r.height / 2
|
|
||||||
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
|
||||||
|
|
||||||
before_count = len(row_geoms)
|
|
||||||
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
|
||||||
excluded = before_count - len(row_geoms)
|
|
||||||
if excluded:
|
|
||||||
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
|
||||||
|
|
||||||
# --- Words-First path ---
|
|
||||||
if grid_method == "words_first":
|
|
||||||
return await _words_first_path(
|
|
||||||
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return StreamingResponse(
|
|
||||||
_word_batch_stream_generator(
|
|
||||||
session_id, cached, col_regions, row_geoms,
|
|
||||||
dewarped_bgr, engine, pronunciation, request,
|
|
||||||
skip_heal_gaps=skip_heal_gaps,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Non-streaming path (grid_method=v2) ---
|
|
||||||
return await _v2_path(
|
|
||||||
session_id, cached, col_regions, row_geoms,
|
|
||||||
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _words_first_path(
|
|
||||||
session_id: str,
|
|
||||||
cached: Dict[str, Any],
|
|
||||||
dewarped_bgr: np.ndarray,
|
|
||||||
engine: str,
|
|
||||||
pronunciation: str,
|
|
||||||
zones: list,
|
|
||||||
) -> dict:
|
|
||||||
"""Words-first grid construction path."""
|
|
||||||
t0 = time.time()
|
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
|
||||||
|
|
||||||
if engine == "paddle":
|
|
||||||
from cv_ocr_engines import ocr_region_paddle
|
|
||||||
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
|
||||||
cached["_paddle_word_dicts"] = wf_word_dicts
|
|
||||||
else:
|
|
||||||
wf_word_dicts = cached.get("_word_dicts")
|
|
||||||
if wf_word_dicts is None:
|
|
||||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
|
||||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
|
||||||
if geo_result is not None:
|
|
||||||
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
|
||||||
cached["_word_dicts"] = wf_word_dicts
|
|
||||||
cached["_inv"] = inv
|
|
||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
||||||
|
|
||||||
if not wf_word_dicts:
|
|
||||||
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
|
|
||||||
|
|
||||||
# Convert word coordinates to absolute if needed
|
|
||||||
if engine != "paddle":
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
|
||||||
if content_bounds:
|
|
||||||
lx, _rx, ty, _by = content_bounds
|
|
||||||
abs_words = []
|
|
||||||
for w in wf_word_dicts:
|
|
||||||
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
|
|
||||||
wf_word_dicts = abs_words
|
|
||||||
|
|
||||||
box_rects = []
|
|
||||||
for zone in zones:
|
|
||||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
|
||||||
box_rects.append(zone["box"])
|
|
||||||
|
|
||||||
cells, columns_meta = build_grid_from_words(
|
|
||||||
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
|
|
||||||
)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
|
||||||
for cell in cells:
|
|
||||||
cell.setdefault("zone_index", 0)
|
|
||||||
|
|
||||||
col_types = {c['type'] for c in columns_meta}
|
|
||||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
|
||||||
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
|
||||||
n_cols = len(columns_meta)
|
|
||||||
used_engine = "paddle" if engine == "paddle" else "words_first"
|
|
||||||
|
|
||||||
word_result = {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"grid_method": "words_first",
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_vocab or 'column_text' in col_types:
|
|
||||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
|
||||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
word_result["entries"] = entries
|
|
||||||
word_result["entry_count"] = len(entries)
|
|
||||||
word_result["summary"]["total_entries"] = len(entries)
|
|
||||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
|
||||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
|
||||||
cached["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
|
||||||
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "words", {
|
|
||||||
"grid_method": "words_first",
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"layout": word_result["layout"],
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **word_result}
|
|
||||||
|
|
||||||
|
|
||||||
async def _v2_path(
|
|
||||||
session_id: str,
|
|
||||||
cached: Dict[str, Any],
|
|
||||||
col_regions: List[PageRegion],
|
|
||||||
row_geoms: List[RowGeometry],
|
|
||||||
dewarped_bgr: np.ndarray,
|
|
||||||
engine: str,
|
|
||||||
pronunciation: str,
|
|
||||||
skip_heal_gaps: bool,
|
|
||||||
) -> dict:
|
|
||||||
"""Cell-First OCR v2 non-streaming path."""
|
|
||||||
t0 = time.time()
|
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
|
||||||
|
|
||||||
cells, columns_meta = build_cell_grid_v2(
|
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
|
||||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
|
||||||
skip_heal_gaps=skip_heal_gaps,
|
|
||||||
)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
for cell in cells:
|
|
||||||
cell.setdefault("zone_index", 0)
|
|
||||||
|
|
||||||
col_types = {c['type'] for c in columns_meta}
|
|
||||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
|
||||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
|
||||||
n_cols = len(columns_meta)
|
|
||||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
|
||||||
|
|
||||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
|
||||||
|
|
||||||
word_result = {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
has_text_col = 'column_text' in col_types
|
|
||||||
if is_vocab or has_text_col:
|
|
||||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
|
||||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
word_result["entries"] = entries
|
|
||||||
word_result["entry_count"] = len(entries)
|
|
||||||
word_result["summary"]["total_entries"] = len(entries)
|
|
||||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
|
||||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
|
||||||
cached["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
|
||||||
f"layout={word_result['layout']}, "
|
|
||||||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
|
||||||
|
|
||||||
await _append_pipeline_log(session_id, "words", {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
|
||||||
"low_confidence_count": word_result["summary"]["low_confidence"],
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"layout": word_result["layout"],
|
|
||||||
"entry_count": word_result.get("entry_count", 0),
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
return {"session_id": session_id, **word_result}
|
|
||||||
|
|||||||
@@ -1,303 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/words_stream.py
|
||||||
OCR Pipeline Words Stream — SSE streaming generators for word detection.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Extracted from ocr_pipeline_words.py.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_stream")
|
||||||
|
|
||||||
Lizenz: Apache 2.0
|
|
||||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import (
|
|
||||||
PageRegion,
|
|
||||||
RowGeometry,
|
|
||||||
_cells_to_vocab_entries,
|
|
||||||
_fix_character_confusion,
|
|
||||||
_fix_phonetic_brackets,
|
|
||||||
fix_cell_phonetics,
|
|
||||||
build_cell_grid_v2,
|
|
||||||
build_cell_grid_v2_streaming,
|
|
||||||
create_ocr_image,
|
|
||||||
)
|
|
||||||
from ocr_pipeline_session_store import update_session_db
|
|
||||||
from ocr_pipeline_common import _cache
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _word_batch_stream_generator(
|
|
||||||
session_id: str,
|
|
||||||
cached: Dict[str, Any],
|
|
||||||
col_regions: List[PageRegion],
|
|
||||||
row_geoms: List[RowGeometry],
|
|
||||||
dewarped_bgr: np.ndarray,
|
|
||||||
engine: str,
|
|
||||||
pronunciation: str,
|
|
||||||
request: Request,
|
|
||||||
skip_heal_gaps: bool = False,
|
|
||||||
):
|
|
||||||
"""SSE generator that runs batch OCR (parallel) then streams results.
|
|
||||||
|
|
||||||
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
|
|
||||||
then emits all cells as SSE events.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
|
||||||
|
|
||||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
|
||||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
|
||||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
|
||||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
|
||||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
|
||||||
total_cells = n_content_rows * n_cols
|
|
||||||
|
|
||||||
# 1. Send meta event immediately
|
|
||||||
meta_event = {
|
|
||||||
"type": "meta",
|
|
||||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
|
||||||
|
|
||||||
# 2. Send preparing event (keepalive for proxy)
|
|
||||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
|
||||||
|
|
||||||
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
ocr_future = loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: build_cell_grid_v2(
|
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
|
||||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
|
||||||
skip_heal_gaps=skip_heal_gaps,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send keepalive events every 5 seconds while OCR runs
|
|
||||||
keepalive_count = 0
|
|
||||||
while not ocr_future.done():
|
|
||||||
try:
|
|
||||||
cells, columns_meta = await asyncio.wait_for(
|
|
||||||
asyncio.shield(ocr_future), timeout=5.0,
|
|
||||||
)
|
|
||||||
break # OCR finished
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
keepalive_count += 1
|
|
||||||
elapsed = int(time.time() - t0)
|
|
||||||
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
|
||||||
if await request.is_disconnected():
|
|
||||||
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
|
||||||
ocr_future.cancel()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
cells, columns_meta = ocr_future.result()
|
|
||||||
|
|
||||||
if await request.is_disconnected():
|
|
||||||
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 4. Apply IPA phonetic fixes
|
|
||||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
|
||||||
|
|
||||||
# 5. Send columns meta
|
|
||||||
if columns_meta:
|
|
||||||
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
|
||||||
|
|
||||||
# 6. Stream all cells
|
|
||||||
for idx, cell in enumerate(cells):
|
|
||||||
cell_event = {
|
|
||||||
"type": "cell",
|
|
||||||
"cell": cell,
|
|
||||||
"progress": {"current": idx + 1, "total": len(cells)},
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
|
||||||
|
|
||||||
# 7. Build final result and persist
|
|
||||||
duration = time.time() - t0
|
|
||||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
|
||||||
|
|
||||||
word_result = {
|
|
||||||
"cells": cells,
|
|
||||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(cells),
|
|
||||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
vocab_entries = None
|
|
||||||
has_text_col = 'column_text' in col_types
|
|
||||||
if is_vocab or has_text_col:
|
|
||||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
|
||||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
word_result["entries"] = entries
|
|
||||||
word_result["entry_count"] = len(entries)
|
|
||||||
word_result["summary"]["total_entries"] = len(entries)
|
|
||||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
|
||||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
|
||||||
vocab_entries = entries
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
|
||||||
cached["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
|
||||||
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
|
||||||
|
|
||||||
# 8. Send complete event
|
|
||||||
complete_event = {
|
|
||||||
"type": "complete",
|
|
||||||
"summary": word_result["summary"],
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
}
|
|
||||||
if vocab_entries is not None:
|
|
||||||
complete_event["vocab_entries"] = vocab_entries
|
|
||||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def _word_stream_generator(
|
|
||||||
session_id: str,
|
|
||||||
cached: Dict[str, Any],
|
|
||||||
col_regions: List[PageRegion],
|
|
||||||
row_geoms: List[RowGeometry],
|
|
||||||
dewarped_bgr: np.ndarray,
|
|
||||||
engine: str,
|
|
||||||
pronunciation: str,
|
|
||||||
request: Request,
|
|
||||||
):
|
|
||||||
"""SSE generator that yields cell-by-cell OCR progress."""
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
|
||||||
|
|
||||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
|
||||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
|
||||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
|
||||||
|
|
||||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
|
||||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
|
||||||
|
|
||||||
columns_meta = None
|
|
||||||
total_cells = n_content_rows * n_cols
|
|
||||||
|
|
||||||
meta_event = {
|
|
||||||
"type": "meta",
|
|
||||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
|
||||||
|
|
||||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
|
||||||
|
|
||||||
all_cells: List[Dict[str, Any]] = []
|
|
||||||
cell_idx = 0
|
|
||||||
last_keepalive = time.time()
|
|
||||||
|
|
||||||
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
|
||||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
|
||||||
):
|
|
||||||
if await request.is_disconnected():
|
|
||||||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if columns_meta is None:
|
|
||||||
columns_meta = cols_meta
|
|
||||||
meta_update = {"type": "columns", "columns_used": cols_meta}
|
|
||||||
yield f"data: {json.dumps(meta_update)}\n\n"
|
|
||||||
|
|
||||||
all_cells.append(cell)
|
|
||||||
cell_idx += 1
|
|
||||||
|
|
||||||
cell_event = {
|
|
||||||
"type": "cell",
|
|
||||||
"cell": cell,
|
|
||||||
"progress": {"current": cell_idx, "total": total},
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
|
||||||
|
|
||||||
# All cells done
|
|
||||||
duration = time.time() - t0
|
|
||||||
if columns_meta is None:
|
|
||||||
columns_meta = []
|
|
||||||
|
|
||||||
# Remove all-empty rows
|
|
||||||
rows_with_text: set = set()
|
|
||||||
for c in all_cells:
|
|
||||||
if c.get("text", "").strip():
|
|
||||||
rows_with_text.add(c["row_index"])
|
|
||||||
before_filter = len(all_cells)
|
|
||||||
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
|
||||||
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
|
||||||
if empty_rows_removed > 0:
|
|
||||||
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
|
||||||
|
|
||||||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
|
||||||
|
|
||||||
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
|
||||||
|
|
||||||
word_result = {
|
|
||||||
"cells": all_cells,
|
|
||||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
|
|
||||||
"columns_used": columns_meta,
|
|
||||||
"layout": "vocab" if is_vocab else "generic",
|
|
||||||
"image_width": img_w,
|
|
||||||
"image_height": img_h,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
"summary": {
|
|
||||||
"total_cells": len(all_cells),
|
|
||||||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
|
||||||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
vocab_entries = None
|
|
||||||
has_text_col = 'column_text' in col_types
|
|
||||||
if is_vocab or has_text_col:
|
|
||||||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
|
||||||
entries = _fix_character_confusion(entries)
|
|
||||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
|
||||||
word_result["vocab_entries"] = entries
|
|
||||||
word_result["entries"] = entries
|
|
||||||
word_result["entry_count"] = len(entries)
|
|
||||||
word_result["summary"]["total_entries"] = len(entries)
|
|
||||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
|
||||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
|
||||||
vocab_entries = entries
|
|
||||||
|
|
||||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
|
||||||
cached["word_result"] = word_result
|
|
||||||
|
|
||||||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
|
||||||
f"layout={word_result['layout']}, "
|
|
||||||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
|
||||||
|
|
||||||
complete_event = {
|
|
||||||
"type": "complete",
|
|
||||||
"summary": word_result["summary"],
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
"ocr_engine": used_engine,
|
|
||||||
}
|
|
||||||
if vocab_entries is not None:
|
|
||||||
complete_event["vocab_entries"] = vocab_entries
|
|
||||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
|
||||||
|
|||||||
@@ -1,188 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/orientation_api.py
|
||||||
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
|
import importlib as _importlib
|
||||||
"""
|
import sys as _sys
|
||||||
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_api")
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
|
|
||||||
from cv_vocab_pipeline import detect_and_fix_orientation
|
|
||||||
from page_crop import detect_page_splits
|
|
||||||
from ocr_pipeline_session_store import update_session_db
|
|
||||||
|
|
||||||
from orientation_crop_helpers import ensure_cached, append_pipeline_log
|
|
||||||
from page_sub_sessions import create_page_sub_sessions_full
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Step 1: Orientation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/orientation")
|
|
||||||
async def detect_orientation(session_id: str):
|
|
||||||
"""Detect and fix 90/180/270 degree rotations from scanners.
|
|
||||||
|
|
||||||
Reads the original image, applies orientation correction,
|
|
||||||
stores the result as oriented_png.
|
|
||||||
"""
|
|
||||||
cached = await ensure_cached(session_id)
|
|
||||||
|
|
||||||
img_bgr = cached.get("original_bgr")
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Original image not available")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
# Detect and fix orientation
|
|
||||||
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
|
|
||||||
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
orientation_result = {
|
|
||||||
"orientation_degrees": orientation_deg,
|
|
||||||
"corrected": orientation_deg != 0,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Encode oriented image
|
|
||||||
success, png_buf = cv2.imencode(".png", oriented_bgr)
|
|
||||||
oriented_png = png_buf.tobytes() if success else b""
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
cached["oriented_bgr"] = oriented_bgr
|
|
||||||
cached["orientation_result"] = orientation_result
|
|
||||||
|
|
||||||
# Persist to DB
|
|
||||||
await update_session_db(
|
|
||||||
session_id,
|
|
||||||
oriented_png=oriented_png,
|
|
||||||
orientation_result=orientation_result,
|
|
||||||
current_step=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
|
|
||||||
session_id, orientation_deg,
|
|
||||||
"corrected" if orientation_deg else "no change",
|
|
||||||
duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
await append_pipeline_log(session_id, "orientation", {
|
|
||||||
"orientation_degrees": orientation_deg,
|
|
||||||
"corrected": orientation_deg != 0,
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
h, w = oriented_bgr.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**orientation_result,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/page-split")
|
|
||||||
async def detect_page_split(session_id: str):
|
|
||||||
"""Detect if the image is a double-page book spread and split into sub-sessions.
|
|
||||||
|
|
||||||
Must be called **after orientation** (step 1) and **before deskew** (step 2).
|
|
||||||
Each sub-session receives the raw page region and goes through the full
|
|
||||||
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
|
|
||||||
independently, so each page gets its own deskew correction.
|
|
||||||
|
|
||||||
Returns ``{"multi_page": false}`` if only one page is detected.
|
|
||||||
"""
|
|
||||||
cached = await ensure_cached(session_id)
|
|
||||||
|
|
||||||
# Use oriented (preferred), fall back to original
|
|
||||||
img_bgr = next(
|
|
||||||
(v for k in ("oriented_bgr", "original_bgr")
|
|
||||||
if (v := cached.get(k)) is not None),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if img_bgr is None:
|
|
||||||
raise HTTPException(status_code=400, detail="No image available for page-split detection")
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
page_splits = detect_page_splits(img_bgr)
|
|
||||||
used_original = False
|
|
||||||
|
|
||||||
if not page_splits or len(page_splits) < 2:
|
|
||||||
# Orientation may have rotated a landscape double-page spread to
|
|
||||||
# portrait. Try the original (pre-orientation) image as fallback.
|
|
||||||
orig_bgr = cached.get("original_bgr")
|
|
||||||
if orig_bgr is not None and orig_bgr is not img_bgr:
|
|
||||||
page_splits_orig = detect_page_splits(orig_bgr)
|
|
||||||
if page_splits_orig and len(page_splits_orig) >= 2:
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: page-split session %s: spread detected on "
|
|
||||||
"ORIGINAL (orientation rotated it away)",
|
|
||||||
session_id,
|
|
||||||
)
|
|
||||||
img_bgr = orig_bgr
|
|
||||||
page_splits = page_splits_orig
|
|
||||||
used_original = True
|
|
||||||
|
|
||||||
if not page_splits or len(page_splits) < 2:
|
|
||||||
duration = time.time() - t0
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: page-split session %s: single page (%.2fs)",
|
|
||||||
session_id, duration,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
"multi_page": False,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Multi-page spread detected — create sub-sessions for full pipeline.
|
|
||||||
# start_step=2 means "ready for deskew" (orientation already applied).
|
|
||||||
# start_step=1 means "needs orientation too" (split from original image).
|
|
||||||
start_step = 1 if used_original else 2
|
|
||||||
sub_sessions = await create_page_sub_sessions_full(
|
|
||||||
session_id, cached, img_bgr, page_splits, start_step=start_step,
|
|
||||||
)
|
|
||||||
duration = time.time() - t0
|
|
||||||
|
|
||||||
split_info: Dict[str, Any] = {
|
|
||||||
"multi_page": True,
|
|
||||||
"page_count": len(page_splits),
|
|
||||||
"page_splits": page_splits,
|
|
||||||
"used_original": used_original,
|
|
||||||
"duration_seconds": round(duration, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mark parent session as split and hidden from session list
|
|
||||||
await update_session_db(session_id, crop_result=split_info, status='split')
|
|
||||||
cached["crop_result"] = split_info
|
|
||||||
|
|
||||||
await append_pipeline_log(session_id, "page_split", {
|
|
||||||
"multi_page": True,
|
|
||||||
"page_count": len(page_splits),
|
|
||||||
}, duration_ms=int(duration * 1000))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
|
|
||||||
session_id, len(page_splits), duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
return {
|
|
||||||
"session_id": session_id,
|
|
||||||
**split_info,
|
|
||||||
"image_width": w,
|
|
||||||
"image_height": h,
|
|
||||||
"sub_sessions": sub_sessions,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,16 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_api.py
|
||||||
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Barrel re-export: merges routers from orientation_api and crop_api,
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_api")
|
||||||
and re-exports set_cache_ref for main.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from orientation_crop_helpers import set_cache_ref # noqa: F401
|
|
||||||
from orientation_api import router as _orientation_router
|
|
||||||
from crop_api import router as _crop_router
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(_orientation_router)
|
|
||||||
router.include_router(_crop_router)
|
|
||||||
|
|||||||
@@ -1,86 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_helpers.py
|
||||||
Orientation & Crop shared helpers - cache management and pipeline logging.
|
import importlib as _importlib
|
||||||
"""
|
import sys as _sys
|
||||||
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_helpers")
|
||||||
import logging
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
get_session_db,
|
|
||||||
get_session_image,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
|
|
||||||
_cache: Dict[str, Dict[str, Any]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
|
|
||||||
"""Set reference to the shared cache from ocr_pipeline_api."""
|
|
||||||
global _cache
|
|
||||||
_cache = cache
|
|
||||||
|
|
||||||
|
|
||||||
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
|
|
||||||
"""Get reference to the shared cache."""
|
|
||||||
return _cache
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_cached(session_id: str) -> Dict[str, Any]:
|
|
||||||
"""Ensure session is in cache, loading from DB if needed."""
|
|
||||||
if session_id in _cache:
|
|
||||||
return _cache[session_id]
|
|
||||||
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
cache_entry: Dict[str, Any] = {
|
|
||||||
"id": session_id,
|
|
||||||
**session,
|
|
||||||
"original_bgr": None,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": None,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
for img_type, bgr_key in [
|
|
||||||
("original", "original_bgr"),
|
|
||||||
("oriented", "oriented_bgr"),
|
|
||||||
("cropped", "cropped_bgr"),
|
|
||||||
("deskewed", "deskewed_bgr"),
|
|
||||||
("dewarped", "dewarped_bgr"),
|
|
||||||
]:
|
|
||||||
png_data = await get_session_image(session_id, img_type)
|
|
||||||
if png_data:
|
|
||||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
|
||||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
||||||
cache_entry[bgr_key] = bgr
|
|
||||||
|
|
||||||
_cache[session_id] = cache_entry
|
|
||||||
return cache_entry
|
|
||||||
|
|
||||||
|
|
||||||
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
|
|
||||||
"""Append a step entry to the pipeline log."""
|
|
||||||
from datetime import datetime
|
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
return
|
|
||||||
pipeline_log = session.get("pipeline_log") or {"steps": []}
|
|
||||||
pipeline_log["steps"].append({
|
|
||||||
"step": step,
|
|
||||||
"completed_at": datetime.utcnow().isoformat(),
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": duration_ms,
|
|
||||||
"metrics": metrics,
|
|
||||||
})
|
|
||||||
await update_session_db(session_id, pipeline_log=pipeline_log)
|
|
||||||
|
|||||||
@@ -1,33 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/page_crop.py
|
||||||
Page Crop — Barrel Re-export
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Content-based crop for scanned pages and book scans.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop")
|
||||||
|
|
||||||
Split into:
|
|
||||||
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
|
|
||||||
- page_crop_core.py — Main crop algorithm and format detection
|
|
||||||
|
|
||||||
All public names are re-exported here for backward compatibility.
|
|
||||||
License: Apache 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Core: main crop functions and format detection
|
|
||||||
from page_crop_core import ( # noqa: F401
|
|
||||||
PAPER_FORMATS,
|
|
||||||
detect_page_splits,
|
|
||||||
detect_and_crop_page,
|
|
||||||
_detect_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Edge detection helpers
|
|
||||||
from page_crop_edges import ( # noqa: F401
|
|
||||||
_INK_THRESHOLD,
|
|
||||||
_MIN_RUN_FRAC,
|
|
||||||
_detect_spine_shadow,
|
|
||||||
_detect_gutter_continuity,
|
|
||||||
_detect_left_edge_shadow,
|
|
||||||
_detect_right_edge_shadow,
|
|
||||||
_detect_top_bottom_edges,
|
|
||||||
_detect_edge_projection,
|
|
||||||
_filter_narrow_runs,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,342 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/page_crop_core.py
|
||||||
Page Crop - Core Crop and Format Detection
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Content-based crop for scanned pages and book scans. Detects the content
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_core")
|
||||||
boundary by analysing ink density projections and (for book scans) the
|
|
||||||
spine shadow gradient.
|
|
||||||
|
|
||||||
Extracted from page_crop.py to keep files under 500 LOC.
|
|
||||||
License: Apache 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Tuple
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from page_crop_edges import (
|
|
||||||
_detect_left_edge_shadow,
|
|
||||||
_detect_right_edge_shadow,
|
|
||||||
_detect_top_bottom_edges,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Known paper format aspect ratios (height / width, portrait orientation)
|
|
||||||
PAPER_FORMATS = {
|
|
||||||
"A4": 297.0 / 210.0, # 1.4143
|
|
||||||
"A5": 210.0 / 148.0, # 1.4189
|
|
||||||
"Letter": 11.0 / 8.5, # 1.2941
|
|
||||||
"Legal": 14.0 / 8.5, # 1.6471
|
|
||||||
"A3": 420.0 / 297.0, # 1.4141
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_page_splits(
|
|
||||||
img_bgr: np.ndarray,
|
|
||||||
) -> list:
|
|
||||||
"""Detect if the image is a multi-page spread and return split rectangles.
|
|
||||||
|
|
||||||
Uses **brightness** (not ink density) to find the spine area:
|
|
||||||
the scanner bed produces a characteristic gray strip where pages meet,
|
|
||||||
which is darker than the white paper on either side.
|
|
||||||
|
|
||||||
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
|
||||||
or an empty list if only one page is detected.
|
|
||||||
"""
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# Only check landscape-ish images (width > height * 1.15)
|
|
||||||
if w < h * 1.15:
|
|
||||||
return []
|
|
||||||
|
|
||||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
|
||||||
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
|
||||||
|
|
||||||
# Heavy smoothing to ignore individual text lines
|
|
||||||
kern = max(11, w // 50)
|
|
||||||
if kern % 2 == 0:
|
|
||||||
kern += 1
|
|
||||||
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
|
||||||
|
|
||||||
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
|
||||||
page_brightness = float(np.max(brightness_smooth))
|
|
||||||
if page_brightness < 100:
|
|
||||||
return [] # Very dark image, skip
|
|
||||||
|
|
||||||
# Spine threshold: significantly darker than the page
|
|
||||||
spine_thresh = page_brightness * 0.88
|
|
||||||
|
|
||||||
# Search in center region (30-70% of width)
|
|
||||||
center_lo = int(w * 0.30)
|
|
||||||
center_hi = int(w * 0.70)
|
|
||||||
|
|
||||||
# Find the darkest valley in the center region
|
|
||||||
center_brightness = brightness_smooth[center_lo:center_hi]
|
|
||||||
darkest_val = float(np.min(center_brightness))
|
|
||||||
|
|
||||||
if darkest_val >= spine_thresh:
|
|
||||||
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
|
||||||
darkest_val, spine_thresh)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find ALL contiguous dark runs in the center region
|
|
||||||
is_dark = center_brightness < spine_thresh
|
|
||||||
dark_runs: list = []
|
|
||||||
run_start = -1
|
|
||||||
for i in range(len(is_dark)):
|
|
||||||
if is_dark[i]:
|
|
||||||
if run_start < 0:
|
|
||||||
run_start = i
|
|
||||||
else:
|
|
||||||
if run_start >= 0:
|
|
||||||
dark_runs.append((run_start, i))
|
|
||||||
run_start = -1
|
|
||||||
if run_start >= 0:
|
|
||||||
dark_runs.append((run_start, len(is_dark)))
|
|
||||||
|
|
||||||
# Filter out runs that are too narrow (< 1% of image width)
|
|
||||||
min_spine_px = int(w * 0.01)
|
|
||||||
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
|
||||||
|
|
||||||
if not dark_runs:
|
|
||||||
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Score each dark run: prefer centered, dark, narrow valleys
|
|
||||||
center_region_len = center_hi - center_lo
|
|
||||||
image_center_in_region = (w * 0.5 - center_lo)
|
|
||||||
best_score = -1.0
|
|
||||||
best_start, best_end = dark_runs[0]
|
|
||||||
|
|
||||||
for rs, re in dark_runs:
|
|
||||||
run_width = re - rs
|
|
||||||
run_center = (rs + re) / 2.0
|
|
||||||
|
|
||||||
sigma = center_region_len * 0.15
|
|
||||||
dist = abs(run_center - image_center_in_region)
|
|
||||||
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
|
||||||
|
|
||||||
run_brightness = float(np.mean(center_brightness[rs:re]))
|
|
||||||
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
|
||||||
|
|
||||||
width_frac = run_width / w
|
|
||||||
if width_frac <= 0.05:
|
|
||||||
narrowness_bonus = 1.0
|
|
||||||
elif width_frac <= 0.15:
|
|
||||||
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
|
|
||||||
else:
|
|
||||||
narrowness_bonus = 0.0
|
|
||||||
|
|
||||||
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
|
|
||||||
center_lo + rs, center_lo + re, run_width,
|
|
||||||
center_factor, darkness_factor, narrowness_bonus, score,
|
|
||||||
)
|
|
||||||
|
|
||||||
if score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_start, best_end = rs, re
|
|
||||||
|
|
||||||
spine_w = best_end - best_start
|
|
||||||
spine_x = center_lo + best_start
|
|
||||||
spine_center = spine_x + spine_w // 2
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
|
||||||
spine_x, spine_x + spine_w, spine_w, best_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify: must have bright (paper) content on BOTH sides
|
|
||||||
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
|
||||||
right_end = center_lo + best_end
|
|
||||||
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
|
||||||
|
|
||||||
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
|
||||||
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
|
||||||
left_brightness, right_brightness, spine_thresh)
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
|
||||||
"left_paper=%.0f, right_paper=%.0f",
|
|
||||||
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
|
||||||
left_brightness, right_brightness,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Split at the spine center
|
|
||||||
split_points = [spine_center]
|
|
||||||
|
|
||||||
# Build page rectangles
|
|
||||||
pages: list = []
|
|
||||||
prev_x = 0
|
|
||||||
for i, sx in enumerate(split_points):
|
|
||||||
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
|
||||||
"height": h, "page_index": i})
|
|
||||||
prev_x = sx
|
|
||||||
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
|
||||||
"height": h, "page_index": len(split_points)})
|
|
||||||
|
|
||||||
# Filter out tiny pages (< 15% of total width)
|
|
||||||
pages = [p for p in pages if p["width"] >= w * 0.15]
|
|
||||||
if len(pages) < 2:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Re-index
|
|
||||||
for i, p in enumerate(pages):
|
|
||||||
p["page_index"] = i
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
|
||||||
len(pages), spine_w, split_points,
|
|
||||||
)
|
|
||||||
return pages
|
|
||||||
|
|
||||||
|
|
||||||
def detect_and_crop_page(
|
|
||||||
img_bgr: np.ndarray,
|
|
||||||
margin_frac: float = 0.01,
|
|
||||||
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
|
||||||
"""Detect content boundary and crop scanner/book borders.
|
|
||||||
|
|
||||||
Algorithm (4-edge detection):
|
|
||||||
1. Adaptive threshold -> binary (text=255, bg=0)
|
|
||||||
2. Left edge: spine-shadow detection via grayscale column means,
|
|
||||||
fallback to binary vertical projection
|
|
||||||
3. Right edge: binary vertical projection (last ink column)
|
|
||||||
4. Top/bottom edges: binary horizontal projection
|
|
||||||
5. Sanity checks, then crop with configurable margin
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
|
||||||
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (cropped_image, result_dict)
|
|
||||||
"""
|
|
||||||
h, w = img_bgr.shape[:2]
|
|
||||||
total_area = h * w
|
|
||||||
|
|
||||||
result: Dict[str, Any] = {
|
|
||||||
"crop_applied": False,
|
|
||||||
"crop_rect": None,
|
|
||||||
"crop_rect_pct": None,
|
|
||||||
"original_size": {"width": w, "height": h},
|
|
||||||
"cropped_size": {"width": w, "height": h},
|
|
||||||
"detected_format": None,
|
|
||||||
"format_confidence": 0.0,
|
|
||||||
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
|
||||||
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
|
||||||
}
|
|
||||||
|
|
||||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# --- Binarise with adaptive threshold ---
|
|
||||||
binary = cv2.adaptiveThreshold(
|
|
||||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
|
||||||
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Edge detection ---
|
|
||||||
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
|
||||||
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
|
||||||
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
|
||||||
|
|
||||||
# Compute border fractions
|
|
||||||
border_top = top_edge / h
|
|
||||||
border_bottom = (h - bottom_edge) / h
|
|
||||||
border_left = left_edge / w
|
|
||||||
border_right = (w - right_edge) / w
|
|
||||||
|
|
||||||
result["border_fractions"] = {
|
|
||||||
"top": round(border_top, 4),
|
|
||||||
"bottom": round(border_bottom, 4),
|
|
||||||
"left": round(border_left, 4),
|
|
||||||
"right": round(border_right, 4),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sanity: only crop if at least one edge has > 2% border
|
|
||||||
min_border = 0.02
|
|
||||||
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
|
||||||
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
|
||||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
|
||||||
return img_bgr, result
|
|
||||||
|
|
||||||
# Add margin
|
|
||||||
margin_x = int(w * margin_frac)
|
|
||||||
margin_y = int(h * margin_frac)
|
|
||||||
|
|
||||||
crop_x = max(0, left_edge - margin_x)
|
|
||||||
crop_y = max(0, top_edge - margin_y)
|
|
||||||
crop_x2 = min(w, right_edge + margin_x)
|
|
||||||
crop_y2 = min(h, bottom_edge + margin_y)
|
|
||||||
|
|
||||||
crop_w = crop_x2 - crop_x
|
|
||||||
crop_h = crop_y2 - crop_y
|
|
||||||
|
|
||||||
# Sanity: cropped area must be >= 40% of original
|
|
||||||
if crop_w * crop_h < 0.40 * total_area:
|
|
||||||
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
|
||||||
100.0 * crop_w * crop_h / total_area)
|
|
||||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
|
||||||
return img_bgr, result
|
|
||||||
|
|
||||||
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
|
||||||
|
|
||||||
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
|
||||||
|
|
||||||
result["crop_applied"] = True
|
|
||||||
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
|
||||||
result["crop_rect_pct"] = {
|
|
||||||
"x": round(100.0 * crop_x / w, 2),
|
|
||||||
"y": round(100.0 * crop_y / h, 2),
|
|
||||||
"width": round(100.0 * crop_w / w, 2),
|
|
||||||
"height": round(100.0 * crop_h / h, 2),
|
|
||||||
}
|
|
||||||
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
|
||||||
result["detected_format"] = detected_format
|
|
||||||
result["format_confidence"] = format_confidence
|
|
||||||
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
|
||||||
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
|
||||||
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
|
||||||
border_top * 100, border_bottom * 100,
|
|
||||||
border_left * 100, border_right * 100,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cropped, result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Format detection (kept as optional metadata)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
|
||||||
"""Detect paper format from dimensions by comparing aspect ratios."""
|
|
||||||
if width <= 0 or height <= 0:
|
|
||||||
return "unknown", 0.0
|
|
||||||
|
|
||||||
aspect = max(width, height) / min(width, height)
|
|
||||||
|
|
||||||
best_format = "unknown"
|
|
||||||
best_diff = float("inf")
|
|
||||||
|
|
||||||
for fmt, expected_ratio in PAPER_FORMATS.items():
|
|
||||||
diff = abs(aspect - expected_ratio)
|
|
||||||
if diff < best_diff:
|
|
||||||
best_diff = diff
|
|
||||||
best_format = fmt
|
|
||||||
|
|
||||||
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
|
||||||
|
|
||||||
if confidence < 0.3:
|
|
||||||
return "unknown", 0.0
|
|
||||||
|
|
||||||
return best_format, round(confidence, 3)
|
|
||||||
|
|||||||
@@ -1,388 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/page_crop_edges.py
|
||||||
Page Crop - Edge Detection Helpers
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Spine shadow detection, gutter continuity analysis, projection-based
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_edges")
|
||||||
edge detection, and narrow-run filtering for content cropping.
|
|
||||||
|
|
||||||
Extracted from page_crop.py to keep files under 500 LOC.
|
|
||||||
License: Apache 2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
|
||||||
_INK_THRESHOLD = 0.003 # 0.3%
|
|
||||||
|
|
||||||
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
|
||||||
_MIN_RUN_FRAC = 0.005 # 0.5%
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_spine_shadow(
|
|
||||||
gray: np.ndarray,
|
|
||||||
search_region: np.ndarray,
|
|
||||||
offset_x: int,
|
|
||||||
w: int,
|
|
||||||
side: str,
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Find the book spine center (darkest point) in a scanner shadow.
|
|
||||||
|
|
||||||
The scanner produces a gray strip where the book spine presses against
|
|
||||||
the glass. The darkest column in that strip is the spine center —
|
|
||||||
that's where we crop.
|
|
||||||
|
|
||||||
Distinguishes real spine shadows from text content by checking:
|
|
||||||
1. Strong brightness range (> 40 levels)
|
|
||||||
2. Darkest point is genuinely dark (< 180 mean brightness)
|
|
||||||
3. The dark area is a NARROW valley, not a text-content plateau
|
|
||||||
4. Brightness rises significantly toward the page content side
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gray: Full grayscale image (for context).
|
|
||||||
search_region: Column slice of the grayscale image to search in.
|
|
||||||
offset_x: X offset of search_region relative to full image.
|
|
||||||
w: Full image width.
|
|
||||||
side: 'left' or 'right' (for logging).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
X coordinate (in full image) of the spine center, or None.
|
|
||||||
"""
|
|
||||||
region_w = search_region.shape[1]
|
|
||||||
if region_w < 10:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Column-mean brightness in the search region
|
|
||||||
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
|
||||||
|
|
||||||
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
|
||||||
kernel_size = max(5, w // 100)
|
|
||||||
if kernel_size % 2 == 0:
|
|
||||||
kernel_size += 1
|
|
||||||
kernel = np.ones(kernel_size) / kernel_size
|
|
||||||
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
|
||||||
|
|
||||||
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
|
|
||||||
margin = kernel_size // 2
|
|
||||||
if region_w <= 2 * margin + 10:
|
|
||||||
return None
|
|
||||||
smoothed = smoothed_raw[margin:region_w - margin]
|
|
||||||
trim_offset = margin # offset of smoothed[0] relative to search_region
|
|
||||||
|
|
||||||
val_min = float(np.min(smoothed))
|
|
||||||
val_max = float(np.max(smoothed))
|
|
||||||
shadow_range = val_max - val_min
|
|
||||||
|
|
||||||
# --- Check 1: Strong brightness gradient ---
|
|
||||||
if shadow_range <= 40:
|
|
||||||
logger.debug(
|
|
||||||
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- Check 2: Darkest point must be genuinely dark ---
|
|
||||||
if val_min > 180:
|
|
||||||
logger.debug(
|
|
||||||
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
|
||||||
spine_local = spine_idx + trim_offset # index in search_region
|
|
||||||
trimmed_len = len(smoothed)
|
|
||||||
|
|
||||||
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
|
||||||
valley_thresh = val_min + shadow_range * 0.20
|
|
||||||
valley_mask = smoothed < valley_thresh
|
|
||||||
valley_width = int(np.sum(valley_mask))
|
|
||||||
max_valley_frac = 0.50
|
|
||||||
if valley_width > trimmed_len * max_valley_frac:
|
|
||||||
logger.debug(
|
|
||||||
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
|
||||||
side.capitalize(), valley_width, trimmed_len,
|
|
||||||
100.0 * valley_width / trimmed_len,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- Check 4: Brightness must rise toward page content ---
|
|
||||||
rise_check_w = max(5, trimmed_len // 5)
|
|
||||||
if side == "left":
|
|
||||||
right_start = min(spine_idx + 5, trimmed_len - 1)
|
|
||||||
right_end = min(right_start + rise_check_w, trimmed_len)
|
|
||||||
if right_end > right_start:
|
|
||||||
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
|
||||||
rise = rise_brightness - val_min
|
|
||||||
if rise < shadow_range * 0.3:
|
|
||||||
logger.debug(
|
|
||||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
|
||||||
side.capitalize(), rise, shadow_range * 0.3,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
else: # right
|
|
||||||
left_end = max(spine_idx - 5, 0)
|
|
||||||
left_start = max(left_end - rise_check_w, 0)
|
|
||||||
if left_end > left_start:
|
|
||||||
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
|
||||||
rise = rise_brightness - val_min
|
|
||||||
if rise < shadow_range * 0.3:
|
|
||||||
logger.debug(
|
|
||||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
|
||||||
side.capitalize(), rise, shadow_range * 0.3,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
spine_x = offset_x + spine_local
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
|
||||||
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
|
||||||
)
|
|
||||||
return spine_x
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_gutter_continuity(
|
|
||||||
gray: np.ndarray,
|
|
||||||
search_region: np.ndarray,
|
|
||||||
offset_x: int,
|
|
||||||
w: int,
|
|
||||||
side: str,
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Detect gutter shadow via vertical continuity analysis.
|
|
||||||
|
|
||||||
Camera book scans produce a subtle brightness gradient at the gutter
|
|
||||||
that is too faint for scanner-shadow detection (range < 40). However,
|
|
||||||
the gutter shadow has a unique property: it runs **continuously from
|
|
||||||
top to bottom** without interruption.
|
|
||||||
|
|
||||||
Algorithm:
|
|
||||||
1. Divide image into N horizontal strips (~60px each)
|
|
||||||
2. For each column, compute what fraction of strips are darker than
|
|
||||||
the page median (from the center 50% of the full image)
|
|
||||||
3. A "gutter column" has >= 75% of strips darker than page_median - d
|
|
||||||
4. Smooth the dark-fraction profile and find the transition point
|
|
||||||
5. Validate: gutter band must be 0.5%-10% of image width
|
|
||||||
"""
|
|
||||||
region_h, region_w = search_region.shape[:2]
|
|
||||||
if region_w < 20 or region_h < 100:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- 1. Divide into horizontal strips ---
|
|
||||||
strip_target_h = 60
|
|
||||||
n_strips = max(10, region_h // strip_target_h)
|
|
||||||
strip_h = region_h // n_strips
|
|
||||||
|
|
||||||
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
|
||||||
for s in range(n_strips):
|
|
||||||
y0 = s * strip_h
|
|
||||||
y1 = min((s + 1) * strip_h, region_h)
|
|
||||||
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
|
||||||
|
|
||||||
# --- 2. Page median from center 50% of full image ---
|
|
||||||
center_lo = w // 4
|
|
||||||
center_hi = 3 * w // 4
|
|
||||||
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
|
||||||
|
|
||||||
dark_thresh = page_median - 5.0
|
|
||||||
|
|
||||||
if page_median < 180:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- 3. Per-column dark fraction ---
|
|
||||||
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
|
||||||
dark_frac = dark_count / n_strips
|
|
||||||
|
|
||||||
# --- 4. Smooth and find transition ---
|
|
||||||
smooth_w = max(5, w // 100)
|
|
||||||
if smooth_w % 2 == 0:
|
|
||||||
smooth_w += 1
|
|
||||||
kernel = np.ones(smooth_w) / smooth_w
|
|
||||||
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
|
||||||
|
|
||||||
margin = smooth_w // 2
|
|
||||||
if region_w <= 2 * margin + 10:
|
|
||||||
return None
|
|
||||||
|
|
||||||
transition_thresh = 0.50
|
|
||||||
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
|
||||||
|
|
||||||
if peak_frac < 0.70:
|
|
||||||
logger.debug(
|
|
||||||
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
|
||||||
gutter_inner = None
|
|
||||||
|
|
||||||
if side == "right":
|
|
||||||
for x in range(peak_x, margin, -1):
|
|
||||||
if frac_smooth[x] < transition_thresh:
|
|
||||||
gutter_inner = x + 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for x in range(peak_x, region_w - margin):
|
|
||||||
if frac_smooth[x] < transition_thresh:
|
|
||||||
gutter_inner = x - 1
|
|
||||||
break
|
|
||||||
|
|
||||||
if gutter_inner is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- 5. Validate gutter width ---
|
|
||||||
if side == "right":
|
|
||||||
gutter_width = region_w - gutter_inner
|
|
||||||
else:
|
|
||||||
gutter_width = gutter_inner
|
|
||||||
|
|
||||||
min_gutter = max(3, int(w * 0.005))
|
|
||||||
max_gutter = int(w * 0.10)
|
|
||||||
|
|
||||||
if gutter_width < min_gutter:
|
|
||||||
logger.debug(
|
|
||||||
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
|
||||||
gutter_width, min_gutter,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if gutter_width > max_gutter:
|
|
||||||
logger.debug(
|
|
||||||
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
|
||||||
gutter_width, max_gutter,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if side == "right":
|
|
||||||
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
|
||||||
else:
|
|
||||||
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
|
||||||
|
|
||||||
brightness_drop = page_median - gutter_brightness
|
|
||||||
if brightness_drop < 3:
|
|
||||||
logger.debug(
|
|
||||||
"%s gutter: insufficient brightness drop (%.1f levels)",
|
|
||||||
side.capitalize(), brightness_drop,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
gutter_x = offset_x + gutter_inner
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
|
||||||
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
|
||||||
side.capitalize(), gutter_x, gutter_width,
|
|
||||||
100.0 * gutter_width / w, gutter_brightness, page_median,
|
|
||||||
brightness_drop, float(frac_smooth[gutter_inner]),
|
|
||||||
)
|
|
||||||
return gutter_x
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_left_edge_shadow(
|
|
||||||
gray: np.ndarray,
|
|
||||||
binary: np.ndarray,
|
|
||||||
w: int,
|
|
||||||
h: int,
|
|
||||||
) -> int:
|
|
||||||
"""Detect left content edge, accounting for book-spine shadow.
|
|
||||||
|
|
||||||
Tries three methods in order:
|
|
||||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
|
||||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
|
||||||
3. Binary projection fallback (first ink column)
|
|
||||||
"""
|
|
||||||
search_w = max(1, w // 4)
|
|
||||||
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
|
||||||
if spine_x is not None:
|
|
||||||
return spine_x
|
|
||||||
|
|
||||||
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
|
||||||
if gutter_x is not None:
|
|
||||||
return gutter_x
|
|
||||||
|
|
||||||
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_right_edge_shadow(
|
|
||||||
gray: np.ndarray,
|
|
||||||
binary: np.ndarray,
|
|
||||||
w: int,
|
|
||||||
h: int,
|
|
||||||
) -> int:
|
|
||||||
"""Detect right content edge, accounting for book-spine shadow.
|
|
||||||
|
|
||||||
Tries three methods in order:
|
|
||||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
|
||||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
|
||||||
3. Binary projection fallback (last ink column)
|
|
||||||
"""
|
|
||||||
search_w = max(1, w // 4)
|
|
||||||
right_start = w - search_w
|
|
||||||
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
|
||||||
if spine_x is not None:
|
|
||||||
return spine_x
|
|
||||||
|
|
||||||
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
|
||||||
if gutter_x is not None:
|
|
||||||
return gutter_x
|
|
||||||
|
|
||||||
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
|
||||||
"""Detect top and bottom content edges via binary horizontal projection."""
|
|
||||||
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
|
||||||
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
|
||||||
return top, bottom
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_edge_projection(
|
|
||||||
binary: np.ndarray,
|
|
||||||
axis: int,
|
|
||||||
from_start: bool,
|
|
||||||
dim: int,
|
|
||||||
) -> int:
|
|
||||||
"""Find the first/last row or column with ink density above threshold.
|
|
||||||
|
|
||||||
axis=0 -> project vertically (column densities) -> returns x position
|
|
||||||
axis=1 -> project horizontally (row densities) -> returns y position
|
|
||||||
|
|
||||||
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
|
||||||
"""
|
|
||||||
projection = np.mean(binary, axis=axis) / 255.0
|
|
||||||
|
|
||||||
ink_mask = projection >= _INK_THRESHOLD
|
|
||||||
|
|
||||||
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
|
||||||
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
|
||||||
|
|
||||||
ink_positions = np.where(ink_mask)[0]
|
|
||||||
if len(ink_positions) == 0:
|
|
||||||
return 0 if from_start else dim
|
|
||||||
|
|
||||||
if from_start:
|
|
||||||
return int(ink_positions[0])
|
|
||||||
else:
|
|
||||||
return int(ink_positions[-1])
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
|
||||||
"""Remove True-runs shorter than min_run pixels."""
|
|
||||||
if min_run <= 1:
|
|
||||||
return mask
|
|
||||||
|
|
||||||
result = mask.copy()
|
|
||||||
n = len(result)
|
|
||||||
i = 0
|
|
||||||
while i < n:
|
|
||||||
if result[i]:
|
|
||||||
start = i
|
|
||||||
while i < n and result[i]:
|
|
||||||
i += 1
|
|
||||||
if i - start < min_run:
|
|
||||||
result[start:i] = False
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,189 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/page_sub_sessions.py
|
||||||
Sub-session creation for multi-page spreads.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Used by both the page-split and crop steps when a double-page scan is detected.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_sub_sessions")
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid as uuid_mod
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from page_crop import detect_and_crop_page
|
|
||||||
from ocr_pipeline_session_store import (
|
|
||||||
create_session_db,
|
|
||||||
get_sub_sessions,
|
|
||||||
update_session_db,
|
|
||||||
)
|
|
||||||
from orientation_crop_helpers import get_cache_ref
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_page_sub_sessions(
|
|
||||||
parent_session_id: str,
|
|
||||||
parent_cached: dict,
|
|
||||||
full_img_bgr: np.ndarray,
|
|
||||||
page_splits: List[Dict[str, Any]],
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Create sub-sessions for each detected page in a multi-page spread.
|
|
||||||
|
|
||||||
Each page region is individually cropped, then stored as a sub-session
|
|
||||||
with its own cropped image ready for the rest of the pipeline.
|
|
||||||
"""
|
|
||||||
# Check for existing sub-sessions (idempotent)
|
|
||||||
existing = await get_sub_sessions(parent_session_id)
|
|
||||||
if existing:
|
|
||||||
return [
|
|
||||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
|
||||||
for i, s in enumerate(existing)
|
|
||||||
]
|
|
||||||
|
|
||||||
parent_name = parent_cached.get("name", "Scan")
|
|
||||||
parent_filename = parent_cached.get("filename", "scan.png")
|
|
||||||
|
|
||||||
sub_sessions: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
for page in page_splits:
|
|
||||||
pi = page["page_index"]
|
|
||||||
px, py = page["x"], page["y"]
|
|
||||||
pw, ph = page["width"], page["height"]
|
|
||||||
|
|
||||||
# Extract page region
|
|
||||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
|
||||||
|
|
||||||
# Crop each page individually (remove its own borders)
|
|
||||||
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
|
|
||||||
|
|
||||||
# Encode as PNG
|
|
||||||
ok, png_buf = cv2.imencode(".png", cropped_page)
|
|
||||||
page_png = png_buf.tobytes() if ok else b""
|
|
||||||
|
|
||||||
sub_id = str(uuid_mod.uuid4())
|
|
||||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
|
||||||
|
|
||||||
await create_session_db(
|
|
||||||
session_id=sub_id,
|
|
||||||
name=sub_name,
|
|
||||||
filename=parent_filename,
|
|
||||||
original_png=page_png,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-populate: set cropped = original (already cropped)
|
|
||||||
await update_session_db(
|
|
||||||
sub_id,
|
|
||||||
cropped_png=page_png,
|
|
||||||
crop_result=page_crop_info,
|
|
||||||
current_step=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
ch, cw = cropped_page.shape[:2]
|
|
||||||
sub_sessions.append({
|
|
||||||
"id": sub_id,
|
|
||||||
"name": sub_name,
|
|
||||||
"page_index": pi,
|
|
||||||
"source_rect": page,
|
|
||||||
"cropped_size": {"width": cw, "height": ch},
|
|
||||||
"detected_format": page_crop_info.get("detected_format"),
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
|
|
||||||
sub_id, pi + 1, px, pw, cw, ch,
|
|
||||||
)
|
|
||||||
|
|
||||||
return sub_sessions
|
|
||||||
|
|
||||||
|
|
||||||
async def create_page_sub_sessions_full(
|
|
||||||
parent_session_id: str,
|
|
||||||
parent_cached: dict,
|
|
||||||
full_img_bgr: np.ndarray,
|
|
||||||
page_splits: List[Dict[str, Any]],
|
|
||||||
start_step: int = 2,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
|
|
||||||
|
|
||||||
Unlike ``create_page_sub_sessions`` (used by the crop step), these
|
|
||||||
sub-sessions store the *uncropped* page region and start at
|
|
||||||
``start_step`` (default 2 = ready for deskew; 1 if orientation still
|
|
||||||
needed). Each page goes through its own pipeline independently,
|
|
||||||
which is essential for book spreads where each page has a different tilt.
|
|
||||||
"""
|
|
||||||
_cache = get_cache_ref()
|
|
||||||
|
|
||||||
# Idempotent: reuse existing sub-sessions
|
|
||||||
existing = await get_sub_sessions(parent_session_id)
|
|
||||||
if existing:
|
|
||||||
return [
|
|
||||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
|
||||||
for i, s in enumerate(existing)
|
|
||||||
]
|
|
||||||
|
|
||||||
parent_name = parent_cached.get("name", "Scan")
|
|
||||||
parent_filename = parent_cached.get("filename", "scan.png")
|
|
||||||
|
|
||||||
sub_sessions: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
for page in page_splits:
|
|
||||||
pi = page["page_index"]
|
|
||||||
px, py = page["x"], page["y"]
|
|
||||||
pw, ph = page["width"], page["height"]
|
|
||||||
|
|
||||||
# Extract RAW page region — NO individual cropping here; each
|
|
||||||
# sub-session will run its own crop step after deskew + dewarp.
|
|
||||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
|
||||||
|
|
||||||
# Encode as PNG
|
|
||||||
ok, png_buf = cv2.imencode(".png", page_bgr)
|
|
||||||
page_png = png_buf.tobytes() if ok else b""
|
|
||||||
|
|
||||||
sub_id = str(uuid_mod.uuid4())
|
|
||||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
|
||||||
|
|
||||||
await create_session_db(
|
|
||||||
session_id=sub_id,
|
|
||||||
name=sub_name,
|
|
||||||
filename=parent_filename,
|
|
||||||
original_png=page_png,
|
|
||||||
)
|
|
||||||
|
|
||||||
# start_step=2 -> ready for deskew (orientation already done on spread)
|
|
||||||
# start_step=1 -> needs its own orientation (split from original image)
|
|
||||||
await update_session_db(sub_id, current_step=start_step)
|
|
||||||
|
|
||||||
# Cache the BGR so the pipeline can start immediately
|
|
||||||
_cache[sub_id] = {
|
|
||||||
"id": sub_id,
|
|
||||||
"filename": parent_filename,
|
|
||||||
"name": sub_name,
|
|
||||||
"original_bgr": page_bgr,
|
|
||||||
"oriented_bgr": None,
|
|
||||||
"cropped_bgr": None,
|
|
||||||
"deskewed_bgr": None,
|
|
||||||
"dewarped_bgr": None,
|
|
||||||
"orientation_result": None,
|
|
||||||
"crop_result": None,
|
|
||||||
"deskew_result": None,
|
|
||||||
"dewarp_result": None,
|
|
||||||
"ground_truth": {},
|
|
||||||
"current_step": start_step,
|
|
||||||
}
|
|
||||||
|
|
||||||
rh, rw = page_bgr.shape[:2]
|
|
||||||
sub_sessions.append({
|
|
||||||
"id": sub_id,
|
|
||||||
"name": sub_name,
|
|
||||||
"page_index": pi,
|
|
||||||
"source_rect": page,
|
|
||||||
"image_size": {"width": rw, "height": rh},
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
|
|
||||||
sub_id, pi + 1, px, pw, rw, rh,
|
|
||||||
)
|
|
||||||
|
|
||||||
return sub_sessions
|
|
||||||
|
|||||||
@@ -1,102 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/scan_quality.py
|
||||||
Scan Quality Assessment — Measures image quality before OCR.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Computes blur score, contrast score, and an overall quality rating.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.scan_quality")
|
||||||
Used to gate enhancement steps and warn users about degraded scans.
|
|
||||||
|
|
||||||
All operations use OpenCV (Apache-2.0), no additional dependencies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Thresholds (empirically tuned on textbook scans)
|
|
||||||
BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry
|
|
||||||
CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast
|
|
||||||
CONFIDENCE_GOOD = 40 # OCR min confidence for good scans
|
|
||||||
CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScanQualityReport:
|
|
||||||
"""Result of scan quality assessment."""
|
|
||||||
blur_score: float # Laplacian variance (higher = sharper)
|
|
||||||
contrast_score: float # Grayscale std deviation (higher = more contrast)
|
|
||||||
brightness: float # Mean grayscale value (0-255)
|
|
||||||
is_blurry: bool
|
|
||||||
is_low_contrast: bool
|
|
||||||
is_degraded: bool # True if any quality issue detected
|
|
||||||
quality_pct: int # 0-100 overall quality estimate
|
|
||||||
recommended_min_conf: int # Recommended OCR confidence threshold
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
|
|
||||||
def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport:
|
|
||||||
"""
|
|
||||||
Assess the quality of a scanned image.
|
|
||||||
|
|
||||||
Uses:
|
|
||||||
- Laplacian variance for blur detection
|
|
||||||
- Grayscale standard deviation for contrast
|
|
||||||
- Mean brightness for exposure assessment
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_bgr: BGR image (numpy array from OpenCV)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ScanQualityReport with scores and recommendations
|
|
||||||
"""
|
|
||||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# Blur detection: Laplacian variance
|
|
||||||
# Higher = sharper edges = better quality
|
|
||||||
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
|
||||||
blur_score = float(laplacian.var())
|
|
||||||
|
|
||||||
# Contrast: standard deviation of grayscale
|
|
||||||
contrast_score = float(np.std(gray))
|
|
||||||
|
|
||||||
# Brightness: mean grayscale
|
|
||||||
brightness = float(np.mean(gray))
|
|
||||||
|
|
||||||
# Quality flags
|
|
||||||
is_blurry = blur_score < BLUR_THRESHOLD
|
|
||||||
is_low_contrast = contrast_score < CONTRAST_THRESHOLD
|
|
||||||
is_degraded = is_blurry or is_low_contrast
|
|
||||||
|
|
||||||
# Overall quality percentage (simple weighted combination)
|
|
||||||
blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50)
|
|
||||||
contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50)
|
|
||||||
quality_pct = int(min(100, blur_pct + contrast_pct))
|
|
||||||
|
|
||||||
# Recommended confidence threshold
|
|
||||||
recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD
|
|
||||||
|
|
||||||
report = ScanQualityReport(
|
|
||||||
blur_score=round(blur_score, 1),
|
|
||||||
contrast_score=round(contrast_score, 1),
|
|
||||||
brightness=round(brightness, 1),
|
|
||||||
is_blurry=is_blurry,
|
|
||||||
is_low_contrast=is_low_contrast,
|
|
||||||
is_degraded=is_degraded,
|
|
||||||
quality_pct=quality_pct,
|
|
||||||
recommended_min_conf=recommended_min_conf,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Scan quality: blur={report.blur_score} "
|
|
||||||
f"contrast={report.contrast_score} "
|
|
||||||
f"quality={report.quality_pct}% "
|
|
||||||
f"degraded={report.is_degraded} "
|
|
||||||
f"min_conf={report.recommended_min_conf}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return report
|
|
||||||
|
|||||||
@@ -1,261 +1,4 @@
|
|||||||
"""
|
# Backward-compat shim -- module moved to ocr/pipeline/vision_fusion.py
|
||||||
Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading.
|
import importlib as _importlib
|
||||||
|
import sys as _sys
|
||||||
Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL.
|
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.vision_fusion")
|
||||||
The LLM can read degraded text using context understanding and visual inspection,
|
|
||||||
while OCR coordinates provide structural hints (where text is, column positions).
|
|
||||||
|
|
||||||
Uses Ollama API (same pattern as handwriting_htr_api.py).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import httpx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
||||||
VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b")
|
|
||||||
|
|
||||||
# Document category → prompt context
|
|
||||||
CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = {
|
|
||||||
"vokabelseite": {
|
|
||||||
"label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)",
|
|
||||||
"columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.",
|
|
||||||
},
|
|
||||||
"woerterbuch": {
|
|
||||||
"label": "Woerterbuchseite",
|
|
||||||
"columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.",
|
|
||||||
},
|
|
||||||
"arbeitsblatt": {
|
|
||||||
"label": "Arbeitsblatt",
|
|
||||||
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
|
|
||||||
},
|
|
||||||
"buchseite": {
|
|
||||||
"label": "Schulbuchseite",
|
|
||||||
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _group_words_into_lines(
|
|
||||||
words: List[Dict], y_tolerance: float = 15.0,
|
|
||||||
) -> List[List[Dict]]:
|
|
||||||
"""Group OCR words into lines by Y-proximity."""
|
|
||||||
if not words:
|
|
||||||
return []
|
|
||||||
sorted_w = sorted(words, key=lambda w: w.get("top", 0))
|
|
||||||
lines: List[List[Dict]] = [[sorted_w[0]]]
|
|
||||||
for w in sorted_w[1:]:
|
|
||||||
last_line = lines[-1]
|
|
||||||
avg_y = sum(ww["top"] for ww in last_line) / len(last_line)
|
|
||||||
if abs(w["top"] - avg_y) <= y_tolerance:
|
|
||||||
last_line.append(w)
|
|
||||||
else:
|
|
||||||
lines.append([w])
|
|
||||||
# Sort words within each line by X
|
|
||||||
for line in lines:
|
|
||||||
line.sort(key=lambda w: w.get("left", 0))
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def _build_ocr_context(words: List[Dict], img_h: int) -> str:
|
|
||||||
"""Build a text description of OCR words with positions for the prompt."""
|
|
||||||
lines = _group_words_into_lines(words)
|
|
||||||
context_parts = []
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
word_descs = []
|
|
||||||
for w in line:
|
|
||||||
text = w.get("text", "").strip()
|
|
||||||
x = w.get("left", 0)
|
|
||||||
conf = w.get("conf", 0)
|
|
||||||
marker = " (?)" if conf < 50 else ""
|
|
||||||
word_descs.append(f'x={x} "{text}"{marker}')
|
|
||||||
avg_y = int(sum(w["top"] for w in line) / len(line))
|
|
||||||
context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}")
|
|
||||||
return "\n".join(context_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_prompt(
|
|
||||||
ocr_context: str, category: str, img_w: int, img_h: int,
|
|
||||||
) -> str:
|
|
||||||
"""Build the Vision-LLM prompt with OCR context and document type."""
|
|
||||||
cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"])
|
|
||||||
|
|
||||||
return f"""Du siehst eine eingescannte {cat_info['label']}.
|
|
||||||
{cat_info['columns']}
|
|
||||||
|
|
||||||
Die OCR-Software hat folgende Woerter an diesen Positionen erkannt.
|
|
||||||
Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch:
|
|
||||||
|
|
||||||
{ocr_context}
|
|
||||||
|
|
||||||
Bildgroesse: {img_w} x {img_h} Pixel.
|
|
||||||
|
|
||||||
AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle.
|
|
||||||
- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst
|
|
||||||
- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist,
|
|
||||||
gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht)
|
|
||||||
- Behalte die Reihenfolge bei
|
|
||||||
|
|
||||||
Antworte NUR mit einem JSON-Array, keine Erklaerungen:
|
|
||||||
[
|
|
||||||
{{"row": 1, "english": "...", "german": "...", "example": "..."}},
|
|
||||||
{{"row": 2, "english": "...", "german": "...", "example": "..."}}
|
|
||||||
]"""
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_llm_response(response_text: str) -> Optional[List[Dict]]:
|
|
||||||
"""Parse the LLM JSON response, handling markdown code blocks."""
|
|
||||||
text = response_text.strip()
|
|
||||||
|
|
||||||
# Strip markdown code block if present
|
|
||||||
if text.startswith("```"):
|
|
||||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
|
||||||
text = re.sub(r"\s*```\s*$", "", text)
|
|
||||||
|
|
||||||
# Try to find JSON array
|
|
||||||
match = re.search(r"\[[\s\S]*\]", text)
|
|
||||||
if not match:
|
|
||||||
logger.warning("vision_fuse_ocr: no JSON array found in LLM response")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
if not isinstance(data, list):
|
|
||||||
return None
|
|
||||||
return data
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning(f"vision_fuse_ocr: JSON parse error: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _vocab_rows_to_words(
|
|
||||||
rows: List[Dict], img_w: int, img_h: int,
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Convert LLM vocab rows back to word dicts for grid building.
|
|
||||||
|
|
||||||
Distributes words across estimated column positions so the
|
|
||||||
existing grid builder can process them normally.
|
|
||||||
"""
|
|
||||||
words = []
|
|
||||||
# Estimate column positions (3-column vocab layout)
|
|
||||||
col_positions = [
|
|
||||||
(0.02, 0.28), # EN: 2%-28% of width
|
|
||||||
(0.30, 0.55), # DE: 30%-55%
|
|
||||||
(0.57, 0.98), # Example: 57%-98%
|
|
||||||
]
|
|
||||||
|
|
||||||
median_h = max(15, img_h // (len(rows) * 3)) if rows else 20
|
|
||||||
y_step = max(median_h + 5, img_h // max(len(rows), 1))
|
|
||||||
|
|
||||||
for i, row in enumerate(rows):
|
|
||||||
y = int(i * y_step + 20)
|
|
||||||
row_num = row.get("row", i + 1)
|
|
||||||
|
|
||||||
for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([
|
|
||||||
("english", col_positions[0]),
|
|
||||||
("german", col_positions[1]),
|
|
||||||
("example", col_positions[2]),
|
|
||||||
]):
|
|
||||||
text = (row.get(field) or "").strip()
|
|
||||||
if not text:
|
|
||||||
continue
|
|
||||||
x = int(x_start_pct * img_w)
|
|
||||||
w = int((x_end_pct - x_start_pct) * img_w)
|
|
||||||
words.append({
|
|
||||||
"text": text,
|
|
||||||
"left": x,
|
|
||||||
"top": y,
|
|
||||||
"width": w,
|
|
||||||
"height": median_h,
|
|
||||||
"conf": 95, # LLM-corrected → high confidence
|
|
||||||
"_source": "vision_llm",
|
|
||||||
"_row": row_num,
|
|
||||||
"_col_type": f"column_{['en', 'de', 'example'][col_idx]}",
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words")
|
|
||||||
return words
|
|
||||||
|
|
||||||
|
|
||||||
async def vision_fuse_ocr(
|
|
||||||
img_bgr: np.ndarray,
|
|
||||||
ocr_words: List[Dict],
|
|
||||||
document_category: str = "vokabelseite",
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Fuse traditional OCR results with Vision-LLM reading.
|
|
||||||
|
|
||||||
Sends the image + OCR word positions to Qwen2.5-VL which can:
|
|
||||||
- Read degraded text that traditional OCR cannot
|
|
||||||
- Use document context (knows what a vocab table looks like)
|
|
||||||
- Merge continuation rows (understands table structure)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_bgr: The cropped/dewarped scan image (BGR)
|
|
||||||
ocr_words: Traditional OCR word list with positions
|
|
||||||
document_category: Type of document being scanned
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Corrected word list in same format as input, ready for grid building.
|
|
||||||
Falls back to original ocr_words on error.
|
|
||||||
"""
|
|
||||||
img_h, img_w = img_bgr.shape[:2]
|
|
||||||
|
|
||||||
# Build OCR context string
|
|
||||||
ocr_context = _build_ocr_context(ocr_words, img_h)
|
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
prompt = _build_prompt(ocr_context, document_category, img_w, img_h)
|
|
||||||
|
|
||||||
# Encode image as base64
|
|
||||||
_, img_encoded = cv2.imencode(".png", img_bgr)
|
|
||||||
img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8")
|
|
||||||
|
|
||||||
# Call Qwen2.5-VL via Ollama
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{OLLAMA_BASE_URL}/api/generate",
|
|
||||||
json={
|
|
||||||
"model": VISION_FUSION_MODEL,
|
|
||||||
"prompt": prompt,
|
|
||||||
"images": [img_b64],
|
|
||||||
"stream": False,
|
|
||||||
"options": {"temperature": 0.1, "num_predict": 4096},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
response_text = data.get("response", "").strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"vision_fuse_ocr: Ollama call failed: {e}")
|
|
||||||
return ocr_words # Fallback to original
|
|
||||||
|
|
||||||
if not response_text:
|
|
||||||
logger.warning("vision_fuse_ocr: empty LLM response")
|
|
||||||
return ocr_words
|
|
||||||
|
|
||||||
# Parse JSON response
|
|
||||||
rows = _parse_llm_response(response_text)
|
|
||||||
if not rows:
|
|
||||||
logger.warning(
|
|
||||||
"vision_fuse_ocr: could not parse LLM response, "
|
|
||||||
"first 200 chars: %s", response_text[:200],
|
|
||||||
)
|
|
||||||
return ocr_words
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows "
|
|
||||||
f"(from {len(ocr_words)} OCR words)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert back to word format for grid building
|
|
||||||
return _vocab_rows_to_words(rows, img_w, img_h)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user