Restructure: Move ocr_pipeline + labeling + crop into ocr/ package
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 21:51:43 +02:00
parent 59c400b9aa
commit 0504d22b8e
98 changed files with 10351 additions and 10152 deletions

View File

@@ -47,6 +47,7 @@
# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context
**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01
**/ocr/pipeline/auto_steps.py | owner=klausur | reason=Same file moved to ocr/ package | review=2026-10-01
# Legacy — TEMPORAER bis Refactoring abgeschlossen
# Dateien hier werden Phase fuer Phase abgearbeitet und entfernt.

View File

@@ -1,290 +1,4 @@
"""
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
Auto-crop, manual crop, and skip-crop for scanner/book borders.
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from page_crop import detect_and_crop_page, detect_page_splits
from ocr_pipeline_session_store import get_sub_sessions, update_session_db
from orientation_crop_helpers import ensure_cached, append_pipeline_log
from page_sub_sessions import create_page_sub_sessions
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/crop")
async def auto_crop(session_id: str):
"""Auto-detect and crop scanner/book borders.
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
Falls back to oriented -> original if earlier steps were skipped.
If the image is a multi-page spread (e.g. book on scanner), it will
automatically split into separate sub-sessions per page, crop each
individually, and return the split info.
"""
cached = await ensure_cached(session_id)
# Use dewarped (preferred), fall back to oriented, then original
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
t0 = time.time()
# --- Check for existing sub-sessions (from page-split step) ---
# If page-split already created sub-sessions, skip multi-page detection
# in the crop step. Each sub-session runs its own crop independently.
existing_subs = await get_sub_sessions(session_id)
if existing_subs:
crop_result = cached.get("crop_result") or {}
if crop_result.get("multi_page"):
# Already split -- just return the existing info
duration = time.time() - t0
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"sub_sessions": [
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
for i, s in enumerate(existing_subs)
],
"note": "Page split was already performed; each sub-session runs its own crop.",
}
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
page_splits = detect_page_splits(img_bgr)
if page_splits and len(page_splits) >= 2:
# Multi-page spread detected -- create sub-sessions
sub_sessions = await create_page_sub_sessions(
session_id, cached, img_bgr, page_splits,
)
duration = time.time() - t0
crop_info: Dict[str, Any] = {
"crop_applied": True,
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"duration_seconds": round(duration, 2),
}
cached["crop_result"] = crop_info
# Store the first page as the main cropped image for backward compat
first_page = page_splits[0]
first_bgr = img_bgr[
first_page["y"]:first_page["y"] + first_page["height"],
first_page["x"]:first_page["x"] + first_page["width"],
].copy()
first_cropped, _ = detect_and_crop_page(first_bgr)
cached["cropped_bgr"] = first_cropped
ok, png_buf = cv2.imencode(".png", first_cropped)
await update_session_db(
session_id,
cropped_png=png_buf.tobytes() if ok else b"",
crop_result=crop_info,
current_step=5,
status='split',
)
logger.info(
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
session_id, len(page_splits), duration,
)
await append_pipeline_log(session_id, "crop", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
h, w = first_cropped.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
"sub_sessions": sub_sessions,
}
# --- Single page (normal) ---
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
duration = time.time() - t0
crop_info["duration_seconds"] = round(duration, 2)
crop_info["multi_page"] = False
# Encode cropped image
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
# Update cache
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_info
# Persist to DB
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_info,
current_step=5,
)
logger.info(
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
session_id, crop_info["crop_applied"],
crop_info.get("detected_format", "?"),
duration,
)
await append_pipeline_log(session_id, "crop", {
"crop_applied": crop_info["crop_applied"],
"detected_format": crop_info.get("detected_format"),
"format_confidence": crop_info.get("format_confidence"),
}, duration_ms=int(duration * 1000))
h, w = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
class ManualCropRequest(BaseModel):
x: float # percentage 0-100
y: float # percentage 0-100
width: float # percentage 0-100
height: float # percentage 0-100
@router.post("/sessions/{session_id}/crop/manual")
async def manual_crop(session_id: str, req: ManualCropRequest):
"""Manually crop using percentage coordinates."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
h, w = img_bgr.shape[:2]
# Convert percentages to pixels
px_x = int(w * req.x / 100.0)
px_y = int(h * req.y / 100.0)
px_w = int(w * req.width / 100.0)
px_h = int(h * req.height / 100.0)
# Clamp
px_x = max(0, min(px_x, w - 1))
px_y = max(0, min(px_y, h - 1))
px_w = max(1, min(px_w, w - px_x))
px_h = max(1, min(px_h, h - px_y))
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": True,
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
"width": round(req.width, 2), "height": round(req.height, 2)},
"original_size": {"width": w, "height": h},
"cropped_size": {"width": px_w, "height": px_h},
"method": "manual",
}
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
ch, cw = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": cw,
"image_height": ch,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
@router.post("/sessions/{session_id}/crop/skip")
async def skip_crop(session_id: str):
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available")
h, w = img_bgr.shape[:2]
# Store the dewarped image as cropped (identity crop)
success, png_buf = cv2.imencode(".png", img_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": False,
"skipped": True,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
}
cached["cropped_bgr"] = img_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
# Backward-compat shim -- module moved to ocr/pipeline/crop_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.crop_api")

View File

@@ -1,4 +1,4 @@
# Backward-compat shim -- module moved to ocr\/pipeline.py
# Backward-compat shim -- module moved to ocr/cv_pipeline.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline")
_sys.modules[__name__] = _importlib.import_module("ocr.cv_pipeline")

View File

@@ -6,4 +6,4 @@ Backward-compatible re-exports: consumers can still use
"""
from .types import * # noqa: F401,F403
from .pipeline import * # noqa: F401,F403
from .cv_pipeline import * # noqa: F401,F403

View File

@@ -0,0 +1,6 @@
"""
OCR Labeling sub-package — labeling API, models, helpers, and route handlers.
Moved from backend/ flat modules (ocr_labeling_*.py).
Backward-compatible shim files remain at the old locations.
"""

View File

@@ -0,0 +1,81 @@
"""
OCR Labeling API — Barrel Re-export
Split into:
- ocr_labeling_models.py — Pydantic models and constants
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
- ocr_labeling_routes.py — Session/queue/labeling route handlers
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
All public names are re-exported here for backward compatibility.
"""
# Models
from .models import ( # noqa: F401
LOCAL_STORAGE_PATH,
SessionCreate,
SessionResponse,
ItemResponse,
ConfirmRequest,
CorrectRequest,
SkipRequest,
ExportRequest,
StatsResponse,
)
# Helpers
from .helpers import ( # noqa: F401
VISION_OCR_AVAILABLE,
PADDLEOCR_AVAILABLE,
TROCR_AVAILABLE,
DONUT_AVAILABLE,
MINIO_AVAILABLE,
TRAINING_EXPORT_AVAILABLE,
compute_image_hash,
run_ocr_on_image,
run_vision_ocr_wrapper,
run_paddleocr_wrapper,
run_trocr_wrapper,
run_donut_wrapper,
save_image_locally,
get_image_url,
)
# Conditional re-exports from helpers' optional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
except ImportError:
pass
try:
from training_export_service import ( # noqa: F401
TrainingExportService,
TrainingSample,
get_training_export_service,
)
except ImportError:
pass
try:
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
except ImportError:
pass
try:
from services.trocr_service import run_trocr_ocr # noqa: F401
except ImportError:
pass
try:
from services.donut_ocr_service import run_donut_ocr # noqa: F401
except ImportError:
pass
try:
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
except ImportError:
pass
# Routes (router is the main export for app.include_router)
from .routes import router # noqa: F401
from .upload_routes import router as upload_router # noqa: F401

View File

@@ -0,0 +1,205 @@
"""
OCR Labeling - Helper Functions and OCR Wrappers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
"""
import os
import hashlib
from .models import LOCAL_STORAGE_PATH
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Try to import PaddleOCR from hybrid_vocab_extractor
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path

View File

@@ -0,0 +1,86 @@
"""
OCR Labeling - Pydantic Models and Constants
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
"""
import os
from pydantic import BaseModel
from typing import Optional, Dict
from datetime import datetime
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None

View File

@@ -0,0 +1,241 @@
"""
OCR Labeling - Session and Labeling Route Handlers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
Endpoints:
- POST /sessions - Create labeling session
- GET /sessions - List sessions
- GET /sessions/{id} - Get session
- GET /queue - Get labeling queue
- GET /items/{id} - Get item
- POST /confirm - Confirm OCR
- POST /correct - Correct ground truth
- POST /skip - Skip item
- GET /stats - Get statistics
"""
from fastapi import APIRouter, HTTPException, Query
from typing import Optional, List
from datetime import datetime
import uuid
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
)
from .models import (
SessionCreate, SessionResponse, ItemResponse,
ConfirmRequest, CorrectRequest, SkipRequest,
)
from .helpers import get_image_url
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# =============================================================================
# Session Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""Create a new OCR labeling session."""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Queue and Item Endpoints
# =============================================================================
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""Get items from the labeling queue."""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Labeling Action Endpoints
# =============================================================================
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""Confirm that OCR text is correct."""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""Save corrected ground truth for an item."""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""Skip an item (unusable image, etc.)."""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin",
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""Get labeling statistics."""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats

View File

@@ -0,0 +1,313 @@
"""
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
Endpoints:
- POST /sessions/{id}/upload - Upload images for labeling
- POST /run-ocr/{item_id} - Run OCR on existing item
- POST /export - Export training data
- GET /training-samples - List training samples
- GET /images/{path} - Serve images from local storage
- GET /exports - List exports
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from typing import Optional, List
import uuid
import os
from metrics_db import (
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_item,
export_training_samples,
get_training_samples,
)
from .models import (
ExportRequest,
LOCAL_STORAGE_PATH,
)
from .helpers import (
compute_image_hash, run_ocr_on_image,
save_image_locally,
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
)
# Conditional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image
except ImportError:
pass
try:
from training_export_service import TrainingSample, get_training_export_service
except ImportError:
pass
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None),
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
content = await file.read()
image_hash = compute_image_hash(content)
item_id = str(uuid.uuid4())
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
if MINIO_AVAILABLE:
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
image_path = save_image_locally(session_id, item_id, content, extension)
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf':
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.post("/export")
async def export_data(request: ExportRequest):
"""Export labeled data for training."""
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin",
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""Serve an image from local storage."""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""Run OCR on an existing item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""List all available training data exports."""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")

View File

@@ -0,0 +1,8 @@
"""
OCR Pipeline sub-package — API endpoints, session management, overlays,
geometry steps, word detection, regression testing, and related utilities.
Moved from backend/ flat modules (ocr_pipeline_*.py, page_crop*.py,
orientation_*.py, crop_api.py, etc.).
Backward-compatible shim files remain at the old locations.
"""

View File

@@ -0,0 +1,63 @@
"""
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
Thin wrapper that assembles all sub-module routers into a single
composite router. Backward-compatible: main.py and tests can still
import ``router``, ``_cache``, and helper functions from here.
Sub-modules (each < 1 000 lines):
ocr_pipeline_common shared state, cache, Pydantic models, helpers
ocr_pipeline_sessions session CRUD, image serving, doc-type
ocr_pipeline_geometry deskew, dewarp, structure, columns
ocr_pipeline_rows row detection, box-overlay helper
ocr_pipeline_words word detection (SSE), paddle-direct, word GT
ocr_pipeline_ocr_merge paddle/tesseract merge helpers, kombi endpoints
ocr_pipeline_postprocess LLM review, reconstruction, export, validation
ocr_pipeline_auto auto-mode orchestrator, reprocess
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
# ---------------------------------------------------------------------------
# Shared state (imported by main.py and orientation_crop_api.py)
# ---------------------------------------------------------------------------
from .common import ( # noqa: F401 re-exported
_cache,
_BORDER_GHOST_CHARS,
_filter_border_ghost_words,
)
# ---------------------------------------------------------------------------
# Sub-module routers
# ---------------------------------------------------------------------------
from .sessions import router as _sessions_router
from .geometry import router as _geometry_router
from .rows import router as _rows_router
from .words import router as _words_router
from .ocr_merge import (
router as _ocr_merge_router,
# Re-export for test backward compatibility
_split_paddle_multi_words, # noqa: F401
_group_words_into_rows, # noqa: F401
_merge_row_sequences, # noqa: F401
_merge_paddle_tesseract, # noqa: F401
)
from .postprocess import router as _postprocess_router
from .auto import router as _auto_router
from .regression import router as _regression_router
# ---------------------------------------------------------------------------
# Composite router (used by main.py)
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_sessions_router)
router.include_router(_geometry_router)
router.include_router(_rows_router)
router.include_router(_words_router)
router.include_router(_ocr_merge_router)
router.include_router(_postprocess_router)
router.include_router(_auto_router)
router.include_router(_regression_router)

View File

@@ -0,0 +1,23 @@
"""
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
Split into submodules:
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from .reprocess import router as _reprocess_router
from .auto_steps import router as _steps_router
# Combine both sub-routers into a single router for backwards compatibility.
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
router.include_router(_reprocess_router)
router.include_router(_steps_router)
__all__ = ["router"]

View File

@@ -0,0 +1,84 @@
"""
OCR Pipeline Auto-Mode Helpers.
VLM shear detection, SSE event formatting, and request models.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from typing import Any, Dict
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
payload = {"step": step, "status": status, **data}
return f"data: {json.dumps(payload)}\n\n"
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}

View File

@@ -0,0 +1,528 @@
"""
OCR Pipeline Auto-Mode Orchestrator.
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
)
from .session_store import (
get_session_db,
update_session_db,
)
from .auto_helpers import (
RunAutoRequest,
auto_sse_event as _auto_sse_event,
detect_shear_with_vlm as _detect_shear_with_vlm,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew -- straighten the scan
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
3. Columns -- detect column layout
4. Rows -- detect row layout
5. Words -- OCR each cell
6. LLM review -- correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# Step 1: Deskew
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# Step 2: Dewarp
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# Step 3: Columns
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# Step 4: Rows
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed -- cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# Step 5: Words (OCR)
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# Step 6: LLM Review (optional)
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# Final event
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -0,0 +1,293 @@
"""
OCR Pipeline Column Detection Endpoints (Step 5)
Detect invisible columns, manual column override, and ground truth.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from dataclasses import asdict
from datetime import datetime
from typing import Dict, List
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
_detect_header_footer_gaps,
_detect_sub_columns,
classify_column_types,
create_layout_image,
create_ocr_image,
analyze_layout,
detect_column_geometry_zoned,
expand_narrow_columns,
)
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualColumnsRequest,
ColumnGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/columns")
async def detect_columns(session_id: str):
"""Run column detection on the cropped (or dewarped) image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
# -----------------------------------------------------------------------
# Sub-sessions (box crops): skip column detection entirely.
# Instead, create a single pseudo-column spanning the full image width.
# Also run Tesseract + binarization here so that the row detection step
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
# instead of falling back to detect_column_geometry() which may fail
# on small box images with < 5 words.
# -----------------------------------------------------------------------
session = await get_session_db(session_id)
if session and session.get("parent_session_id"):
h, w = img_bgr.shape[:2]
# Binarize + invert for row detection (horizontal projection profile)
ocr_img = create_ocr_image(img_bgr)
inv = cv2.bitwise_not(ocr_img)
# Run Tesseract to get word bounding boxes.
try:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
import pytesseract
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
word_dicts = []
for i in range(len(data['text'])):
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
text = str(data['text'][i]).strip()
if conf < 30 or not text:
continue
word_dicts.append({
'text': text, 'conf': conf,
'left': int(data['left'][i]),
'top': int(data['top'][i]),
'width': int(data['width'][i]),
'height': int(data['height'][i]),
})
# Log all words including low-confidence ones for debugging
all_count = sum(1 for i in range(len(data['text']))
if str(data['text'][i]).strip())
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
for i in range(len(data['text']))
if str(data['text'][i]).strip()
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
if low_conf:
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
except Exception as e:
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
word_dicts = []
# Cache intermediates for row detection (detect_rows reuses these)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (0, w, 0, h)
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": w, "height": h,
}],
"zones": None,
"boxes_detected": 0,
"duration_seconds": 0,
"method": "sub_session_pseudo_column",
}
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
return {"session_id": session_id, **column_result}
t0 = time.time()
# Binarized image for layout analysis
ocr_img = create_ocr_image(img_bgr)
h, w = ocr_img.shape[:2]
# Phase A: Zone-aware geometry detection
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
boxes_detected = 0
if zoned_result is None:
# Fallback to projection-based layout
layout_img = create_layout_image(img_bgr)
regions = analyze_layout(layout_img, ocr_img)
zones_data = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
content_w = right_x - left_x
boxes_detected = len(boxes)
# Cache intermediates for row detection (avoids second Tesseract run)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
cached["_zones_data"] = zones_data
cached["_boxes_detected"] = boxes_detected
# Detect header/footer early so sub-column clustering ignores them
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
# Split sub-columns (e.g. page references) before classification
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
# Expand narrow columns (sub-columns are often very narrow)
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
# Phase B: Content-based classification
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
duration = time.time() - t0
columns = [asdict(r) for r in regions]
# Determine classification methods used
methods = list(set(
c.get("classification_method", "") for c in columns
if c.get("classification_method")
))
column_result = {
"columns": columns,
"classification_methods": methods,
"duration_seconds": round(duration, 2),
"boxes_detected": boxes_detected,
}
# Add zone data when boxes are present
if zones_data and boxes_detected > 0:
column_result["zones"] = zones_data
# Persist to DB -- also invalidate downstream results (rows, words)
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
# Update cache
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
col_count = len([c for c in columns if c["type"].startswith("column")])
logger.info(f"OCR Pipeline: columns session {session_id}: "
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
img_w = img_bgr.shape[1]
await _append_pipeline_log(session_id, "columns", {
"total_columns": len(columns),
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
"column_types": [c["type"] for c in columns],
"boxes_detected": boxes_detected,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**column_result,
}
@router.post("/sessions/{session_id}/columns/manual")
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
"""Override detected columns with manual definitions."""
column_result = {
"columns": req.columns,
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None)
if session_id in _cache:
_cache[session_id]["column_result"] = column_result
_cache[session_id].pop("row_result", None)
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
f"{len(req.columns)} columns set")
return {"session_id": session_id, **column_result}
@router.post("/sessions/{session_id}/ground-truth/columns")
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
"""Save ground truth feedback for the column detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_columns": req.corrected_columns,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"column_result": session.get("column_result"),
}
ground_truth["columns"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/columns")
async def get_column_ground_truth(session_id: str):
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
columns_gt = ground_truth.get("columns")
if not columns_gt:
raise HTTPException(status_code=404, detail="No column ground truth saved")
return {
"session_id": session_id,
"columns_gt": columns_gt,
"columns_auto": session.get("column_result"),
}

View File

@@ -0,0 +1,354 @@
"""
Shared common module for the OCR pipeline.
Contains in-memory cache, helper functions, Pydantic request models,
pipeline logging, and border-ghost word filtering used by the pipeline
API endpoints and related modules.
"""
import logging
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import HTTPException
from pydantic import BaseModel
from .session_store import get_session_db, get_session_image, update_session_db
__all__ = [
# Cache
"_cache",
# Helper functions
"_get_base_image_png",
"_load_session_to_cache",
"_get_cached",
# Pydantic models
"ManualDeskewRequest",
"DeskewGroundTruthRequest",
"ManualDewarpRequest",
"CombinedAdjustRequest",
"DewarpGroundTruthRequest",
"VALID_DOCUMENT_CATEGORIES",
"UpdateSessionRequest",
"ManualColumnsRequest",
"ColumnGroundTruthRequest",
"ManualRowsRequest",
"RowGroundTruthRequest",
"RemoveHandwritingRequest",
# Pipeline log
"_append_pipeline_log",
# Border-ghost filter
"_BORDER_GHOST_CHARS",
"_filter_border_ghost_words",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# In-memory cache for active sessions (BGR numpy arrays for processing)
# DB is source of truth, cache holds BGR arrays during active processing.
# ---------------------------------------------------------------------------
_cache: Dict[str, Dict[str, Any]] = {}
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
"""Get the best available base image for a session (cropped > dewarped > original)."""
for img_type in ("cropped", "dewarped", "original"):
png_data = await get_session_image(session_id, img_type)
if png_data:
return png_data
return None
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
if session_id in _cache:
return _cache[session_id]
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
# Decode images from DB into BGR numpy arrays
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
# Sub-sessions: original image IS the cropped box region.
# Promote original_bgr to cropped_bgr so downstream steps find it.
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
_cache[session_id] = cache_entry
return cache_entry
def _get_cached(session_id: str) -> Dict[str, Any]:
"""Get from cache or raise 404."""
entry = _cache.get(session_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
return entry
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class ManualDeskewRequest(BaseModel):
angle: float
class DeskewGroundTruthRequest(BaseModel):
is_correct: bool
corrected_angle: Optional[float] = None
notes: Optional[str] = None
class ManualDewarpRequest(BaseModel):
shear_degrees: float
class CombinedAdjustRequest(BaseModel):
rotation_degrees: float = 0.0
shear_degrees: float = 0.0
class DewarpGroundTruthRequest(BaseModel):
is_correct: bool
corrected_shear: Optional[float] = None
notes: Optional[str] = None
VALID_DOCUMENT_CATEGORIES = {
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
}
class UpdateSessionRequest(BaseModel):
name: Optional[str] = None
document_category: Optional[str] = None
class ManualColumnsRequest(BaseModel):
columns: List[Dict[str, Any]]
class ColumnGroundTruthRequest(BaseModel):
is_correct: bool
corrected_columns: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class ManualRowsRequest(BaseModel):
rows: List[Dict[str, Any]]
class RowGroundTruthRequest(BaseModel):
is_correct: bool
corrected_rows: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class RemoveHandwritingRequest(BaseModel):
method: str = "auto" # "auto" | "telea" | "ns"
target_ink: str = "all" # "all" | "colored" | "pencil"
dilation: int = 2 # mask dilation iterations (0-5)
use_source: str = "auto" # "original" | "deskewed" | "auto"
# ---------------------------------------------------------------------------
# Pipeline Log Helper
# ---------------------------------------------------------------------------
async def _append_pipeline_log(
session_id: str,
step_name: str,
metrics: Dict[str, Any],
success: bool = True,
duration_ms: Optional[int] = None,
):
"""Append a step entry to the session's pipeline_log JSONB."""
session = await get_session_db(session_id)
if not session:
return
log = session.get("pipeline_log") or {"steps": []}
if not isinstance(log, dict):
log = {"steps": []}
entry = {
"step": step_name,
"completed_at": datetime.utcnow().isoformat(),
"success": success,
"metrics": metrics,
}
if duration_ms is not None:
entry["duration_ms"] = duration_ms
log.setdefault("steps", []).append(entry)
await update_session_db(session_id, pipeline_log=log)
# ---------------------------------------------------------------------------
# Border-ghost word filter
# ---------------------------------------------------------------------------
# Characters that OCR produces when reading box-border lines.
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—_~.,;:'\"")
def _filter_border_ghost_words(
word_result: Dict,
boxes: List,
) -> int:
"""Remove OCR words that are actually box border lines.
A word is considered a border ghost when it sits on a known box edge
(left, right, top, or bottom) and looks like a line artefact (narrow
aspect ratio or text consists only of line-like characters).
After removing ghost cells, columns that have become empty are also
removed from ``columns_used`` so the grid no longer shows phantom
columns.
Modifies *word_result* in-place and returns the number of removed cells.
"""
if not boxes or not word_result:
return 0
cells = word_result.get("cells")
if not cells:
return 0
# Build border bands — vertical (X) and horizontal (Y)
x_bands = [] # list of (x_lo, x_hi)
y_bands = [] # list of (y_lo, y_hi)
for b in boxes:
bx = b.x if hasattr(b, "x") else b.get("x", 0)
by = b.y if hasattr(b, "y") else b.get("y", 0)
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
margin = max(bt * 2, 10) + 6 # generous margin
# Vertical edges (left / right)
x_bands.append((bx - margin, bx + margin))
x_bands.append((bx + bw - margin, bx + bw + margin))
# Horizontal edges (top / bottom)
y_bands.append((by - margin, by + margin))
y_bands.append((by + bh - margin, by + bh + margin))
img_w = word_result.get("image_width", 1)
img_h = word_result.get("image_height", 1)
def _is_ghost(cell: Dict) -> bool:
text = (cell.get("text") or "").strip()
if not text:
return False
# Compute absolute pixel position
if cell.get("bbox_px"):
px = cell["bbox_px"]
cx = px["x"] + px["w"] / 2
cy = px["y"] + px["h"] / 2
cw = px["w"]
ch = px["h"]
elif cell.get("bbox_pct"):
pct = cell["bbox_pct"]
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
cw = (pct["w"] / 100) * img_w
ch = (pct["h"] / 100) * img_h
else:
return False
# Check if center sits on a vertical or horizontal border
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
if not on_vertical and not on_horizontal:
return False
# Very short text (1-2 chars) on a border → very likely ghost
if len(text) <= 2:
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
if ch > 0 and cw / ch < 0.5:
return True
if cw > 0 and ch / cw < 0.5:
return True
# Text is only border-ghost characters?
if all(c in _BORDER_GHOST_CHARS for c in text):
return True
# Longer text but still only ghost chars and very narrow
if all(c in _BORDER_GHOST_CHARS for c in text):
if ch > 0 and cw / ch < 0.35:
return True
if cw > 0 and ch / cw < 0.35:
return True
return True # all ghost chars on a border → remove
return False
before = len(cells)
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
removed = before - len(word_result["cells"])
# --- Remove empty columns from columns_used ---
columns_used = word_result.get("columns_used")
if removed and columns_used and len(columns_used) > 1:
remaining_cells = word_result["cells"]
occupied_cols = {c.get("col_index") for c in remaining_cells}
before_cols = len(columns_used)
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
# Re-index columns and remap cell col_index values
if len(columns_used) < before_cols:
old_to_new = {}
for new_i, col in enumerate(columns_used):
old_to_new[col["index"]] = new_i
col["index"] = new_i
for cell in remaining_cells:
old_ci = cell.get("col_index")
if old_ci in old_to_new:
cell["col_index"] = old_to_new[old_ci]
word_result["columns_used"] = columns_used
logger.info("border-ghost: removed %d empty column(s), %d remaining",
before_cols - len(columns_used), len(columns_used))
if removed:
# Update summary counts
summary = word_result.get("summary", {})
summary["total_cells"] = len(word_result["cells"])
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
word_result["summary"] = summary
gs = word_result.get("grid_shape", {})
gs["total_cells"] = len(word_result["cells"])
if columns_used is not None:
gs["cols"] = len(columns_used)
word_result["grid_shape"] = gs
return removed

View File

@@ -0,0 +1,290 @@
"""
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
Auto-crop, manual crop, and skip-crop for scanner/book borders.
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from .page_crop import detect_and_crop_page, detect_page_splits
from .session_store import get_sub_sessions, update_session_db
from .orientation_crop_helpers import ensure_cached, append_pipeline_log
from .page_sub_sessions import create_page_sub_sessions
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/crop")
async def auto_crop(session_id: str):
"""Auto-detect and crop scanner/book borders.
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
Falls back to oriented -> original if earlier steps were skipped.
If the image is a multi-page spread (e.g. book on scanner), it will
automatically split into separate sub-sessions per page, crop each
individually, and return the split info.
"""
cached = await ensure_cached(session_id)
# Use dewarped (preferred), fall back to oriented, then original
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
t0 = time.time()
# --- Check for existing sub-sessions (from page-split step) ---
# If page-split already created sub-sessions, skip multi-page detection
# in the crop step. Each sub-session runs its own crop independently.
existing_subs = await get_sub_sessions(session_id)
if existing_subs:
crop_result = cached.get("crop_result") or {}
if crop_result.get("multi_page"):
# Already split -- just return the existing info
duration = time.time() - t0
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"sub_sessions": [
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
for i, s in enumerate(existing_subs)
],
"note": "Page split was already performed; each sub-session runs its own crop.",
}
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
page_splits = detect_page_splits(img_bgr)
if page_splits and len(page_splits) >= 2:
# Multi-page spread detected -- create sub-sessions
sub_sessions = await create_page_sub_sessions(
session_id, cached, img_bgr, page_splits,
)
duration = time.time() - t0
crop_info: Dict[str, Any] = {
"crop_applied": True,
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"duration_seconds": round(duration, 2),
}
cached["crop_result"] = crop_info
# Store the first page as the main cropped image for backward compat
first_page = page_splits[0]
first_bgr = img_bgr[
first_page["y"]:first_page["y"] + first_page["height"],
first_page["x"]:first_page["x"] + first_page["width"],
].copy()
first_cropped, _ = detect_and_crop_page(first_bgr)
cached["cropped_bgr"] = first_cropped
ok, png_buf = cv2.imencode(".png", first_cropped)
await update_session_db(
session_id,
cropped_png=png_buf.tobytes() if ok else b"",
crop_result=crop_info,
current_step=5,
status='split',
)
logger.info(
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
session_id, len(page_splits), duration,
)
await append_pipeline_log(session_id, "crop", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
h, w = first_cropped.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
"sub_sessions": sub_sessions,
}
# --- Single page (normal) ---
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
duration = time.time() - t0
crop_info["duration_seconds"] = round(duration, 2)
crop_info["multi_page"] = False
# Encode cropped image
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
# Update cache
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_info
# Persist to DB
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_info,
current_step=5,
)
logger.info(
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
session_id, crop_info["crop_applied"],
crop_info.get("detected_format", "?"),
duration,
)
await append_pipeline_log(session_id, "crop", {
"crop_applied": crop_info["crop_applied"],
"detected_format": crop_info.get("detected_format"),
"format_confidence": crop_info.get("format_confidence"),
}, duration_ms=int(duration * 1000))
h, w = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
class ManualCropRequest(BaseModel):
x: float # percentage 0-100
y: float # percentage 0-100
width: float # percentage 0-100
height: float # percentage 0-100
@router.post("/sessions/{session_id}/crop/manual")
async def manual_crop(session_id: str, req: ManualCropRequest):
"""Manually crop using percentage coordinates."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
h, w = img_bgr.shape[:2]
# Convert percentages to pixels
px_x = int(w * req.x / 100.0)
px_y = int(h * req.y / 100.0)
px_w = int(w * req.width / 100.0)
px_h = int(h * req.height / 100.0)
# Clamp
px_x = max(0, min(px_x, w - 1))
px_y = max(0, min(px_y, h - 1))
px_w = max(1, min(px_w, w - px_x))
px_h = max(1, min(px_h, h - px_y))
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": True,
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
"width": round(req.width, 2), "height": round(req.height, 2)},
"original_size": {"width": w, "height": h},
"cropped_size": {"width": px_w, "height": px_h},
"method": "manual",
}
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
ch, cw = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": cw,
"image_height": ch,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
@router.post("/sessions/{session_id}/crop/skip")
async def skip_crop(session_id: str):
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available")
h, w = img_bgr.shape[:2]
# Store the dewarped image as cropped (identity crop)
success, png_buf = cv2.imencode(".png", img_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": False,
"skipped": True,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
}
cached["cropped_bgr"] = img_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}

View File

@@ -0,0 +1,236 @@
"""
OCR Pipeline Deskew Endpoints (Step 2)
Auto deskew, manual deskew, and ground truth for the deskew step.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from datetime import datetime
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
deskew_two_pass,
)
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualDeskewRequest,
DeskewGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/deskew")
async def auto_deskew(session_id: str):
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
# Ensure session is in cache
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
# Deskew runs right after orientation -- use oriented image, fall back to original
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
t0 = time.time()
# Two-pass deskew: iterative (+-5 deg) + word-alignment residual check
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
# Also run individual methods for reporting (non-authoritative)
try:
_, angle_hough = deskew_image(img_bgr.copy())
except Exception:
angle_hough = 0.0
success_enc, png_orig = cv2.imencode(".png", img_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
angle_wa = 0.0
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
duration = time.time() - t0
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
)
# Encode as PNG
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = deskewed_png_buf.tobytes() if success else b""
# Create binarized version
binarized_png = None
try:
binarized = create_ocr_image(deskewed_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception as e:
logger.warning(f"Binarization failed: {e}")
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
deskew_result = {
"angle_hough": round(angle_hough, 3),
"angle_word_alignment": round(angle_wa, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"angle_applied": round(angle_applied, 3),
"method_used": method_used,
"confidence": round(confidence, 2),
"duration_seconds": round(duration, 2),
"two_pass_debug": two_pass_debug,
}
# Update cache
cached["deskewed_bgr"] = deskewed_bgr
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
"current_step": 3,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: deskew session {session_id}: "
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
f"textline={angle_textline:.2f} "
f"-> {method_used} total={angle_applied:.2f}")
await _append_pipeline_log(session_id, "deskew", {
"angle_applied": round(angle_applied, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"confidence": round(confidence, 2),
"method": method_used,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**deskew_result,
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
}
@router.post("/sessions/{session_id}/deskew/manual")
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
"""Apply a manual rotation angle to the oriented image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
angle = max(-5.0, min(5.0, req.angle))
h, w = img_bgr.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
deskewed_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(rotated)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(angle, 3),
"method_used": "manual",
}
# Update cache
cached["deskewed_bgr"] = rotated
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
return {
"session_id": session_id,
"angle_applied": round(angle, 3),
"method_used": "manual",
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
}
@router.post("/sessions/{session_id}/ground-truth/deskew")
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
"""Save ground truth feedback for the deskew step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_angle": req.corrected_angle,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"deskew_result": session.get("deskew_result"),
}
ground_truth["deskew"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
# Update cache
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
return {"session_id": session_id, "ground_truth": gt}

View File

@@ -0,0 +1,346 @@
"""
OCR Pipeline Dewarp Endpoints
Auto dewarp (with VLM/CV ensemble), manual dewarp, combined
rotation+shear adjustment, and ground truth.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import json
import logging
import os
import re
import time
from datetime import datetime
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException, Query
from cv_vocab_pipeline import (
_apply_shear,
create_ocr_image,
dewarp_image,
dewarp_image_manual,
)
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualDewarpRequest,
CombinedAdjustRequest,
DewarpGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
@router.post("/sessions/{session_id}/dewarp")
async def auto_dewarp(
session_id: str,
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
):
"""Detect and correct vertical shear on the deskewed image.
Methods:
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
- **cv**: CV ensemble only (same as ensemble)
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
"""
if method not in ("ensemble", "cv", "vlm"):
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
t0 = time.time()
if method == "vlm":
# Encode deskewed image to PNG for VLM
success, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
duration = time.time() - t0
# Encode as PNG
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
"detections": dewarp_info.get("detections", []),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
await _append_pipeline_log(session_id, "dewarp", {
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"method": dewarp_info["method"],
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**dewarp_result,
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/dewarp/manual")
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
"""Apply shear correction with a manual angle."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
if abs(shear_deg) < 0.001:
dewarped_bgr = deskewed_bgr
else:
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
)
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
return {
"session_id": session_id,
"shear_degrees": round(shear_deg, 3),
"method_used": "manual",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/adjust-combined")
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
"""Apply rotation + shear combined to the original image.
Used by the fine-tuning sliders to preview arbitrary rotation/shear
combinations without re-running the full deskew/dewarp pipeline.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
rotation = max(-15.0, min(15.0, req.rotation_degrees))
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
h, w = img_bgr.shape[:2]
result_bgr = img_bgr
# Step 1: Apply rotation
if abs(rotation) >= 0.001:
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
# Step 2: Apply shear
if abs(shear_deg) >= 0.001:
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
# Encode
success, png_buf = cv2.imencode(".png", result_bgr)
dewarped_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(result_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
# Build combined result dicts
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(rotation, 3),
"method_used": "manual_combined",
}
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual_combined",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["deskewed_bgr"] = result_bgr
cached["dewarped_bgr"] = result_bgr
cached["deskew_result"] = deskew_result
cached["dewarp_result"] = dewarp_result
# Persist to DB
db_update = {
"dewarped_png": dewarped_png,
"deskew_result": deskew_result,
"dewarp_result": dewarp_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
db_update["deskewed_png"] = dewarped_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
return {
"session_id": session_id,
"rotation_degrees": round(rotation, 3),
"shear_degrees": round(shear_deg, 3),
"method_used": "manual_combined",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/ground-truth/dewarp")
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
"""Save ground truth feedback for the dewarp step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_shear": req.corrected_shear,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"dewarp_result": session.get("dewarp_result"),
}
ground_truth["dewarp"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
return {"session_id": session_id, "ground_truth": gt}

View File

@@ -0,0 +1,27 @@
"""
OCR Pipeline Geometry API (barrel re-export)
This module was split into:
- ocr_pipeline_deskew.py (Deskew endpoints)
- ocr_pipeline_dewarp.py (Dewarp endpoints)
- ocr_pipeline_structure.py (Structure detection + exclude regions)
- ocr_pipeline_columns.py (Column detection + ground truth)
The `router` object is assembled here by including all sub-routers.
Importers that did `from ocr_pipeline_geometry import router` continue to work.
"""
from fastapi import APIRouter
from .deskew import router as _deskew_router
from .dewarp import router as _dewarp_router
from .structure import router as _structure_router
from .columns import router as _columns_router
# Assemble the combined router.
# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix.
router = APIRouter()
router.include_router(_deskew_router)
router.include_router(_dewarp_router)
router.include_router(_structure_router)
router.include_router(_columns_router)

View File

@@ -0,0 +1,209 @@
"""
OCR Pipeline LLM Review — LLM-based correction endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
from datetime import datetime
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
llm_review_entries,
llm_review_entries_streaming,
)
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_append_pipeline_log,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 8: LLM Review
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}

View File

@@ -0,0 +1,272 @@
"""
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
Extracted from ocr_pipeline_ocr_merge.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import List
logger = logging.getLogger(__name__)
def _split_paddle_multi_words(words: list) -> list:
"""Split PaddleOCR multi-word boxes into individual word boxes.
PaddleOCR often returns entire phrases as a single box, e.g.
"More than 200 singers took part in the" with one bounding box.
This splits them into individual words with proportional widths.
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
"""
import re
result = []
for w in words:
raw_text = w.get("text", "").strip()
if not raw_text:
continue
# Split on whitespace, before "[" (IPA), and after "!" before letter
tokens = re.split(
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
)
tokens = [t for t in tokens if t]
if len(tokens) <= 1:
result.append(w)
else:
# Split proportionally by character count
total_chars = sum(len(t) for t in tokens)
if total_chars == 0:
continue
n_gaps = len(tokens) - 1
gap_px = w["width"] * 0.02
usable_w = w["width"] - gap_px * n_gaps
cursor = w["left"]
for t in tokens:
token_w = max(1, usable_w * len(t) / total_chars)
result.append({
"text": t,
"left": round(cursor),
"top": w["top"],
"width": round(token_w),
"height": w["height"],
"conf": w.get("conf", 0),
})
cursor += token_w + gap_px
return result
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
"""Group words into rows by Y-position clustering.
Words whose vertical centers are within `row_gap` pixels are on the same row.
Returns list of rows, each row is a list of words sorted left-to-right.
"""
if not words:
return []
# Sort by vertical center
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
rows: list = []
current_row: list = [sorted_words[0]]
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
for w in sorted_words[1:]:
cy = w["top"] + w.get("height", 0) / 2
if abs(cy - current_cy) <= row_gap:
current_row.append(w)
else:
# Sort current row left-to-right before saving
rows.append(sorted(current_row, key=lambda w: w["left"]))
current_row = [w]
current_cy = cy
if current_row:
rows.append(sorted(current_row, key=lambda w: w["left"]))
return rows
def _row_center_y(row: list) -> float:
"""Average vertical center of a row of words."""
if not row:
return 0.0
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
"""Merge two word sequences from the same row using sequence alignment.
Both sequences are sorted left-to-right. Walk through both simultaneously:
- If words match (same/similar text): take Paddle text with averaged coords
- If they don't match: the extra word is unique to one engine, include it
"""
merged = []
pi, ti = 0, 0
while pi < len(paddle_row) and ti < len(tess_row):
pw = paddle_row[pi]
tw = tess_row[ti]
pt = pw.get("text", "").lower().strip()
tt = tw.get("text", "").lower().strip()
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
# Spatial overlap check
spatial_match = False
if not is_same:
overlap_left = max(pw["left"], tw["left"])
overlap_right = min(
pw["left"] + pw.get("width", 0),
tw["left"] + tw.get("width", 0),
)
overlap_w = max(0, overlap_right - overlap_left)
min_w = min(pw.get("width", 1), tw.get("width", 1))
if min_w > 0 and overlap_w / min_w >= 0.4:
is_same = True
spatial_match = True
if is_same:
pc = pw.get("conf", 80)
tc = tw.get("conf", 50)
total = pc + tc
if total == 0:
total = 1
if spatial_match and pc < tc:
best_text = tw["text"]
else:
best_text = pw["text"]
merged.append({
"text": best_text,
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
"conf": max(pc, tc),
})
pi += 1
ti += 1
else:
paddle_ahead = any(
tess_row[t].get("text", "").lower().strip() == pt
for t in range(ti + 1, min(ti + 4, len(tess_row)))
)
tess_ahead = any(
paddle_row[p].get("text", "").lower().strip() == tt
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
)
if paddle_ahead and not tess_ahead:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
elif tess_ahead and not paddle_ahead:
merged.append(pw)
pi += 1
else:
if pw["left"] <= tw["left"]:
merged.append(pw)
pi += 1
else:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
while pi < len(paddle_row):
merged.append(paddle_row[pi])
pi += 1
while ti < len(tess_row):
tw = tess_row[ti]
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
return merged
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
if not paddle_words and not tess_words:
return []
if not paddle_words:
return [w for w in tess_words if w.get("conf", 0) >= 40]
if not tess_words:
return list(paddle_words)
paddle_rows = _group_words_into_rows(paddle_words)
tess_rows = _group_words_into_rows(tess_words)
used_tess_rows: set = set()
merged_all: list = []
for pr in paddle_rows:
pr_cy = _row_center_y(pr)
best_dist, best_tri = float("inf"), -1
for tri, tr in enumerate(tess_rows):
if tri in used_tess_rows:
continue
tr_cy = _row_center_y(tr)
dist = abs(pr_cy - tr_cy)
if dist < best_dist:
best_dist, best_tri = dist, tri
max_row_dist = max(
max((w.get("height", 20) for w in pr), default=20),
15,
)
if best_tri >= 0 and best_dist <= max_row_dist:
tr = tess_rows[best_tri]
used_tess_rows.add(best_tri)
merged_all.extend(_merge_row_sequences(pr, tr))
else:
merged_all.extend(pr)
for tri, tr in enumerate(tess_rows):
if tri not in used_tess_rows:
for tw in tr:
if tw.get("conf", 0) >= 40:
merged_all.append(tw)
return merged_all
def _deduplicate_words(words: list) -> list:
"""Remove duplicate words with same text at overlapping positions."""
if not words:
return words
result: list = []
for w in words:
wt = w.get("text", "").lower().strip()
if not wt:
continue
is_dup = False
w_right = w["left"] + w.get("width", 0)
w_bottom = w["top"] + w.get("height", 0)
for existing in result:
et = existing.get("text", "").lower().strip()
if wt != et:
continue
ox_l = max(w["left"], existing["left"])
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
ox = max(0, ox_r - ox_l)
min_w = min(w.get("width", 1), existing.get("width", 1))
if min_w <= 0 or ox / min_w < 0.5:
continue
oy_t = max(w["top"], existing["top"])
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
oy = max(0, oy_b - oy_t)
min_h = min(w.get("height", 1), existing.get("height", 1))
if min_h > 0 and oy / min_h >= 0.5:
is_dup = True
break
if not is_dup:
result.append(w)
removed = len(words) - len(result)
if removed:
logger.info("dedup: removed %d duplicate words", removed)
return result

View File

@@ -0,0 +1,266 @@
"""
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
Merge helper functions live in ocr_merge_helpers.py.
This module re-exports them for backward compatibility.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from cv_words_first import build_grid_from_words
from .common import _cache, _append_pipeline_log
from .session_store import get_session_image, update_session_db
# Re-export merge helpers for backward compatibility
from .merge_helpers import ( # noqa: F401
_split_paddle_multi_words,
_group_words_into_rows,
_row_center_y,
_merge_row_sequences,
_merge_paddle_tesseract,
_deduplicate_words,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
def _run_tesseract_words(img_bgr) -> list:
"""Run Tesseract OCR on an image and return word dicts."""
from PIL import Image
import pytesseract
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
data = pytesseract.image_to_data(
pil_img, lang="eng+deu",
config="--psm 6 --oem 3",
output_type=pytesseract.Output.DICT,
)
tess_words = []
for i in range(len(data["text"])):
text = str(data["text"][i]).strip()
conf_raw = str(data["conf"][i])
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
if not text or conf < 20:
continue
tess_words.append({
"text": text,
"left": data["left"][i],
"top": data["top"][i],
"width": data["width"][i],
"height": data["height"][i],
"conf": conf,
})
return tess_words
def _build_kombi_word_result(
cells: list,
columns_meta: list,
img_w: int,
img_h: int,
duration: float,
engine_name: str,
raw_engine_words: list,
raw_engine_words_split: list,
tess_words: list,
merged_words: list,
raw_engine_key: str = "raw_paddle_words",
raw_split_key: str = "raw_paddle_words_split",
) -> dict:
"""Build the word_result dict for kombi endpoints."""
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
return {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": engine_name,
"grid_method": engine_name,
raw_engine_key: raw_engine_words,
raw_split_key: raw_engine_words_split,
"raw_tesseract_words": tess_words,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
},
}
async def _load_session_image(session_id: str):
"""Load preprocessed image for kombi endpoints."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
return img_png, img_bgr
# ---------------------------------------------------------------------------
# Kombi endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/paddle-kombi")
async def paddle_kombi(session_id: str):
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
paddle_words = await ocr_region_paddle(img_bgr, region=None)
if not paddle_words:
paddle_words = []
tess_words = _run_tesseract_words(img_bgr)
paddle_words_split = _split_paddle_multi_words(paddle_words)
logger.info(
"paddle_kombi: split %d paddle boxes -> %d individual words",
len(paddle_words), len(paddle_words_split),
)
if not paddle_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
merged_words = _deduplicate_words(merged_words)
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "kombi"
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "kombi",
paddle_words, paddle_words_split, tess_words, merged_words,
"raw_paddle_words", "raw_paddle_words_split",
)
await update_session_db(
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[paddle=%d, tess=%d, merged=%d]",
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(paddle_words), len(tess_words), len(merged_words),
)
await _append_pipeline_log(session_id, "paddle_kombi", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"paddle_words": len(paddle_words),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
"ocr_engine": "kombi",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
@router.post("/sessions/{session_id}/rapid-kombi")
async def rapid_kombi(session_id: str):
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_rapid
from cv_vocab_types import PageRegion
t0 = time.time()
full_region = PageRegion(
type="full_page", x=0, y=0, width=img_w, height=img_h,
)
rapid_words = ocr_region_rapid(img_bgr, full_region)
if not rapid_words:
rapid_words = []
tess_words = _run_tesseract_words(img_bgr)
rapid_words_split = _split_paddle_multi_words(rapid_words)
logger.info(
"rapid_kombi: split %d rapid boxes -> %d individual words",
len(rapid_words), len(rapid_words_split),
)
if not rapid_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
merged_words = _deduplicate_words(merged_words)
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "rapid_kombi"
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
rapid_words, rapid_words_split, tess_words, merged_words,
"raw_rapid_words", "raw_rapid_words_split",
)
await update_session_db(
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[rapid=%d, tess=%d, merged=%d]",
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(rapid_words), len(tess_words), len(merged_words),
)
await _append_pipeline_log(session_id, "rapid_kombi", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"rapid_words": len(rapid_words),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
"ocr_engine": "rapid_kombi",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}

View File

@@ -0,0 +1,188 @@
"""
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import detect_and_fix_orientation
from .page_crop import detect_page_splits
from .session_store import update_session_db
from .orientation_crop_helpers import ensure_cached, append_pipeline_log
from .page_sub_sessions import create_page_sub_sessions_full
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 1: Orientation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/orientation")
async def detect_orientation(session_id: str):
"""Detect and fix 90/180/270 degree rotations from scanners.
Reads the original image, applies orientation correction,
stores the result as oriented_png.
"""
cached = await ensure_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
# Detect and fix orientation
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
duration = time.time() - t0
orientation_result = {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
"duration_seconds": round(duration, 2),
}
# Encode oriented image
success, png_buf = cv2.imencode(".png", oriented_bgr)
oriented_png = png_buf.tobytes() if success else b""
# Update cache
cached["oriented_bgr"] = oriented_bgr
cached["orientation_result"] = orientation_result
# Persist to DB
await update_session_db(
session_id,
oriented_png=oriented_png,
orientation_result=orientation_result,
current_step=2,
)
logger.info(
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
session_id, orientation_deg,
"corrected" if orientation_deg else "no change",
duration,
)
await append_pipeline_log(session_id, "orientation", {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
}, duration_ms=int(duration * 1000))
h, w = oriented_bgr.shape[:2]
return {
"session_id": session_id,
**orientation_result,
"image_width": w,
"image_height": h,
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
}
# ---------------------------------------------------------------------------
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/page-split")
async def detect_page_split(session_id: str):
"""Detect if the image is a double-page book spread and split into sub-sessions.
Must be called **after orientation** (step 1) and **before deskew** (step 2).
Each sub-session receives the raw page region and goes through the full
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
independently, so each page gets its own deskew correction.
Returns ``{"multi_page": false}`` if only one page is detected.
"""
cached = await ensure_cached(session_id)
# Use oriented (preferred), fall back to original
img_bgr = next(
(v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for page-split detection")
t0 = time.time()
page_splits = detect_page_splits(img_bgr)
used_original = False
if not page_splits or len(page_splits) < 2:
# Orientation may have rotated a landscape double-page spread to
# portrait. Try the original (pre-orientation) image as fallback.
orig_bgr = cached.get("original_bgr")
if orig_bgr is not None and orig_bgr is not img_bgr:
page_splits_orig = detect_page_splits(orig_bgr)
if page_splits_orig and len(page_splits_orig) >= 2:
logger.info(
"OCR Pipeline: page-split session %s: spread detected on "
"ORIGINAL (orientation rotated it away)",
session_id,
)
img_bgr = orig_bgr
page_splits = page_splits_orig
used_original = True
if not page_splits or len(page_splits) < 2:
duration = time.time() - t0
logger.info(
"OCR Pipeline: page-split session %s: single page (%.2fs)",
session_id, duration,
)
return {
"session_id": session_id,
"multi_page": False,
"duration_seconds": round(duration, 2),
}
# Multi-page spread detected — create sub-sessions for full pipeline.
# start_step=2 means "ready for deskew" (orientation already applied).
# start_step=1 means "needs orientation too" (split from original image).
start_step = 1 if used_original else 2
sub_sessions = await create_page_sub_sessions_full(
session_id, cached, img_bgr, page_splits, start_step=start_step,
)
duration = time.time() - t0
split_info: Dict[str, Any] = {
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"used_original": used_original,
"duration_seconds": round(duration, 2),
}
# Mark parent session as split and hidden from session list
await update_session_db(session_id, crop_result=split_info, status='split')
cached["crop_result"] = split_info
await append_pipeline_log(session_id, "page_split", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
logger.info(
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
session_id, len(page_splits), duration,
)
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**split_info,
"image_width": w,
"image_height": h,
"sub_sessions": sub_sessions,
}

View File

@@ -0,0 +1,16 @@
"""
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
Barrel re-export: merges routers from orientation_api and crop_api,
and re-exports set_cache_ref for main.py.
"""
from fastapi import APIRouter
from .orientation_crop_helpers import set_cache_ref # noqa: F401
from .orientation_api import router as _orientation_router
from .crop_api import router as _crop_router
router = APIRouter()
router.include_router(_orientation_router)
router.include_router(_crop_router)

View File

@@ -0,0 +1,86 @@
"""
Orientation & Crop shared helpers - cache management and pipeline logging.
"""
import logging
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import HTTPException
from .session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
_cache: Dict[str, Dict[str, Any]] = {}
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
"""Set reference to the shared cache from ocr_pipeline_api."""
global _cache
_cache = cache
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
"""Get reference to the shared cache."""
return _cache
async def ensure_cached(session_id: str) -> Dict[str, Any]:
"""Ensure session is in cache, loading from DB if needed."""
if session_id in _cache:
return _cache[session_id]
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
"""Append a step entry to the pipeline log."""
from datetime import datetime
session = await get_session_db(session_id)
if not session:
return
pipeline_log = session.get("pipeline_log") or {"steps": []}
pipeline_log["steps"].append({
"step": step,
"completed_at": datetime.utcnow().isoformat(),
"success": True,
"duration_ms": duration_ms,
"metrics": metrics,
})
await update_session_db(session_id, pipeline_log=pipeline_log)

View File

@@ -0,0 +1,333 @@
"""
Overlay rendering for columns, rows, and words (grid-based overlays).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from .common import _get_base_image_png
from .session_store import get_session_db
from .rows import _draw_box_exclusion_overlay
logger = logging.getLogger(__name__)
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types (BGR)
colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"column_text": (200, 200, 0), # Cyan/Turquoise
"page_ref": (200, 0, 200), # Purple
"column_marker": (0, 0, 220), # Red
"column_ignore": (180, 180, 180), # Light Gray
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label with confidence
label = col.get("type", "unknown").replace("column_", "").upper()
conf = col.get("classification_confidence")
if conf is not None and conf < 1.0:
label = f"{label} {int(conf * 100)}%"
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
# Draw detected box boundaries as dashed rectangles
zones = column_result.get("zones") or []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
box_color = (0, 200, 255) # Yellow (BGR)
# Draw dashed rectangle by drawing short line segments
dash_len = 15
for edge_x in range(bx, bx + bw, dash_len * 2):
end_x = min(edge_x + dash_len, bx + bw)
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
for edge_y in range(by, by + bh, dash_len * 2):
end_y = min(edge_y + dash_len, by + bh)
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with row bands drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
row_result = session.get("row_result")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=404, detail="No row data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for row types (BGR)
row_colors = {
"content": (255, 180, 0), # Blue
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for row in row_result["rows"]:
x, y = row["x"], row["y"]
w, h = row["width"], row["height"]
row_type = row.get("row_type", "content")
color = row_colors.get(row_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
# Label
idx = row.get("index", 0)
label = f"R{idx} {row_type.upper()}"
wc = row.get("word_count", 0)
if wc:
label = f"{label} ({wc}w)"
cv2.putText(img, label, (x + 5, y + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# Draw zone separator lines if zones exist
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
if zones:
img_w_px = img.shape[1]
zone_color = (0, 200, 255) # Yellow (BGR)
dash_len = 20
for zone in zones:
if zone.get("zone_type") == "box":
zy = zone["y"]
zh = zone["height"]
for line_y in [zy, zy + zh]:
for sx in range(0, img_w_px, dash_len * 2):
ex = min(sx + dash_len, img_w_px)
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=404, detail="No word data available")
# Support both new cell-based and legacy entry-based formats
cells = word_result.get("cells")
if not cells and not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
overlay = img.copy()
if cells:
# New cell-based overlay: color by column index
col_palette = [
(255, 180, 0), # Blue (BGR)
(0, 200, 0), # Green
(0, 140, 255), # Orange
(200, 100, 200), # Purple
(200, 200, 0), # Cyan
(100, 200, 200), # Yellow-ish
]
for cell in cells:
bbox = cell.get("bbox_px", {})
cx = bbox.get("x", 0)
cy = bbox.get("y", 0)
cw = bbox.get("w", 0)
ch = bbox.get("h", 0)
if cw <= 0 or ch <= 0:
continue
col_idx = cell.get("col_index", 0)
color = col_palette[col_idx % len(col_palette)]
# Cell rectangle border
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
# Semi-transparent fill
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
# Cell-ID label (top-left corner)
cell_id = cell.get("cell_id", "")
cv2.putText(img, cell_id, (cx + 2, cy + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
# Text label (bottom of cell)
text = cell.get("text", "")
if text:
conf = cell.get("confidence", 0)
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, cy + ch - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
else:
# Legacy fallback: entry-based overlay (for old sessions)
column_result = session.get("column_result")
row_result = session.get("row_result")
col_colors = {
"column_en": (255, 180, 0),
"column_de": (0, 200, 0),
"column_example": (0, 140, 255),
}
columns = []
if column_result and column_result.get("columns"):
columns = [c for c in column_result["columns"]
if c.get("type", "").startswith("column_")]
content_rows_data = []
if row_result and row_result.get("rows"):
content_rows_data = [r for r in row_result["rows"]
if r.get("row_type") == "content"]
for col in columns:
col_type = col.get("type", "")
color = col_colors.get(col_type, (200, 200, 200))
cx, cw = col["x"], col["width"]
for row in content_rows_data:
ry, rh = row["y"], row["height"]
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
entries = word_result["entries"]
entry_by_row: Dict[int, Dict] = {}
for entry in entries:
entry_by_row[entry.get("row_index", -1)] = entry
for row_idx, row in enumerate(content_rows_data):
entry = entry_by_row.get(row_idx)
if not entry:
continue
conf = entry.get("confidence", 0)
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
ry, rh = row["y"], row["height"]
for col in columns:
col_type = col.get("type", "")
cx, cw = col["x"], col["width"]
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
text = entry.get(field, "") if field else ""
if text:
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, ry + rh - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
# Red semi-transparent overlay for box zones
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")

View File

@@ -0,0 +1,205 @@
"""
Overlay rendering for structure detection (boxes, zones, colors, graphics).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from .common import _get_base_image_png
from .session_store import get_session_db
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
from cv_box_detect import detect_boxes, split_page_into_zones
logger = logging.getLogger(__name__)
async def _get_structure_overlay(session_id: str) -> Response:
"""Generate overlay image showing detected boxes, zones, and color regions."""
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
# Get structure result (run detection if not cached)
session = await get_session_db(session_id)
structure = (session or {}).get("structure_result")
if not structure:
# Run detection on-the-fly
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
structure = {
"boxes": [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence, "border_thickness": b.border_thickness}
for b in boxes
],
"zones": [
{"index": z.index, "zone_type": z.zone_type,
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
for z in zones
],
}
overlay = img.copy()
# --- Draw zone boundaries ---
zone_colors = {
"content": (200, 200, 200), # light gray
"box": (255, 180, 0), # blue-ish (BGR)
}
for zone in structure.get("zones", []):
zx = zone["x"]
zy = zone["y"]
zw = zone["w"]
zh = zone["h"]
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
# Draw zone boundary as dashed line
dash_len = 12
for edge_x in range(zx, zx + zw, dash_len * 2):
end_x = min(edge_x + dash_len, zx + zw)
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
# Zone label
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
cv2.putText(img, zone_label, (zx + 5, zy + 15),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
# --- Draw detected boxes ---
# Color map for box backgrounds (BGR)
bg_hex_to_bgr = {
"#dc2626": (38, 38, 220), # red
"#2563eb": (235, 99, 37), # blue
"#16a34a": (74, 163, 22), # green
"#ea580c": (12, 88, 234), # orange
"#9333ea": (234, 51, 147), # purple
"#ca8a04": (4, 138, 202), # yellow
"#6b7280": (128, 114, 107), # gray
}
for box_data in structure.get("boxes", []):
bx = box_data["x"]
by = box_data["y"]
bw = box_data["w"]
bh = box_data["h"]
conf = box_data.get("confidence", 0)
thickness = box_data.get("border_thickness", 0)
bg_hex = box_data.get("bg_color_hex", "#6b7280")
bg_name = box_data.get("bg_color_name", "")
# Box fill color
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
# Semi-transparent fill
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
# Solid border
border_color = fill_bgr
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
# Label
label = f"BOX"
if bg_name and bg_name not in ("unknown", "white"):
label += f" ({bg_name})"
if thickness > 0:
label += f" border={thickness}px"
label += f" {int(conf * 100)}%"
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# --- Draw color regions (HSV masks) ---
hsv = cv2.cvtColor(
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
cv2.COLOR_BGR2HSV,
)
color_bgr_map = {
"red": (0, 0, 255),
"orange": (0, 140, 255),
"yellow": (0, 200, 255),
"green": (0, 200, 0),
"blue": (255, 150, 0),
"purple": (200, 0, 200),
}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
# Only draw if there are significant colored pixels
if np.sum(mask > 0) < 100:
continue
# Draw colored contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
for cnt in contours:
area = cv2.contourArea(cnt)
if area < 20:
continue
cv2.drawContours(img, [cnt], -1, draw_color, 2)
# --- Draw graphic elements ---
graphics_data = structure.get("graphics", [])
shape_icons = {
"image": "IMAGE",
"illustration": "ILLUST",
}
for gfx in graphics_data:
gx, gy = gfx["x"], gfx["y"]
gw, gh = gfx["w"], gfx["h"]
shape = gfx.get("shape", "icon")
color_hex = gfx.get("color_hex", "#6b7280")
conf = gfx.get("confidence", 0)
# Pick draw color based on element color (BGR)
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
# Draw bounding box (dashed style via short segments)
dash = 6
for seg_x in range(gx, gx + gw, dash * 2):
end_x = min(seg_x + dash, gx + gw)
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
for seg_y in range(gy, gy + gh, dash * 2):
end_y = min(seg_y + dash, gy + gh)
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
# Label
icon = shape_icons.get(shape, shape.upper()[:5])
label = f"{icon} {int(conf * 100)}%"
# White background for readability
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
lx = gx + 2
ly = max(gy - 4, th + 4)
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
# Encode result
_, png_buf = cv2.imencode(".png", img)
return Response(content=png_buf.tobytes(), media_type="image/png")

View File

@@ -0,0 +1,34 @@
"""
Overlay image rendering for OCR pipeline — barrel re-export.
All implementation split into:
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
ocr_pipeline_overlay_grid — columns, rows, words overlays
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import HTTPException
from fastapi.responses import Response
from .overlay_structure import _get_structure_overlay # noqa: F401
from .overlay_grid import ( # noqa: F401
_get_columns_overlay,
_get_rows_overlay,
_get_words_overlay,
)
async def render_overlay(overlay_type: str, session_id: str) -> Response:
"""Dispatch to the appropriate overlay renderer."""
if overlay_type == "structure":
return await _get_structure_overlay(session_id)
elif overlay_type == "columns":
return await _get_columns_overlay(session_id)
elif overlay_type == "rows":
return await _get_rows_overlay(session_id)
elif overlay_type == "words":
return await _get_words_overlay(session_id)
else:
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")

View File

@@ -0,0 +1,33 @@
"""
Page Crop — Barrel Re-export
Content-based crop for scanned pages and book scans.
Split into:
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
- page_crop_core.py — Main crop algorithm and format detection
All public names are re-exported here for backward compatibility.
License: Apache 2.0
"""
# Core: main crop functions and format detection
from .page_crop_core import ( # noqa: F401
PAPER_FORMATS,
detect_page_splits,
detect_and_crop_page,
_detect_format,
)
# Edge detection helpers
from .page_crop_edges import ( # noqa: F401
_INK_THRESHOLD,
_MIN_RUN_FRAC,
_detect_spine_shadow,
_detect_gutter_continuity,
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
_detect_edge_projection,
_filter_narrow_runs,
)

View File

@@ -0,0 +1,342 @@
"""
Page Crop - Core Crop and Format Detection
Content-based crop for scanned pages and book scans. Detects the content
boundary by analysing ink density projections and (for book scans) the
spine shadow gradient.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple
import cv2
import numpy as np
from .page_crop_edges import (
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
)
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = []
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo)
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
run_brightness = float(np.mean(center_brightness[rs:re]))
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold -> binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Edge detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)

View File

@@ -0,0 +1,388 @@
"""
Page Crop - Edge Detection Helpers
Spine shadow detection, gutter continuity analysis, projection-based
edge detection, and narrow-run filtering for content cropping.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Optional, Tuple
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
max_valley_frac = 0.50
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
rise_check_w = max(5, trimmed_len // 5)
if side == "left":
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has >= 75% of strips darker than page_median - d
4. Smooth the dark-fraction profile and find the transition point
5. Validate: gutter band must be 0.5%-10% of image width
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
dark_thresh = page_median - 5.0
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips
# --- 4. Smooth and find transition ---
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None
if side == "right":
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005))
max_gutter = int(w * 0.10)
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 -> project vertically (column densities) -> returns x position
axis=1 -> project horizontally (row densities) -> returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
projection = np.mean(binary, axis=axis) / 255.0
ink_mask = projection >= _INK_THRESHOLD
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result

View File

@@ -0,0 +1,189 @@
"""
Sub-session creation for multi-page spreads.
Used by both the page-split and crop steps when a double-page scan is detected.
"""
import logging
import uuid as uuid_mod
from typing import Any, Dict, List
import cv2
import numpy as np
from .page_crop import detect_and_crop_page
from .session_store import (
create_session_db,
get_sub_sessions,
update_session_db,
)
from .orientation_crop_helpers import get_cache_ref
logger = logging.getLogger(__name__)
async def create_page_sub_sessions(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each detected page in a multi-page spread.
Each page region is individually cropped, then stored as a sub-session
with its own cropped image ready for the rest of the pipeline.
"""
# Check for existing sub-sessions (idempotent)
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract page region
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Crop each page individually (remove its own borders)
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
# Encode as PNG
ok, png_buf = cv2.imencode(".png", cropped_page)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# Pre-populate: set cropped = original (already cropped)
await update_session_db(
sub_id,
cropped_png=page_png,
crop_result=page_crop_info,
current_step=5,
)
ch, cw = cropped_page.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"cropped_size": {"width": cw, "height": ch},
"detected_format": page_crop_info.get("detected_format"),
})
logger.info(
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
sub_id, pi + 1, px, pw, cw, ch,
)
return sub_sessions
async def create_page_sub_sessions_full(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
start_step: int = 2,
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
Unlike ``create_page_sub_sessions`` (used by the crop step), these
sub-sessions store the *uncropped* page region and start at
``start_step`` (default 2 = ready for deskew; 1 if orientation still
needed). Each page goes through its own pipeline independently,
which is essential for book spreads where each page has a different tilt.
"""
_cache = get_cache_ref()
# Idempotent: reuse existing sub-sessions
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract RAW page region — NO individual cropping here; each
# sub-session will run its own crop step after deskew + dewarp.
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Encode as PNG
ok, png_buf = cv2.imencode(".png", page_bgr)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# start_step=2 -> ready for deskew (orientation already done on spread)
# start_step=1 -> needs its own orientation (split from original image)
await update_session_db(sub_id, current_step=start_step)
# Cache the BGR so the pipeline can start immediately
_cache[sub_id] = {
"id": sub_id,
"filename": parent_filename,
"name": sub_name,
"original_bgr": page_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": start_step,
}
rh, rw = page_bgr.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"image_size": {"width": rw, "height": rh},
})
logger.info(
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
sub_id, pi + 1, px, pw, rw, rh,
)
return sub_sessions

View File

@@ -0,0 +1,26 @@
"""
OCR Pipeline Postprocessing API — composite router assembling LLM review,
reconstruction, export, validation, image detection/generation, and
handwriting removal endpoints.
Split into sub-modules:
ocr_pipeline_llm_review — LLM review + apply corrections
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from .llm_review import router as _llm_review_router
from .reconstruction import router as _reconstruction_router
from .validation import router as _validation_router
# Composite router — drop-in replacement for the old monolithic router.
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
router = APIRouter()
router.include_router(_llm_review_router)
router.include_router(_reconstruction_router)
router.include_router(_validation_router)

View File

@@ -0,0 +1,362 @@
"""
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Dict
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from .session_store import (
get_session_db,
get_sub_sessions,
update_session_db,
)
from .common import _cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 9: Reconstruction + Fabric JSON export
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=10)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in main_updates:
cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
for entry in entries:
row_idx = entry.get("row_index", -1)
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=10)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
# Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return {
"session_id": session_id,
"updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
for cell in sub_word["cells"]:
cell_copy = dict(cell)
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
# ---------------------------------------------------------------------------
# Vocab entries merged + PDF/DOCX export
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
for e in entries:
e.setdefault("source", "main")
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y
entries.append(e_copy)
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows x columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")

View File

@@ -0,0 +1,22 @@
"""
OCR Pipeline Regression Tests — barrel re-export.
All implementation split into:
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
ocr_pipeline_regression_endpoints — FastAPI routes
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
# Helpers (used by grid_editor_api_grid.py)
from .regression_helpers import ( # noqa: F401
_init_regression_table,
_persist_regression_run,
_extract_cells_for_comparison,
_build_reference_snapshot,
compare_grids,
)
# Endpoints (router used by ocr_pipeline_api.py)
from .regression_endpoints import router # noqa: F401

View File

@@ -0,0 +1,421 @@
"""
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from .session_store import (
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
from .regression_helpers import (
_build_reference_snapshot,
_init_regression_table,
_persist_regression_run,
compare_grids,
get_pool,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,207 @@
"""
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from .session_store import get_pool
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}

View File

@@ -0,0 +1,94 @@
"""
OCR Pipeline Reprocess Endpoint.
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Request
from .common import _cache
from .session_store import get_session_db, update_session_db
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}

View File

@@ -0,0 +1,348 @@
"""
OCR Pipeline - Row Detection Endpoints.
Extracted from ocr_pipeline_api.py.
Handles row detection (auto + manual) and row ground truth.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
create_ocr_image,
detect_column_geometry,
detect_row_geometry,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualRowsRequest,
RowGroundTruthRequest,
)
from .session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
# ---------------------------------------------------------------------------
def _draw_box_exclusion_overlay(
img: np.ndarray,
zones: List[Dict],
*,
label: str = "BOX — separat verarbeitet",
) -> None:
"""Draw red semi-transparent rectangles over box zones (in-place).
Reusable for columns, rows, and words overlays.
"""
for zone in zones:
if zone.get("zone_type") != "box" or not zone.get("box"):
continue
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Red semi-transparent fill (~25 %)
box_overlay = img.copy()
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
# Border
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
# Label
cv2.putText(img, label, (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# ---------------------------------------------------------------------------
# Row Detection Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/rows")
async def detect_rows(session_id: str):
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
t0 = time.time()
# Try to reuse cached word_dicts and inv from column detection
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
# Not cached — run column geometry to get intermediates
ocr_img = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
else:
left_x, right_x, top_y, bottom_y = content_bounds
# Read zones from column_result to exclude box regions
session = await get_session_db(session_id)
column_result = (session or {}).get("column_result") or {}
is_sub_session = bool((session or {}).get("parent_session_id"))
# Sub-sessions (box crops): use word-grouping instead of gap-based
# row detection. Box images are small with complex internal layouts
# (headings, sub-columns) where the horizontal projection approach
# merges rows. Word-grouping directly clusters words by Y proximity,
# which is more robust for these cases.
if is_sub_session and word_dicts:
from cv_layout import _build_rows_from_word_grouping
rows = _build_rows_from_word_grouping(
word_dicts, left_x, right_x, top_y, bottom_y,
right_x - left_x, bottom_y - top_y,
)
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
else:
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
# Collect box y-ranges for filtering.
# Use border_thickness to shrink the exclusion zone: the border pixels
# belong visually to the box frame, but text rows above/below the box
# may overlap with the border area and must not be clipped.
box_ranges = [] # [(y_start, y_end)]
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
box_ranges.append((box["y"], box["y"] + box["height"]))
# Inner range: shrink by border thickness so boundary rows aren't excluded
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges and inv is not None:
# Combined-image approach: strip box regions from inv image,
# run row detection on the combined image, then remap y-coords back.
content_strips = [] # [(y_start, y_end)] in absolute coords
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
# Using inner ranges means the border area is included in the content
# strips, so the last row above a box isn't clipped by the border.
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
strip_start = top_y
for by_start, by_end in sorted_boxes:
if by_start > strip_start:
content_strips.append((strip_start, by_start))
strip_start = max(strip_start, by_end)
if strip_start < bottom_y:
content_strips.append((strip_start, bottom_y))
# Filter to strips with meaningful height
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
if content_strips:
# Stack content strips vertically
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
combined_inv = np.vstack(inv_strips)
# Filter word_dicts to only include words from content strips
combined_words = []
cum_y = 0
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
for ys, ye in content_strips:
h = ye - ys
strip_offsets.append((cum_y, h, ys))
for w in word_dicts:
w_abs_y = w['top'] + top_y # word y is relative to content top
w_center = w_abs_y + w['height'] / 2
if ys <= w_center < ye:
# Remap to combined coordinates
w_copy = dict(w)
w_copy['top'] = cum_y + (w_abs_y - ys)
combined_words.append(w_copy)
cum_y += h
# Run row detection on combined image
combined_h = combined_inv.shape[0]
rows = detect_row_geometry(
combined_inv, combined_words, left_x, right_x, 0, combined_h,
)
# Remap y-coordinates back to absolute page coords
def _combined_y_to_abs(cy: int) -> int:
for c_start, s_h, abs_start in strip_offsets:
if cy < c_start + s_h:
return abs_start + (cy - c_start)
last_c, last_h, last_abs = strip_offsets[-1]
return last_abs + last_h
for r in rows:
abs_y = _combined_y_to_abs(r.y)
abs_y_end = _combined_y_to_abs(r.y + r.height)
r.y = abs_y
r.height = abs_y_end - abs_y
else:
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
else:
# No boxes — standard row detection
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
duration = time.time() - t0
# Assign zone_index based on which content zone each row falls in
# Build content zone list with indices
zones = column_result.get("zones") or []
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
# Build serializable result (exclude words to keep payload small)
rows_data = []
for r in rows:
# Determine zone_index
zone_idx = 0
row_center_y = r.y + r.height / 2
for zi, zone in content_zones:
zy = zone["y"]
zh = zone["height"]
if zy <= row_center_y < zy + zh:
zone_idx = zi
break
rd = {
"index": r.index,
"x": r.x,
"y": r.y,
"width": r.width,
"height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
"zone_index": zone_idx,
}
rows_data.append(rd)
type_counts = {}
for r in rows:
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
row_result = {
"rows": rows_data,
"summary": type_counts,
"total_rows": len(rows),
"duration_seconds": round(duration, 2),
}
# Persist to DB — also invalidate word_result since rows changed
await update_session_db(
session_id,
row_result=row_result,
word_result=None,
current_step=7,
)
cached["row_result"] = row_result
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: rows session {session_id}: "
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
content_rows = sum(1 for r in rows if r.row_type == "content")
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
await _append_pipeline_log(session_id, "rows", {
"total_rows": len(rows),
"content_rows": content_rows,
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
"avg_row_height_px": avg_height,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**row_result,
}
@router.post("/sessions/{session_id}/rows/manual")
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
"""Override detected rows with manual definitions."""
row_result = {
"rows": req.rows,
"total_rows": len(req.rows),
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, row_result=row_result, word_result=None)
if session_id in _cache:
_cache[session_id]["row_result"] = row_result
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
f"{len(req.rows)} rows set")
return {"session_id": session_id, **row_result}
@router.post("/sessions/{session_id}/ground-truth/rows")
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
"""Save ground truth feedback for the row detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_rows": req.corrected_rows,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"row_result": session.get("row_result"),
}
ground_truth["rows"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/rows")
async def get_row_ground_truth(session_id: str):
"""Retrieve saved ground truth for row detection."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
rows_gt = ground_truth.get("rows")
if not rows_gt:
raise HTTPException(status_code=404, detail="No row ground truth saved")
return {
"session_id": session_id,
"rows_gt": rows_gt,
"rows_auto": session.get("row_result"),
}

View File

@@ -0,0 +1,102 @@
"""
Scan Quality Assessment — Measures image quality before OCR.
Computes blur score, contrast score, and an overall quality rating.
Used to gate enhancement steps and warn users about degraded scans.
All operations use OpenCV (Apache-2.0), no additional dependencies.
"""
import logging
from dataclasses import dataclass, asdict
from typing import Dict, Any
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Thresholds (empirically tuned on textbook scans)
BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry
CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast
CONFIDENCE_GOOD = 40 # OCR min confidence for good scans
CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans
@dataclass
class ScanQualityReport:
"""Result of scan quality assessment."""
blur_score: float # Laplacian variance (higher = sharper)
contrast_score: float # Grayscale std deviation (higher = more contrast)
brightness: float # Mean grayscale value (0-255)
is_blurry: bool
is_low_contrast: bool
is_degraded: bool # True if any quality issue detected
quality_pct: int # 0-100 overall quality estimate
recommended_min_conf: int # Recommended OCR confidence threshold
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport:
"""
Assess the quality of a scanned image.
Uses:
- Laplacian variance for blur detection
- Grayscale standard deviation for contrast
- Mean brightness for exposure assessment
Args:
img_bgr: BGR image (numpy array from OpenCV)
Returns:
ScanQualityReport with scores and recommendations
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Blur detection: Laplacian variance
# Higher = sharper edges = better quality
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
blur_score = float(laplacian.var())
# Contrast: standard deviation of grayscale
contrast_score = float(np.std(gray))
# Brightness: mean grayscale
brightness = float(np.mean(gray))
# Quality flags
is_blurry = blur_score < BLUR_THRESHOLD
is_low_contrast = contrast_score < CONTRAST_THRESHOLD
is_degraded = is_blurry or is_low_contrast
# Overall quality percentage (simple weighted combination)
blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50)
contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50)
quality_pct = int(min(100, blur_pct + contrast_pct))
# Recommended confidence threshold
recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD
report = ScanQualityReport(
blur_score=round(blur_score, 1),
contrast_score=round(contrast_score, 1),
brightness=round(brightness, 1),
is_blurry=is_blurry,
is_low_contrast=is_low_contrast,
is_degraded=is_degraded,
quality_pct=quality_pct,
recommended_min_conf=recommended_min_conf,
)
logger.info(
f"Scan quality: blur={report.blur_score} "
f"contrast={report.contrast_score} "
f"quality={report.quality_pct}% "
f"degraded={report.is_degraded} "
f"min_conf={report.recommended_min_conf}"
)
return report

View File

@@ -0,0 +1,388 @@
"""
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
Replaces in-memory storage with database persistence.
See migrations/002_ocr_pipeline_sessions.sql for schema.
"""
import os
import uuid
import logging
import json
from typing import Optional, List, Dict, Any
import asyncpg
logger = logging.getLogger(__name__)
# Database configuration (same as vocab_session_store)
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
)
# Connection pool (initialized lazily)
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
"""Get or create the database connection pool."""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def init_ocr_pipeline_tables():
"""Initialize OCR pipeline tables if they don't exist."""
pool = await get_pool()
async with pool.acquire() as conn:
tables_exist = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'ocr_pipeline_sessions'
)
""")
if not tables_exist:
logger.info("Creating OCR pipeline tables...")
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/002_ocr_pipeline_sessions.sql"
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
logger.info("OCR pipeline tables created successfully")
else:
logger.warning(f"Migration file not found: {migration_path}")
else:
logger.debug("OCR pipeline tables already exist")
# Ensure new columns exist (idempotent ALTER TABLE)
await conn.execute("""
ALTER TABLE ocr_pipeline_sessions
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
ADD COLUMN IF NOT EXISTS pipeline_log JSONB,
ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
ADD COLUMN IF NOT EXISTS orientation_result JSONB,
ADD COLUMN IF NOT EXISTS crop_result JSONB,
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
ADD COLUMN IF NOT EXISTS box_index INT,
ADD COLUMN IF NOT EXISTS grid_editor_result JSONB,
ADD COLUMN IF NOT EXISTS structure_result JSONB,
ADD COLUMN IF NOT EXISTS document_group_id UUID,
ADD COLUMN IF NOT EXISTS page_number INT
""")
# Index for document group lookups
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group
ON ocr_pipeline_sessions (document_group_id)
WHERE document_group_id IS NOT NULL
""")
# =============================================================================
# SESSION CRUD
# =============================================================================
async def create_session_db(
session_id: str,
name: str,
filename: str,
original_png: bytes,
parent_session_id: Optional[str] = None,
box_index: Optional[int] = None,
document_group_id: Optional[str] = None,
page_number: Optional[int] = None,
) -> Dict[str, Any]:
"""Create a new OCR pipeline session.
Args:
parent_session_id: If set, this is a sub-session for a box region.
box_index: 0-based index of the box this sub-session represents.
document_group_id: Groups multi-page uploads into one document.
page_number: 1-based page index within the document group.
"""
pool = await get_pool()
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
group_uuid = uuid.UUID(document_group_id) if document_group_id else None
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO ocr_pipeline_sessions (
id, name, filename, original_png, status, current_step,
parent_session_id, box_index, document_group_id, page_number
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8)
RETURNING id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png,
parent_uuid, box_index, group_uuid, page_number)
return _row_to_dict(row)
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
"""Get session metadata (without images)."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
if row:
return _row_to_dict(row)
return None
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
"""Load a single image (BYTEA) from the session."""
column_map = {
"original": "original_png",
"oriented": "oriented_png",
"cropped": "cropped_png",
"deskewed": "deskewed_png",
"binarized": "binarized_png",
"dewarped": "dewarped_png",
"clean": "clean_png",
}
column = column_map.get(image_type)
if not column:
return None
pool = await get_pool()
async with pool.acquire() as conn:
return await conn.fetchval(
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
uuid.UUID(session_id)
)
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
"""Update session fields dynamically."""
pool = await get_pool()
fields = []
values = []
param_idx = 1
allowed_fields = {
'name', 'filename', 'status', 'current_step',
'original_png', 'oriented_png', 'cropped_png',
'deskewed_png', 'binarized_png', 'dewarped_png',
'clean_png', 'handwriting_removal_meta',
'orientation_result', 'crop_result',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'word_result', 'ground_truth', 'auto_shear_degrees',
'doc_type', 'doc_type_result',
'document_category', 'pipeline_log',
'grid_editor_result', 'structure_result',
'parent_session_id', 'box_index',
'document_group_id', 'page_number',
}
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'}
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
if key in jsonb_fields and value is not None and not isinstance(value, str):
value = json.dumps(value)
values.append(value)
param_idx += 1
if not fields:
return await get_session_db(session_id)
# Always update updated_at
fields.append(f"updated_at = NOW()")
values.append(uuid.UUID(session_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE ocr_pipeline_sessions
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
""", *values)
if row:
return _row_to_dict(row)
return None
async def list_sessions_db(
limit: int = 50,
include_sub_sessions: bool = False,
) -> List[Dict[str, Any]]:
"""List sessions (metadata only, no images).
By default, sub-sessions (those with parent_session_id) are excluded.
Pass include_sub_sessions=True to include them.
"""
pool = await get_pool()
async with pool.acquire() as conn:
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')"
rows = await conn.fetch(f"""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at,
ground_truth
FROM ocr_pipeline_sessions
{where}
ORDER BY created_at DESC
LIMIT $1
""", limit)
results = []
for row in rows:
d = _row_to_dict(row)
# Derive is_ground_truth flag from JSONB, then drop the heavy field
gt = d.pop("ground_truth", None) or {}
d["is_ground_truth"] = bool(gt.get("build_grid_reference"))
results.append(d)
return results
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
"""Get all sub-sessions for a parent session, ordered by box_index."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE parent_session_id = $1
ORDER BY box_index ASC
""", uuid.UUID(parent_session_id))
return [_row_to_dict(row) for row in rows]
async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]:
"""Get all sessions in a document group, ordered by page_number."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE document_group_id = $1
ORDER BY page_number ASC
""", uuid.UUID(document_group_id))
return [_row_to_dict(row) for row in rows]
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
"""List sessions that have a build_grid_reference in ground_truth."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
ground_truth,
parent_session_id, box_index,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE ground_truth IS NOT NULL
AND ground_truth::text LIKE '%build_grid_reference%'
AND parent_session_id IS NULL
ORDER BY created_at DESC
""")
return [_row_to_dict(row) for row in rows]
async def delete_session_db(session_id: str) -> bool:
"""Delete a session."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
return result == "DELETE 1"
async def delete_all_sessions_db() -> int:
"""Delete all sessions. Returns number of deleted rows."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
# result is e.g. "DELETE 5"
try:
return int(result.split()[-1])
except (ValueError, IndexError):
return 0
# =============================================================================
# HELPER
# =============================================================================
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
"""Convert asyncpg Record to JSON-serializable dict."""
if row is None:
return {}
result = dict(row)
# UUID → string
for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']:
if key in result and result[key] is not None:
result[key] = str(result[key])
# datetime → ISO string
for key in ['created_at', 'updated_at']:
if key in result and result[key] is not None:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])
return result

View File

@@ -0,0 +1,20 @@
"""
OCR Pipeline Sessions API — barrel re-export.
All implementation split into:
ocr_pipeline_sessions_crud — session CRUD, box sessions
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from .sessions_crud import router as _crud_router # noqa: F401
from .sessions_images import router as _images_router # noqa: F401
# Composite router (used by ocr_pipeline_api.py)
router = APIRouter()
router.include_router(_crud_router)
router.include_router(_images_router)

View File

@@ -0,0 +1,449 @@
"""
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import uuid
from typing import Any, Dict, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
from .common import (
VALID_DOCUMENT_CATEGORIES,
UpdateSessionRequest,
_cache,
)
from .session_store import (
create_session_db,
delete_all_sessions_db,
delete_session_db,
get_session_db,
get_session_image,
get_sub_sessions,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions(include_sub_sessions: bool = False):
"""List OCR pipeline sessions.
By default, sub-sessions (box regions) are hidden.
Pass ?include_sub_sessions=true to show them.
"""
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session.
For multi-page PDFs (> 1 page), each page becomes its own session
grouped under a ``document_group_id``. The response includes a
``pages`` array with one entry per page/session.
"""
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
session_name = name or filename
# --- Multi-page PDF handling ---
if is_pdf:
try:
import fitz # PyMuPDF
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
page_count = pdf_doc.page_count
pdf_doc.close()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
if page_count > 1:
return await _create_multi_page_sessions(
file_data, filename, session_name, page_count,
)
# --- Single page (image or 1-page PDF) ---
session_id = str(uuid.uuid4())
try:
if is_pdf:
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
else:
img_bgr = render_image_high_res(file_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
original_png = png_buf.tobytes()
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": session_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
}
async def _create_multi_page_sessions(
pdf_data: bytes,
filename: str,
base_name: str,
page_count: int,
) -> dict:
"""Create one session per PDF page, grouped by document_group_id."""
document_group_id = str(uuid.uuid4())
pages = []
for page_idx in range(page_count):
session_id = str(uuid.uuid4())
page_name = f"{base_name} — Seite {page_idx + 1}"
try:
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
except Exception as e:
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
continue
ok, png_buf = cv2.imencode(".png", img_bgr)
if not ok:
continue
page_png = png_buf.tobytes()
await create_session_db(
session_id=session_id,
name=page_name,
filename=filename,
original_png=page_png,
document_group_id=document_group_id,
page_number=page_idx + 1,
)
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": page_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
h, w = img_bgr.shape[:2]
pages.append({
"session_id": session_id,
"name": page_name,
"page_number": page_idx + 1,
"image_width": w,
"image_height": h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
})
logger.info(
f"OCR Pipeline: created page session {session_id} "
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
)
# Include session_id pointing to first page for backwards compatibility
# (frontends that expect a single session_id will navigate to page 1)
first_session_id = pages[0]["session_id"] if pages else None
return {
"session_id": first_session_id,
"document_group_id": document_group_id,
"filename": filename,
"name": base_name,
"page_count": page_count,
"pages": pages,
}
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
"document_category": session.get("document_category"),
"doc_type": session.get("doc_type"),
}
if session.get("orientation_result"):
result["orientation_result"] = session["orientation_result"]
if session.get("crop_result"):
result["crop_result"] = session["crop_result"]
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"]
if session.get("structure_result"):
result["structure_result"] = session["structure_result"]
if session.get("grid_editor_result"):
# Include summary only to keep response small
gr = session["grid_editor_result"]
result["grid_editor_result"] = {
"summary": gr.get("summary", {}),
"zones_count": len(gr.get("zones", [])),
"edited": gr.get("edited", False),
}
if session.get("ground_truth"):
result["ground_truth"] = session["ground_truth"]
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
if session.get("parent_session_id"):
result["parent_session_id"] = session["parent_session_id"]
result["box_index"] = session.get("box_index")
else:
# Check for box sub-sessions (column detection creates these)
subs = await get_sub_sessions(session_id)
if subs:
result["sub_sessions"] = [
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
for s in subs
]
return result
@router.put("/sessions/{session_id}")
async def update_session(session_id: str, req: UpdateSessionRequest):
"""Update session name and/or document category."""
kwargs: Dict[str, Any] = {}
if req.name is not None:
kwargs["name"] = req.name
if req.document_category is not None:
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
)
kwargs["document_category"] = req.document_category
if not kwargs:
raise HTTPException(status_code=400, detail="Nothing to update")
updated = await update_session_db(session_id, **kwargs)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, **kwargs}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
@router.delete("/sessions")
async def delete_all_sessions():
"""Delete ALL sessions (cleanup)."""
_cache.clear()
count = await delete_all_sessions_db()
return {"deleted_count": count}
@router.post("/sessions/{session_id}/create-box-sessions")
async def create_box_sessions(session_id: str):
"""Create sub-sessions for each detected box region.
Crops box regions from the cropped/dewarped image and creates
independent sub-sessions that can be processed through the pipeline.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result:
raise HTTPException(status_code=400, detail="Column detection must be completed first")
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
if not box_zones:
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
# Check for existing sub-sessions
existing = await get_sub_sessions(session_id)
if existing:
return {
"session_id": session_id,
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
"message": f"{len(existing)} sub-session(s) already exist",
}
# Load base image
base_png = await get_session_image(session_id, "cropped")
if not base_png:
base_png = await get_session_image(session_id, "dewarped")
if not base_png:
raise HTTPException(status_code=400, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
parent_name = session.get("name", "Session")
created = []
for i, zone in enumerate(box_zones):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Crop box region with small padding
pad = 5
y1 = max(0, by - pad)
y2 = min(img.shape[0], by + bh + pad)
x1 = max(0, bx - pad)
x2 = min(img.shape[1], bx + bw + pad)
crop = img[y1:y2, x1:x2]
# Encode as PNG
success, png_buf = cv2.imencode(".png", crop)
if not success:
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
continue
sub_id = str(uuid.uuid4())
sub_name = f"{parent_name} — Box {i + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=session.get("filename", "box-crop.png"),
original_png=png_buf.tobytes(),
parent_session_id=session_id,
box_index=i,
)
# Cache the BGR for immediate processing
# Promote original to cropped so column/row/word detection finds it
box_bgr = crop.copy()
_cache[sub_id] = {
"id": sub_id,
"filename": session.get("filename", "box-crop.png"),
"name": sub_name,
"parent_session_id": session_id,
"original_bgr": box_bgr,
"oriented_bgr": None,
"cropped_bgr": box_bgr,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
created.append({
"id": sub_id,
"name": sub_name,
"box_index": i,
"box": box,
"image_width": crop.shape[1],
"image_height": crop.shape[0],
})
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
return {
"session_id": session_id,
"sub_sessions": created,
"total": len(created),
}

View File

@@ -0,0 +1,176 @@
"""
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
categories, and document type detection.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import Response
from cv_vocab_pipeline import create_ocr_image, detect_document_type
from .common import (
VALID_DOCUMENT_CATEGORIES,
_append_pipeline_log,
_cache,
_get_base_image_png,
_get_cached,
_load_session_to_cache,
)
from .overlays import render_overlay
from .session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Thumbnail & Log Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image."""
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
scale = size / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
_, png_bytes = cv2.imencode(".png", thumb)
return Response(content=png_bytes.tobytes(), media_type="image/png",
headers={"Cache-Control": "public, max-age=3600"})
@router.get("/sessions/{session_id}/pipeline-log")
async def get_pipeline_log(session_id: str):
"""Get the pipeline execution log for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
@router.get("/categories")
async def list_categories():
"""List valid document categories."""
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "structure-overlay":
return await render_overlay("structure", session_id)
if image_type == "columns-overlay":
return await render_overlay("columns", session_id)
if image_type == "rows-overlay":
return await render_overlay("rows", session_id)
if image_type == "words-overlay":
return await render_overlay("words", session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB — for cropped/dewarped, fall back through the chain
if image_type in ("cropped", "dewarped"):
data = await _get_base_image_png(session_id)
else:
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after crop (clean image available).
Falls back to dewarped if crop was skipped.
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(img_bgr)
result = detect_document_type(ocr_img, img_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
await _append_pipeline_log(session_id, "detect_type", {
"doc_type": result.doc_type,
"pipeline": result.pipeline,
"confidence": result.confidence,
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **result_dict}

View File

@@ -0,0 +1,299 @@
"""
OCR Pipeline Structure Detection and Exclude Regions
Detect document structure (boxes, zones, color regions, graphics)
and manage user-drawn exclude regions.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_box_detect import detect_boxes
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
from cv_graphic_detect import detect_graphic_elements
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_filter_border_ghost_words,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Structure Detection Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-structure")
async def detect_structure(session_id: str):
"""Detect document structure: boxes, zones, and color regions.
Runs box detection (line + shading) and color analysis on the cropped
image. Returns structured JSON with all detected elements for the
structure visualization step.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = (
cached.get("cropped_bgr")
if cached.get("cropped_bgr") is not None
else cached.get("dewarped_bgr")
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
h, w = img_bgr.shape[:2]
# --- Content bounds from word result (if available) or full image ---
word_result = cached.get("word_result")
words: List[Dict] = []
if word_result and word_result.get("cells"):
for cell in word_result["cells"]:
for wb in (cell.get("word_boxes") or []):
words.append(wb)
# Fallback: use raw OCR words if cell word_boxes are empty
if not words and word_result:
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
raw = word_result.get(key, [])
if raw:
words = raw
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
break
# If no words yet, use image dimensions with small margin
if words:
content_x = max(0, min(int(wb["left"]) for wb in words))
content_y = max(0, min(int(wb["top"]) for wb in words))
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
content_w_px = content_r - content_x
content_h_px = content_b - content_y
else:
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
# --- Box detection ---
boxes = detect_boxes(
img_bgr,
content_x=content_x,
content_w=content_w_px,
content_y=content_y,
content_h=content_h_px,
)
# --- Zone splitting ---
from cv_box_detect import split_page_into_zones as _split_zones
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
# --- Color region sampling ---
# Sample background shading in each detected box
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
box_colors = []
for box in boxes:
# Sample the center region of each box
cy1 = box.y + box.height // 4
cy2 = box.y + 3 * box.height // 4
cx1 = box.x + box.width // 4
cx2 = box.x + 3 * box.width // 4
cy1 = max(0, min(cy1, h - 1))
cy2 = max(0, min(cy2, h - 1))
cx1 = max(0, min(cx1, w - 1))
cx2 = max(0, min(cx2, w - 1))
if cy2 > cy1 and cx2 > cx1:
roi_hsv = hsv[cy1:cy2, cx1:cx2]
med_h = float(np.median(roi_hsv[:, :, 0]))
med_s = float(np.median(roi_hsv[:, :, 1]))
med_v = float(np.median(roi_hsv[:, :, 2]))
if med_s > 15:
from cv_color_detect import _hue_to_color_name
bg_name = _hue_to_color_name(med_h)
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
else:
bg_name = "gray" if med_v < 220 else "white"
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
else:
bg_name = "unknown"
bg_hex = "#6b7280"
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
# --- Color text detection overview ---
# Quick scan for colored text regions across the page
color_summary: Dict[str, int] = {}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
pixel_count = int(np.sum(mask > 0))
if pixel_count > 50: # minimum threshold
color_summary[color_name] = pixel_count
# --- Graphic element detection ---
box_dicts = [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
for b in boxes
]
graphics = detect_graphic_elements(
img_bgr, words,
detected_boxes=box_dicts,
)
# --- Filter border-ghost words from OCR result ---
ghost_count = 0
if boxes and word_result:
ghost_count = _filter_border_ghost_words(word_result, boxes)
if ghost_count:
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
await update_session_db(session_id, word_result=word_result)
cached["word_result"] = word_result
duration = time.time() - t0
# Preserve user-drawn exclude regions from previous run
prev_sr = cached.get("structure_result") or {}
prev_exclude = prev_sr.get("exclude_regions", [])
result_dict = {
"image_width": w,
"image_height": h,
"content_bounds": {
"x": content_x, "y": content_y,
"w": content_w_px, "h": content_h_px,
},
"boxes": [
{
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence,
"border_thickness": b.border_thickness,
"bg_color_name": box_colors[i]["color_name"],
"bg_color_hex": box_colors[i]["color_hex"],
}
for i, b in enumerate(boxes)
],
"zones": [
{
"index": z.index,
"zone_type": z.zone_type,
"y": z.y, "h": z.height,
"x": z.x, "w": z.width,
}
for z in zones
],
"graphics": [
{
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
"area": g.area,
"shape": g.shape,
"color_name": g.color_name,
"color_hex": g.color_hex,
"confidence": round(g.confidence, 2),
}
for g in graphics
],
"exclude_regions": prev_exclude,
"color_pixel_counts": color_summary,
"has_words": len(words) > 0,
"word_count": len(words),
"border_ghosts_removed": ghost_count,
"duration_seconds": round(duration, 2),
}
# Persist to session
await update_session_db(session_id, structure_result=result_dict)
cached["structure_result"] = result_dict
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
session_id, len(boxes), len(zones), len(graphics), duration)
return {"session_id": session_id, **result_dict}
# ---------------------------------------------------------------------------
# Exclude Regions -- user-drawn rectangles to exclude from OCR results
# ---------------------------------------------------------------------------
class _ExcludeRegionIn(BaseModel):
x: int
y: int
w: int
h: int
label: str = ""
class _ExcludeRegionsBatchIn(BaseModel):
regions: list[_ExcludeRegionIn]
@router.put("/sessions/{session_id}/exclude-regions")
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
"""Replace all exclude regions for a session.
Regions are stored inside ``structure_result.exclude_regions``.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
# Update cache
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"exclude_regions": sr["exclude_regions"],
"count": len(sr["exclude_regions"]),
}
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
async def delete_exclude_region(session_id: str, region_index: int):
"""Remove a single exclude region by index."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
regions = sr.get("exclude_regions", [])
if region_index < 0 or region_index >= len(regions):
raise HTTPException(status_code=404, detail="Region index out of range")
removed = regions.pop(region_index)
sr["exclude_regions"] = regions
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"removed": removed,
"remaining": len(regions),
}

View File

@@ -0,0 +1,362 @@
"""
OCR Pipeline Validation — image detection, generation, validation save,
and handwriting removal endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from .session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from .common import (
_cache,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Image detection + generation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM."""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux."""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
# ---------------------------------------------------------------------------
# Validation save/get
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""Remove handwriting from a session image using inpainting."""
import time as _time
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
t0 = _time.monotonic()
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}

View File

@@ -0,0 +1,261 @@
"""
Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading.
Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL.
The LLM can read degraded text using context understanding and visual inspection,
while OCR coordinates provide structural hints (where text is, column positions).
Uses Ollama API (same pattern as handwriting_htr_api.py).
"""
import base64
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional
import cv2
import httpx
import numpy as np
logger = logging.getLogger(__name__)
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b")
# Document category → prompt context
CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = {
"vokabelseite": {
"label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)",
"columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.",
},
"woerterbuch": {
"label": "Woerterbuchseite",
"columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.",
},
"arbeitsblatt": {
"label": "Arbeitsblatt",
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
},
"buchseite": {
"label": "Schulbuchseite",
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
},
}
def _group_words_into_lines(
words: List[Dict], y_tolerance: float = 15.0,
) -> List[List[Dict]]:
"""Group OCR words into lines by Y-proximity."""
if not words:
return []
sorted_w = sorted(words, key=lambda w: w.get("top", 0))
lines: List[List[Dict]] = [[sorted_w[0]]]
for w in sorted_w[1:]:
last_line = lines[-1]
avg_y = sum(ww["top"] for ww in last_line) / len(last_line)
if abs(w["top"] - avg_y) <= y_tolerance:
last_line.append(w)
else:
lines.append([w])
# Sort words within each line by X
for line in lines:
line.sort(key=lambda w: w.get("left", 0))
return lines
def _build_ocr_context(words: List[Dict], img_h: int) -> str:
"""Build a text description of OCR words with positions for the prompt."""
lines = _group_words_into_lines(words)
context_parts = []
for i, line in enumerate(lines):
word_descs = []
for w in line:
text = w.get("text", "").strip()
x = w.get("left", 0)
conf = w.get("conf", 0)
marker = " (?)" if conf < 50 else ""
word_descs.append(f'x={x} "{text}"{marker}')
avg_y = int(sum(w["top"] for w in line) / len(line))
context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}")
return "\n".join(context_parts)
def _build_prompt(
ocr_context: str, category: str, img_w: int, img_h: int,
) -> str:
"""Build the Vision-LLM prompt with OCR context and document type."""
cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"])
return f"""Du siehst eine eingescannte {cat_info['label']}.
{cat_info['columns']}
Die OCR-Software hat folgende Woerter an diesen Positionen erkannt.
Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch:
{ocr_context}
Bildgroesse: {img_w} x {img_h} Pixel.
AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle.
- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst
- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist,
gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht)
- Behalte die Reihenfolge bei
Antworte NUR mit einem JSON-Array, keine Erklaerungen:
[
{{"row": 1, "english": "...", "german": "...", "example": "..."}},
{{"row": 2, "english": "...", "german": "...", "example": "..."}}
]"""
def _parse_llm_response(response_text: str) -> Optional[List[Dict]]:
"""Parse the LLM JSON response, handling markdown code blocks."""
text = response_text.strip()
# Strip markdown code block if present
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```\s*$", "", text)
# Try to find JSON array
match = re.search(r"\[[\s\S]*\]", text)
if not match:
logger.warning("vision_fuse_ocr: no JSON array found in LLM response")
return None
try:
data = json.loads(match.group())
if not isinstance(data, list):
return None
return data
except json.JSONDecodeError as e:
logger.warning(f"vision_fuse_ocr: JSON parse error: {e}")
return None
def _vocab_rows_to_words(
rows: List[Dict], img_w: int, img_h: int,
) -> List[Dict]:
"""Convert LLM vocab rows back to word dicts for grid building.
Distributes words across estimated column positions so the
existing grid builder can process them normally.
"""
words = []
# Estimate column positions (3-column vocab layout)
col_positions = [
(0.02, 0.28), # EN: 2%-28% of width
(0.30, 0.55), # DE: 30%-55%
(0.57, 0.98), # Example: 57%-98%
]
median_h = max(15, img_h // (len(rows) * 3)) if rows else 20
y_step = max(median_h + 5, img_h // max(len(rows), 1))
for i, row in enumerate(rows):
y = int(i * y_step + 20)
row_num = row.get("row", i + 1)
for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([
("english", col_positions[0]),
("german", col_positions[1]),
("example", col_positions[2]),
]):
text = (row.get(field) or "").strip()
if not text:
continue
x = int(x_start_pct * img_w)
w = int((x_end_pct - x_start_pct) * img_w)
words.append({
"text": text,
"left": x,
"top": y,
"width": w,
"height": median_h,
"conf": 95, # LLM-corrected → high confidence
"_source": "vision_llm",
"_row": row_num,
"_col_type": f"column_{['en', 'de', 'example'][col_idx]}",
})
logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words")
return words
async def vision_fuse_ocr(
img_bgr: np.ndarray,
ocr_words: List[Dict],
document_category: str = "vokabelseite",
) -> List[Dict]:
"""Fuse traditional OCR results with Vision-LLM reading.
Sends the image + OCR word positions to Qwen2.5-VL which can:
- Read degraded text that traditional OCR cannot
- Use document context (knows what a vocab table looks like)
- Merge continuation rows (understands table structure)
Args:
img_bgr: The cropped/dewarped scan image (BGR)
ocr_words: Traditional OCR word list with positions
document_category: Type of document being scanned
Returns:
Corrected word list in same format as input, ready for grid building.
Falls back to original ocr_words on error.
"""
img_h, img_w = img_bgr.shape[:2]
# Build OCR context string
ocr_context = _build_ocr_context(ocr_words, img_h)
# Build prompt
prompt = _build_prompt(ocr_context, document_category, img_w, img_h)
# Encode image as base64
_, img_encoded = cv2.imencode(".png", img_bgr)
img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8")
# Call Qwen2.5-VL via Ollama
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": VISION_FUSION_MODEL,
"prompt": prompt,
"images": [img_b64],
"stream": False,
"options": {"temperature": 0.1, "num_predict": 4096},
},
)
resp.raise_for_status()
data = resp.json()
response_text = data.get("response", "").strip()
except Exception as e:
logger.error(f"vision_fuse_ocr: Ollama call failed: {e}")
return ocr_words # Fallback to original
if not response_text:
logger.warning("vision_fuse_ocr: empty LLM response")
return ocr_words
# Parse JSON response
rows = _parse_llm_response(response_text)
if not rows:
logger.warning(
"vision_fuse_ocr: could not parse LLM response, "
"first 200 chars: %s", response_text[:200],
)
return ocr_words
logger.info(
f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows "
f"(from {len(ocr_words)} OCR words)"
)
# Convert back to word format for grid building
return _vocab_rows_to_words(rows, img_w, img_h)

View File

@@ -0,0 +1,185 @@
"""
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
and ground truth endpoints.
Split into sub-modules:
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
ocr_pipeline_words_stream — SSE streaming generators
This barrel module contains the PaddleOCR direct endpoint and ground truth
endpoints, and assembles all word-related routers.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_words_first import build_grid_from_words
from .session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from .common import (
_cache,
_append_pipeline_log,
)
from .words_detect import router as _detect_router
logger = logging.getLogger(__name__)
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
# ---------------------------------------------------------------------------
# PaddleOCR Direct Endpoint
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/paddle-direct")
async def paddle_direct(session_id: str):
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode original image")
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
word_dicts = await ocr_region_paddle(img_bgr, region=None)
if not word_dicts:
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "paddle_direct"
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": "paddle_direct",
"grid_method": "paddle_direct",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
await update_session_db(
session_id,
word_result=word_result,
cropped_png=img_png,
current_step=8,
)
logger.info(
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
session_id, len(cells), n_rows, n_cols, duration,
)
await _append_pipeline_log(session_id, "paddle_direct", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": "paddle_direct",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
# ---------------------------------------------------------------------------
# Ground Truth Words Endpoints
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_entries": req.corrected_entries,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"word_result": session.get("word_result"),
}
ground_truth["words"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@_local_router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
words_gt = ground_truth.get("words")
if not words_gt:
raise HTTPException(status_code=404, detail="No word ground truth saved")
return {
"session_id": session_id,
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Composite router
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_detect_router)
router.include_router(_local_router)

View File

@@ -0,0 +1,393 @@
"""
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
endpoint which handles both v2 and words_first grid methods.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
create_ocr_image,
detect_column_geometry,
)
from cv_words_first import build_grid_from_words
from .session_store import (
get_session_db,
update_session_db,
)
from .common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
)
from .words_stream import (
_word_batch_stream_generator,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Word Detection Endpoint (Step 7)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
skip_heal_gaps: bool = False,
grid_method: str = "v2",
):
"""Build word grid from columns x rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
pronunciation: 'british' (default) or 'american'
stream: false (default) for JSON response, true for SSE streaming
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
grid_method: 'v2' (default) or 'words_first'
"""
# PaddleOCR is full-page remote OCR -> force words_first grid method
if engine == "paddle" and grid_method != "words_first":
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
grid_method = "words_first"
if session_id not in _cache:
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": img_w_tmp, "height": img_h_tmp,
"classification_confidence": 1.0,
"classification_method": "full_page_fallback",
}],
"zones": [],
"duration_seconds": 0,
}
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Populate word counts from cached words
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
else:
top_y = min(r.y for r in row_geoms) if row_geoms else 0
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
# Exclude rows that fall within box zones
zones = column_result.get("zones") or []
box_ranges_inner = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5)
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges_inner:
def _row_in_box(r):
center_y = r.y + r.height / 2
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
before_count = len(row_geoms)
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
excluded = before_count - len(row_geoms)
if excluded:
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
# --- Words-First path ---
if grid_method == "words_first":
return await _words_first_path(
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
)
if stream:
return StreamingResponse(
_word_batch_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
skip_heal_gaps=skip_heal_gaps,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (grid_method=v2) ---
return await _v2_path(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
)
async def _words_first_path(
session_id: str,
cached: Dict[str, Any],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
zones: list,
) -> dict:
"""Words-first grid construction path."""
t0 = time.time()
img_h, img_w = dewarped_bgr.shape[:2]
if engine == "paddle":
from cv_ocr_engines import ocr_region_paddle
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
cached["_paddle_word_dicts"] = wf_word_dicts
else:
wf_word_dicts = cached.get("_word_dicts")
if wf_word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
cached["_word_dicts"] = wf_word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if not wf_word_dicts:
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
# Convert word coordinates to absolute if needed
if engine != "paddle":
content_bounds = cached.get("_content_bounds")
if content_bounds:
lx, _rx, ty, _by = content_bounds
abs_words = []
for w in wf_word_dicts:
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
wf_word_dicts = abs_words
box_rects = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box_rects.append(zone["box"])
cells, columns_meta = build_grid_from_words(
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
)
duration = time.time() - t0
fix_cell_phonetics(cells, pronunciation=pronunciation)
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
n_cols = len(columns_meta)
used_engine = "paddle" if engine == "paddle" else "words_first"
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"grid_method": "words_first",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
if is_vocab or 'column_text' in col_types:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words-first session {session_id}: "
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
await _append_pipeline_log(session_id, "words", {
"grid_method": "words_first",
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
async def _v2_path(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
skip_heal_gaps: bool,
) -> dict:
"""Cell-First OCR v2 non-streaming path."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
)
duration = time.time() - t0
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len(columns_meta)
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
fix_cell_phonetics(cells, pronunciation=pronunciation)
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
await _append_pipeline_log(session_id, "words", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"low_confidence_count": word_result["summary"]["low_confidence"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
"entry_count": word_result.get("entry_count", 0),
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}

View File

@@ -0,0 +1,303 @@
"""
OCR Pipeline Words Stream — SSE streaming generators for word detection.
Extracted from ocr_pipeline_words.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import Request
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
create_ocr_image,
)
from .session_store import update_session_db
from .common import _cache
logger = logging.getLogger(__name__)
async def _word_batch_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
skip_heal_gaps: bool = False,
):
"""SSE generator that runs batch OCR (parallel) then streams results.
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
then emits all cells as SSE events.
"""
import asyncio
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
total_cells = n_content_rows * n_cols
# 1. Send meta event immediately
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# 2. Send preparing event (keepalive for proxy)
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
# 3. Run batch OCR in thread pool with periodic keepalive events.
loop = asyncio.get_event_loop()
ocr_future = loop.run_in_executor(
None,
lambda: build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
),
)
# Send keepalive events every 5 seconds while OCR runs
keepalive_count = 0
while not ocr_future.done():
try:
cells, columns_meta = await asyncio.wait_for(
asyncio.shield(ocr_future), timeout=5.0,
)
break # OCR finished
except asyncio.TimeoutError:
keepalive_count += 1
elapsed = int(time.time() - t0)
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
ocr_future.cancel()
return
else:
cells, columns_meta = ocr_future.result()
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
return
# 4. Apply IPA phonetic fixes
fix_cell_phonetics(cells, pronunciation=pronunciation)
# 5. Send columns meta
if columns_meta:
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
# 6. Stream all cells
for idx, cell in enumerate(cells):
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": idx + 1, "total": len(cells)},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# 7. Build final result and persist
duration = time.time() - t0
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
# 8. Send complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
columns_meta = None
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
last_keepalive = time.time()
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
meta_update = {"type": "columns", "columns_used": cols_meta}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
# Remove all-empty rows
rows_with_text: set = set()
for c in all_cells:
if c.get("text", "").strip():
rows_with_text.add(c["row_index"])
before_filter = len(all_cells)
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
if empty_rows_removed > 0:
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
word_result = {
"cells": all_cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"

View File

@@ -1,81 +1,4 @@
"""
OCR Labeling API — Barrel Re-export
Split into:
- ocr_labeling_models.py — Pydantic models and constants
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
- ocr_labeling_routes.py — Session/queue/labeling route handlers
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
All public names are re-exported here for backward compatibility.
"""
# Models
from ocr_labeling_models import ( # noqa: F401
LOCAL_STORAGE_PATH,
SessionCreate,
SessionResponse,
ItemResponse,
ConfirmRequest,
CorrectRequest,
SkipRequest,
ExportRequest,
StatsResponse,
)
# Helpers
from ocr_labeling_helpers import ( # noqa: F401
VISION_OCR_AVAILABLE,
PADDLEOCR_AVAILABLE,
TROCR_AVAILABLE,
DONUT_AVAILABLE,
MINIO_AVAILABLE,
TRAINING_EXPORT_AVAILABLE,
compute_image_hash,
run_ocr_on_image,
run_vision_ocr_wrapper,
run_paddleocr_wrapper,
run_trocr_wrapper,
run_donut_wrapper,
save_image_locally,
get_image_url,
)
# Conditional re-exports from helpers' optional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
except ImportError:
pass
try:
from training_export_service import ( # noqa: F401
TrainingExportService,
TrainingSample,
get_training_export_service,
)
except ImportError:
pass
try:
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
except ImportError:
pass
try:
from services.trocr_service import run_trocr_ocr # noqa: F401
except ImportError:
pass
try:
from services.donut_ocr_service import run_donut_ocr # noqa: F401
except ImportError:
pass
try:
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
except ImportError:
pass
# Routes (router is the main export for app.include_router)
from ocr_labeling_routes import router # noqa: F401
from ocr_labeling_upload_routes import router as upload_router # noqa: F401
# Backward-compat shim -- module moved to ocr/labeling/api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.api")

View File

@@ -1,205 +1,4 @@
"""
OCR Labeling - Helper Functions and OCR Wrappers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
"""
import os
import hashlib
from ocr_labeling_models import LOCAL_STORAGE_PATH
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Try to import PaddleOCR from hybrid_vocab_extractor
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path
# Backward-compat shim -- module moved to ocr/labeling/helpers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.helpers")

View File

@@ -1,86 +1,4 @@
"""
OCR Labeling - Pydantic Models and Constants
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
"""
import os
from pydantic import BaseModel
from typing import Optional, Dict
from datetime import datetime
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None
# Backward-compat shim -- module moved to ocr/labeling/models.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.models")

View File

@@ -1,241 +1,4 @@
"""
OCR Labeling - Session and Labeling Route Handlers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
Endpoints:
- POST /sessions - Create labeling session
- GET /sessions - List sessions
- GET /sessions/{id} - Get session
- GET /queue - Get labeling queue
- GET /items/{id} - Get item
- POST /confirm - Confirm OCR
- POST /correct - Correct ground truth
- POST /skip - Skip item
- GET /stats - Get statistics
"""
from fastapi import APIRouter, HTTPException, Query
from typing import Optional, List
from datetime import datetime
import uuid
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
)
from ocr_labeling_models import (
SessionCreate, SessionResponse, ItemResponse,
ConfirmRequest, CorrectRequest, SkipRequest,
)
from ocr_labeling_helpers import get_image_url
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# =============================================================================
# Session Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""Create a new OCR labeling session."""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Queue and Item Endpoints
# =============================================================================
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""Get items from the labeling queue."""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Labeling Action Endpoints
# =============================================================================
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""Confirm that OCR text is correct."""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""Save corrected ground truth for an item."""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""Skip an item (unusable image, etc.)."""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin",
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""Get labeling statistics."""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats
# Backward-compat shim -- module moved to ocr/labeling/routes.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.routes")

View File

@@ -1,313 +1,4 @@
"""
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
Endpoints:
- POST /sessions/{id}/upload - Upload images for labeling
- POST /run-ocr/{item_id} - Run OCR on existing item
- POST /export - Export training data
- GET /training-samples - List training samples
- GET /images/{path} - Serve images from local storage
- GET /exports - List exports
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from typing import Optional, List
import uuid
import os
from metrics_db import (
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_item,
export_training_samples,
get_training_samples,
)
from ocr_labeling_models import (
ExportRequest,
LOCAL_STORAGE_PATH,
)
from ocr_labeling_helpers import (
compute_image_hash, run_ocr_on_image,
save_image_locally,
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
)
# Conditional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image
except ImportError:
pass
try:
from training_export_service import TrainingSample, get_training_export_service
except ImportError:
pass
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None),
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
content = await file.read()
image_hash = compute_image_hash(content)
item_id = str(uuid.uuid4())
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
if MINIO_AVAILABLE:
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
image_path = save_image_locally(session_id, item_id, content, extension)
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf':
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.post("/export")
async def export_data(request: ExportRequest):
"""Export labeled data for training."""
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin",
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""Serve an image from local storage."""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""Run OCR on an existing item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""List all available training data exports."""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
# Backward-compat shim -- module moved to ocr/labeling/upload_routes.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.labeling.upload_routes")

View File

@@ -1,272 +1,4 @@
"""
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
Extracted from ocr_pipeline_ocr_merge.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import List
logger = logging.getLogger(__name__)
def _split_paddle_multi_words(words: list) -> list:
"""Split PaddleOCR multi-word boxes into individual word boxes.
PaddleOCR often returns entire phrases as a single box, e.g.
"More than 200 singers took part in the" with one bounding box.
This splits them into individual words with proportional widths.
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
"""
import re
result = []
for w in words:
raw_text = w.get("text", "").strip()
if not raw_text:
continue
# Split on whitespace, before "[" (IPA), and after "!" before letter
tokens = re.split(
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
)
tokens = [t for t in tokens if t]
if len(tokens) <= 1:
result.append(w)
else:
# Split proportionally by character count
total_chars = sum(len(t) for t in tokens)
if total_chars == 0:
continue
n_gaps = len(tokens) - 1
gap_px = w["width"] * 0.02
usable_w = w["width"] - gap_px * n_gaps
cursor = w["left"]
for t in tokens:
token_w = max(1, usable_w * len(t) / total_chars)
result.append({
"text": t,
"left": round(cursor),
"top": w["top"],
"width": round(token_w),
"height": w["height"],
"conf": w.get("conf", 0),
})
cursor += token_w + gap_px
return result
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
"""Group words into rows by Y-position clustering.
Words whose vertical centers are within `row_gap` pixels are on the same row.
Returns list of rows, each row is a list of words sorted left-to-right.
"""
if not words:
return []
# Sort by vertical center
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
rows: list = []
current_row: list = [sorted_words[0]]
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
for w in sorted_words[1:]:
cy = w["top"] + w.get("height", 0) / 2
if abs(cy - current_cy) <= row_gap:
current_row.append(w)
else:
# Sort current row left-to-right before saving
rows.append(sorted(current_row, key=lambda w: w["left"]))
current_row = [w]
current_cy = cy
if current_row:
rows.append(sorted(current_row, key=lambda w: w["left"]))
return rows
def _row_center_y(row: list) -> float:
"""Average vertical center of a row of words."""
if not row:
return 0.0
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
"""Merge two word sequences from the same row using sequence alignment.
Both sequences are sorted left-to-right. Walk through both simultaneously:
- If words match (same/similar text): take Paddle text with averaged coords
- If they don't match: the extra word is unique to one engine, include it
"""
merged = []
pi, ti = 0, 0
while pi < len(paddle_row) and ti < len(tess_row):
pw = paddle_row[pi]
tw = tess_row[ti]
pt = pw.get("text", "").lower().strip()
tt = tw.get("text", "").lower().strip()
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
# Spatial overlap check
spatial_match = False
if not is_same:
overlap_left = max(pw["left"], tw["left"])
overlap_right = min(
pw["left"] + pw.get("width", 0),
tw["left"] + tw.get("width", 0),
)
overlap_w = max(0, overlap_right - overlap_left)
min_w = min(pw.get("width", 1), tw.get("width", 1))
if min_w > 0 and overlap_w / min_w >= 0.4:
is_same = True
spatial_match = True
if is_same:
pc = pw.get("conf", 80)
tc = tw.get("conf", 50)
total = pc + tc
if total == 0:
total = 1
if spatial_match and pc < tc:
best_text = tw["text"]
else:
best_text = pw["text"]
merged.append({
"text": best_text,
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
"conf": max(pc, tc),
})
pi += 1
ti += 1
else:
paddle_ahead = any(
tess_row[t].get("text", "").lower().strip() == pt
for t in range(ti + 1, min(ti + 4, len(tess_row)))
)
tess_ahead = any(
paddle_row[p].get("text", "").lower().strip() == tt
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
)
if paddle_ahead and not tess_ahead:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
elif tess_ahead and not paddle_ahead:
merged.append(pw)
pi += 1
else:
if pw["left"] <= tw["left"]:
merged.append(pw)
pi += 1
else:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
while pi < len(paddle_row):
merged.append(paddle_row[pi])
pi += 1
while ti < len(tess_row):
tw = tess_row[ti]
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
return merged
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
if not paddle_words and not tess_words:
return []
if not paddle_words:
return [w for w in tess_words if w.get("conf", 0) >= 40]
if not tess_words:
return list(paddle_words)
paddle_rows = _group_words_into_rows(paddle_words)
tess_rows = _group_words_into_rows(tess_words)
used_tess_rows: set = set()
merged_all: list = []
for pr in paddle_rows:
pr_cy = _row_center_y(pr)
best_dist, best_tri = float("inf"), -1
for tri, tr in enumerate(tess_rows):
if tri in used_tess_rows:
continue
tr_cy = _row_center_y(tr)
dist = abs(pr_cy - tr_cy)
if dist < best_dist:
best_dist, best_tri = dist, tri
max_row_dist = max(
max((w.get("height", 20) for w in pr), default=20),
15,
)
if best_tri >= 0 and best_dist <= max_row_dist:
tr = tess_rows[best_tri]
used_tess_rows.add(best_tri)
merged_all.extend(_merge_row_sequences(pr, tr))
else:
merged_all.extend(pr)
for tri, tr in enumerate(tess_rows):
if tri not in used_tess_rows:
for tw in tr:
if tw.get("conf", 0) >= 40:
merged_all.append(tw)
return merged_all
def _deduplicate_words(words: list) -> list:
"""Remove duplicate words with same text at overlapping positions."""
if not words:
return words
result: list = []
for w in words:
wt = w.get("text", "").lower().strip()
if not wt:
continue
is_dup = False
w_right = w["left"] + w.get("width", 0)
w_bottom = w["top"] + w.get("height", 0)
for existing in result:
et = existing.get("text", "").lower().strip()
if wt != et:
continue
ox_l = max(w["left"], existing["left"])
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
ox = max(0, ox_r - ox_l)
min_w = min(w.get("width", 1), existing.get("width", 1))
if min_w <= 0 or ox / min_w < 0.5:
continue
oy_t = max(w["top"], existing["top"])
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
oy = max(0, oy_b - oy_t)
min_h = min(w.get("height", 1), existing.get("height", 1))
if min_h > 0 and oy / min_h >= 0.5:
is_dup = True
break
if not is_dup:
result.append(w)
removed = len(words) - len(result)
if removed:
logger.info("dedup: removed %d duplicate words", removed)
return result
# Backward-compat shim -- module moved to ocr/pipeline/merge_helpers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.merge_helpers")

View File

@@ -1,63 +1,4 @@
"""
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
Thin wrapper that assembles all sub-module routers into a single
composite router. Backward-compatible: main.py and tests can still
import ``router``, ``_cache``, and helper functions from here.
Sub-modules (each < 1 000 lines):
ocr_pipeline_common shared state, cache, Pydantic models, helpers
ocr_pipeline_sessions session CRUD, image serving, doc-type
ocr_pipeline_geometry deskew, dewarp, structure, columns
ocr_pipeline_rows row detection, box-overlay helper
ocr_pipeline_words word detection (SSE), paddle-direct, word GT
ocr_pipeline_ocr_merge paddle/tesseract merge helpers, kombi endpoints
ocr_pipeline_postprocess LLM review, reconstruction, export, validation
ocr_pipeline_auto auto-mode orchestrator, reprocess
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
# ---------------------------------------------------------------------------
# Shared state (imported by main.py and orientation_crop_api.py)
# ---------------------------------------------------------------------------
from ocr_pipeline_common import ( # noqa: F401 re-exported
_cache,
_BORDER_GHOST_CHARS,
_filter_border_ghost_words,
)
# ---------------------------------------------------------------------------
# Sub-module routers
# ---------------------------------------------------------------------------
from ocr_pipeline_sessions import router as _sessions_router
from ocr_pipeline_geometry import router as _geometry_router
from ocr_pipeline_rows import router as _rows_router
from ocr_pipeline_words import router as _words_router
from ocr_pipeline_ocr_merge import (
router as _ocr_merge_router,
# Re-export for test backward compatibility
_split_paddle_multi_words, # noqa: F401
_group_words_into_rows, # noqa: F401
_merge_row_sequences, # noqa: F401
_merge_paddle_tesseract, # noqa: F401
)
from ocr_pipeline_postprocess import router as _postprocess_router
from ocr_pipeline_auto import router as _auto_router
from ocr_pipeline_regression import router as _regression_router
# ---------------------------------------------------------------------------
# Composite router (used by main.py)
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_sessions_router)
router.include_router(_geometry_router)
router.include_router(_rows_router)
router.include_router(_words_router)
router.include_router(_ocr_merge_router)
router.include_router(_postprocess_router)
router.include_router(_auto_router)
router.include_router(_regression_router)
# Backward-compat shim -- module moved to ocr/pipeline/api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.api")

View File

@@ -1,23 +1,4 @@
"""
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
Split into submodules:
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from ocr_pipeline_reprocess import router as _reprocess_router
from ocr_pipeline_auto_steps import router as _steps_router
# Combine both sub-routers into a single router for backwards compatibility.
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
router.include_router(_reprocess_router)
router.include_router(_steps_router)
__all__ = ["router"]
# Backward-compat shim -- module moved to ocr/pipeline/auto.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto")

View File

@@ -1,84 +1,4 @@
"""
OCR Pipeline Auto-Mode Helpers.
VLM shear detection, SSE event formatting, and request models.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from typing import Any, Dict
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
payload = {"step": step, "status": status, **data}
return f"data: {json.dumps(payload)}\n\n"
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
# Backward-compat shim -- module moved to ocr/pipeline/auto_helpers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_helpers")

View File

@@ -1,528 +1,4 @@
"""
OCR Pipeline Auto-Mode Orchestrator.
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_auto_helpers import (
RunAutoRequest,
auto_sse_event as _auto_sse_event,
detect_shear_with_vlm as _detect_shear_with_vlm,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew -- straighten the scan
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
3. Columns -- detect column layout
4. Rows -- detect row layout
5. Words -- OCR each cell
6. LLM review -- correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# Step 1: Deskew
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# Step 2: Dewarp
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# Step 3: Columns
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# Step 4: Rows
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed -- cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# Step 5: Words (OCR)
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# Step 6: LLM Review (optional)
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# Final event
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# Backward-compat shim -- module moved to ocr/pipeline/auto_steps.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_steps")

View File

@@ -1,293 +1,4 @@
"""
OCR Pipeline Column Detection Endpoints (Step 5)
Detect invisible columns, manual column override, and ground truth.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from dataclasses import asdict
from datetime import datetime
from typing import Dict, List
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
_detect_header_footer_gaps,
_detect_sub_columns,
classify_column_types,
create_layout_image,
create_ocr_image,
analyze_layout,
detect_column_geometry_zoned,
expand_narrow_columns,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualColumnsRequest,
ColumnGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/columns")
async def detect_columns(session_id: str):
"""Run column detection on the cropped (or dewarped) image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
# -----------------------------------------------------------------------
# Sub-sessions (box crops): skip column detection entirely.
# Instead, create a single pseudo-column spanning the full image width.
# Also run Tesseract + binarization here so that the row detection step
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
# instead of falling back to detect_column_geometry() which may fail
# on small box images with < 5 words.
# -----------------------------------------------------------------------
session = await get_session_db(session_id)
if session and session.get("parent_session_id"):
h, w = img_bgr.shape[:2]
# Binarize + invert for row detection (horizontal projection profile)
ocr_img = create_ocr_image(img_bgr)
inv = cv2.bitwise_not(ocr_img)
# Run Tesseract to get word bounding boxes.
try:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
import pytesseract
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
word_dicts = []
for i in range(len(data['text'])):
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
text = str(data['text'][i]).strip()
if conf < 30 or not text:
continue
word_dicts.append({
'text': text, 'conf': conf,
'left': int(data['left'][i]),
'top': int(data['top'][i]),
'width': int(data['width'][i]),
'height': int(data['height'][i]),
})
# Log all words including low-confidence ones for debugging
all_count = sum(1 for i in range(len(data['text']))
if str(data['text'][i]).strip())
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
for i in range(len(data['text']))
if str(data['text'][i]).strip()
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
if low_conf:
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
except Exception as e:
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
word_dicts = []
# Cache intermediates for row detection (detect_rows reuses these)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (0, w, 0, h)
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": w, "height": h,
}],
"zones": None,
"boxes_detected": 0,
"duration_seconds": 0,
"method": "sub_session_pseudo_column",
}
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
return {"session_id": session_id, **column_result}
t0 = time.time()
# Binarized image for layout analysis
ocr_img = create_ocr_image(img_bgr)
h, w = ocr_img.shape[:2]
# Phase A: Zone-aware geometry detection
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
boxes_detected = 0
if zoned_result is None:
# Fallback to projection-based layout
layout_img = create_layout_image(img_bgr)
regions = analyze_layout(layout_img, ocr_img)
zones_data = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
content_w = right_x - left_x
boxes_detected = len(boxes)
# Cache intermediates for row detection (avoids second Tesseract run)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
cached["_zones_data"] = zones_data
cached["_boxes_detected"] = boxes_detected
# Detect header/footer early so sub-column clustering ignores them
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
# Split sub-columns (e.g. page references) before classification
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
# Expand narrow columns (sub-columns are often very narrow)
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
# Phase B: Content-based classification
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
duration = time.time() - t0
columns = [asdict(r) for r in regions]
# Determine classification methods used
methods = list(set(
c.get("classification_method", "") for c in columns
if c.get("classification_method")
))
column_result = {
"columns": columns,
"classification_methods": methods,
"duration_seconds": round(duration, 2),
"boxes_detected": boxes_detected,
}
# Add zone data when boxes are present
if zones_data and boxes_detected > 0:
column_result["zones"] = zones_data
# Persist to DB -- also invalidate downstream results (rows, words)
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
# Update cache
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
col_count = len([c for c in columns if c["type"].startswith("column")])
logger.info(f"OCR Pipeline: columns session {session_id}: "
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
img_w = img_bgr.shape[1]
await _append_pipeline_log(session_id, "columns", {
"total_columns": len(columns),
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
"column_types": [c["type"] for c in columns],
"boxes_detected": boxes_detected,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**column_result,
}
@router.post("/sessions/{session_id}/columns/manual")
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
"""Override detected columns with manual definitions."""
column_result = {
"columns": req.columns,
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None)
if session_id in _cache:
_cache[session_id]["column_result"] = column_result
_cache[session_id].pop("row_result", None)
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
f"{len(req.columns)} columns set")
return {"session_id": session_id, **column_result}
@router.post("/sessions/{session_id}/ground-truth/columns")
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
"""Save ground truth feedback for the column detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_columns": req.corrected_columns,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"column_result": session.get("column_result"),
}
ground_truth["columns"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/columns")
async def get_column_ground_truth(session_id: str):
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
columns_gt = ground_truth.get("columns")
if not columns_gt:
raise HTTPException(status_code=404, detail="No column ground truth saved")
return {
"session_id": session_id,
"columns_gt": columns_gt,
"columns_auto": session.get("column_result"),
}
# Backward-compat shim -- module moved to ocr/pipeline/columns.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.columns")

View File

@@ -1,354 +1,4 @@
"""
Shared common module for the OCR pipeline.
Contains in-memory cache, helper functions, Pydantic request models,
pipeline logging, and border-ghost word filtering used by the pipeline
API endpoints and related modules.
"""
import logging
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import HTTPException
from pydantic import BaseModel
from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db
__all__ = [
# Cache
"_cache",
# Helper functions
"_get_base_image_png",
"_load_session_to_cache",
"_get_cached",
# Pydantic models
"ManualDeskewRequest",
"DeskewGroundTruthRequest",
"ManualDewarpRequest",
"CombinedAdjustRequest",
"DewarpGroundTruthRequest",
"VALID_DOCUMENT_CATEGORIES",
"UpdateSessionRequest",
"ManualColumnsRequest",
"ColumnGroundTruthRequest",
"ManualRowsRequest",
"RowGroundTruthRequest",
"RemoveHandwritingRequest",
# Pipeline log
"_append_pipeline_log",
# Border-ghost filter
"_BORDER_GHOST_CHARS",
"_filter_border_ghost_words",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# In-memory cache for active sessions (BGR numpy arrays for processing)
# DB is source of truth, cache holds BGR arrays during active processing.
# ---------------------------------------------------------------------------
_cache: Dict[str, Dict[str, Any]] = {}
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
"""Get the best available base image for a session (cropped > dewarped > original)."""
for img_type in ("cropped", "dewarped", "original"):
png_data = await get_session_image(session_id, img_type)
if png_data:
return png_data
return None
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
if session_id in _cache:
return _cache[session_id]
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
# Decode images from DB into BGR numpy arrays
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
# Sub-sessions: original image IS the cropped box region.
# Promote original_bgr to cropped_bgr so downstream steps find it.
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
_cache[session_id] = cache_entry
return cache_entry
def _get_cached(session_id: str) -> Dict[str, Any]:
"""Get from cache or raise 404."""
entry = _cache.get(session_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
return entry
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class ManualDeskewRequest(BaseModel):
angle: float
class DeskewGroundTruthRequest(BaseModel):
is_correct: bool
corrected_angle: Optional[float] = None
notes: Optional[str] = None
class ManualDewarpRequest(BaseModel):
shear_degrees: float
class CombinedAdjustRequest(BaseModel):
rotation_degrees: float = 0.0
shear_degrees: float = 0.0
class DewarpGroundTruthRequest(BaseModel):
is_correct: bool
corrected_shear: Optional[float] = None
notes: Optional[str] = None
VALID_DOCUMENT_CATEGORIES = {
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
}
class UpdateSessionRequest(BaseModel):
name: Optional[str] = None
document_category: Optional[str] = None
class ManualColumnsRequest(BaseModel):
columns: List[Dict[str, Any]]
class ColumnGroundTruthRequest(BaseModel):
is_correct: bool
corrected_columns: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class ManualRowsRequest(BaseModel):
rows: List[Dict[str, Any]]
class RowGroundTruthRequest(BaseModel):
is_correct: bool
corrected_rows: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class RemoveHandwritingRequest(BaseModel):
method: str = "auto" # "auto" | "telea" | "ns"
target_ink: str = "all" # "all" | "colored" | "pencil"
dilation: int = 2 # mask dilation iterations (0-5)
use_source: str = "auto" # "original" | "deskewed" | "auto"
# ---------------------------------------------------------------------------
# Pipeline Log Helper
# ---------------------------------------------------------------------------
async def _append_pipeline_log(
session_id: str,
step_name: str,
metrics: Dict[str, Any],
success: bool = True,
duration_ms: Optional[int] = None,
):
"""Append a step entry to the session's pipeline_log JSONB."""
session = await get_session_db(session_id)
if not session:
return
log = session.get("pipeline_log") or {"steps": []}
if not isinstance(log, dict):
log = {"steps": []}
entry = {
"step": step_name,
"completed_at": datetime.utcnow().isoformat(),
"success": success,
"metrics": metrics,
}
if duration_ms is not None:
entry["duration_ms"] = duration_ms
log.setdefault("steps", []).append(entry)
await update_session_db(session_id, pipeline_log=log)
# ---------------------------------------------------------------------------
# Border-ghost word filter
# ---------------------------------------------------------------------------
# Characters that OCR produces when reading box-border lines.
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—_~.,;:'\"")
def _filter_border_ghost_words(
word_result: Dict,
boxes: List,
) -> int:
"""Remove OCR words that are actually box border lines.
A word is considered a border ghost when it sits on a known box edge
(left, right, top, or bottom) and looks like a line artefact (narrow
aspect ratio or text consists only of line-like characters).
After removing ghost cells, columns that have become empty are also
removed from ``columns_used`` so the grid no longer shows phantom
columns.
Modifies *word_result* in-place and returns the number of removed cells.
"""
if not boxes or not word_result:
return 0
cells = word_result.get("cells")
if not cells:
return 0
# Build border bands — vertical (X) and horizontal (Y)
x_bands = [] # list of (x_lo, x_hi)
y_bands = [] # list of (y_lo, y_hi)
for b in boxes:
bx = b.x if hasattr(b, "x") else b.get("x", 0)
by = b.y if hasattr(b, "y") else b.get("y", 0)
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
margin = max(bt * 2, 10) + 6 # generous margin
# Vertical edges (left / right)
x_bands.append((bx - margin, bx + margin))
x_bands.append((bx + bw - margin, bx + bw + margin))
# Horizontal edges (top / bottom)
y_bands.append((by - margin, by + margin))
y_bands.append((by + bh - margin, by + bh + margin))
img_w = word_result.get("image_width", 1)
img_h = word_result.get("image_height", 1)
def _is_ghost(cell: Dict) -> bool:
text = (cell.get("text") or "").strip()
if not text:
return False
# Compute absolute pixel position
if cell.get("bbox_px"):
px = cell["bbox_px"]
cx = px["x"] + px["w"] / 2
cy = px["y"] + px["h"] / 2
cw = px["w"]
ch = px["h"]
elif cell.get("bbox_pct"):
pct = cell["bbox_pct"]
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
cw = (pct["w"] / 100) * img_w
ch = (pct["h"] / 100) * img_h
else:
return False
# Check if center sits on a vertical or horizontal border
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
if not on_vertical and not on_horizontal:
return False
# Very short text (1-2 chars) on a border → very likely ghost
if len(text) <= 2:
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
if ch > 0 and cw / ch < 0.5:
return True
if cw > 0 and ch / cw < 0.5:
return True
# Text is only border-ghost characters?
if all(c in _BORDER_GHOST_CHARS for c in text):
return True
# Longer text but still only ghost chars and very narrow
if all(c in _BORDER_GHOST_CHARS for c in text):
if ch > 0 and cw / ch < 0.35:
return True
if cw > 0 and ch / cw < 0.35:
return True
return True # all ghost chars on a border → remove
return False
before = len(cells)
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
removed = before - len(word_result["cells"])
# --- Remove empty columns from columns_used ---
columns_used = word_result.get("columns_used")
if removed and columns_used and len(columns_used) > 1:
remaining_cells = word_result["cells"]
occupied_cols = {c.get("col_index") for c in remaining_cells}
before_cols = len(columns_used)
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
# Re-index columns and remap cell col_index values
if len(columns_used) < before_cols:
old_to_new = {}
for new_i, col in enumerate(columns_used):
old_to_new[col["index"]] = new_i
col["index"] = new_i
for cell in remaining_cells:
old_ci = cell.get("col_index")
if old_ci in old_to_new:
cell["col_index"] = old_to_new[old_ci]
word_result["columns_used"] = columns_used
logger.info("border-ghost: removed %d empty column(s), %d remaining",
before_cols - len(columns_used), len(columns_used))
if removed:
# Update summary counts
summary = word_result.get("summary", {})
summary["total_cells"] = len(word_result["cells"])
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
word_result["summary"] = summary
gs = word_result.get("grid_shape", {})
gs["total_cells"] = len(word_result["cells"])
if columns_used is not None:
gs["cols"] = len(columns_used)
word_result["grid_shape"] = gs
return removed
# Backward-compat shim -- module moved to ocr/pipeline/common.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.common")

View File

@@ -1,236 +1,4 @@
"""
OCR Pipeline Deskew Endpoints (Step 2)
Auto deskew, manual deskew, and ground truth for the deskew step.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from datetime import datetime
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
deskew_two_pass,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualDeskewRequest,
DeskewGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/deskew")
async def auto_deskew(session_id: str):
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
# Ensure session is in cache
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
# Deskew runs right after orientation -- use oriented image, fall back to original
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
t0 = time.time()
# Two-pass deskew: iterative (+-5 deg) + word-alignment residual check
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
# Also run individual methods for reporting (non-authoritative)
try:
_, angle_hough = deskew_image(img_bgr.copy())
except Exception:
angle_hough = 0.0
success_enc, png_orig = cv2.imencode(".png", img_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
angle_wa = 0.0
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
duration = time.time() - t0
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
)
# Encode as PNG
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = deskewed_png_buf.tobytes() if success else b""
# Create binarized version
binarized_png = None
try:
binarized = create_ocr_image(deskewed_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception as e:
logger.warning(f"Binarization failed: {e}")
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
deskew_result = {
"angle_hough": round(angle_hough, 3),
"angle_word_alignment": round(angle_wa, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"angle_applied": round(angle_applied, 3),
"method_used": method_used,
"confidence": round(confidence, 2),
"duration_seconds": round(duration, 2),
"two_pass_debug": two_pass_debug,
}
# Update cache
cached["deskewed_bgr"] = deskewed_bgr
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
"current_step": 3,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: deskew session {session_id}: "
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
f"textline={angle_textline:.2f} "
f"-> {method_used} total={angle_applied:.2f}")
await _append_pipeline_log(session_id, "deskew", {
"angle_applied": round(angle_applied, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"confidence": round(confidence, 2),
"method": method_used,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**deskew_result,
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
}
@router.post("/sessions/{session_id}/deskew/manual")
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
"""Apply a manual rotation angle to the oriented image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
angle = max(-5.0, min(5.0, req.angle))
h, w = img_bgr.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
deskewed_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(rotated)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(angle, 3),
"method_used": "manual",
}
# Update cache
cached["deskewed_bgr"] = rotated
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
return {
"session_id": session_id,
"angle_applied": round(angle, 3),
"method_used": "manual",
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
}
@router.post("/sessions/{session_id}/ground-truth/deskew")
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
"""Save ground truth feedback for the deskew step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_angle": req.corrected_angle,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"deskew_result": session.get("deskew_result"),
}
ground_truth["deskew"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
# Update cache
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
return {"session_id": session_id, "ground_truth": gt}
# Backward-compat shim -- module moved to ocr/pipeline/deskew.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.deskew")

View File

@@ -1,346 +1,4 @@
"""
OCR Pipeline Dewarp Endpoints
Auto dewarp (with VLM/CV ensemble), manual dewarp, combined
rotation+shear adjustment, and ground truth.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import json
import logging
import os
import re
import time
from datetime import datetime
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException, Query
from cv_vocab_pipeline import (
_apply_shear,
create_ocr_image,
dewarp_image,
dewarp_image_manual,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualDewarpRequest,
CombinedAdjustRequest,
DewarpGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
@router.post("/sessions/{session_id}/dewarp")
async def auto_dewarp(
session_id: str,
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
):
"""Detect and correct vertical shear on the deskewed image.
Methods:
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
- **cv**: CV ensemble only (same as ensemble)
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
"""
if method not in ("ensemble", "cv", "vlm"):
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
t0 = time.time()
if method == "vlm":
# Encode deskewed image to PNG for VLM
success, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
duration = time.time() - t0
# Encode as PNG
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
"detections": dewarp_info.get("detections", []),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
await _append_pipeline_log(session_id, "dewarp", {
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"method": dewarp_info["method"],
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**dewarp_result,
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/dewarp/manual")
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
"""Apply shear correction with a manual angle."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
if abs(shear_deg) < 0.001:
dewarped_bgr = deskewed_bgr
else:
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
)
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
return {
"session_id": session_id,
"shear_degrees": round(shear_deg, 3),
"method_used": "manual",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/adjust-combined")
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
"""Apply rotation + shear combined to the original image.
Used by the fine-tuning sliders to preview arbitrary rotation/shear
combinations without re-running the full deskew/dewarp pipeline.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
rotation = max(-15.0, min(15.0, req.rotation_degrees))
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
h, w = img_bgr.shape[:2]
result_bgr = img_bgr
# Step 1: Apply rotation
if abs(rotation) >= 0.001:
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
# Step 2: Apply shear
if abs(shear_deg) >= 0.001:
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
# Encode
success, png_buf = cv2.imencode(".png", result_bgr)
dewarped_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(result_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
# Build combined result dicts
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(rotation, 3),
"method_used": "manual_combined",
}
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual_combined",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["deskewed_bgr"] = result_bgr
cached["dewarped_bgr"] = result_bgr
cached["deskew_result"] = deskew_result
cached["dewarp_result"] = dewarp_result
# Persist to DB
db_update = {
"dewarped_png": dewarped_png,
"deskew_result": deskew_result,
"dewarp_result": dewarp_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
db_update["deskewed_png"] = dewarped_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
return {
"session_id": session_id,
"rotation_degrees": round(rotation, 3),
"shear_degrees": round(shear_deg, 3),
"method_used": "manual_combined",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/ground-truth/dewarp")
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
"""Save ground truth feedback for the dewarp step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_shear": req.corrected_shear,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"dewarp_result": session.get("dewarp_result"),
}
ground_truth["dewarp"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
return {"session_id": session_id, "ground_truth": gt}
# Backward-compat shim -- module moved to ocr/pipeline/dewarp.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.dewarp")

View File

@@ -1,27 +1,4 @@
"""
OCR Pipeline Geometry API (barrel re-export)
This module was split into:
- ocr_pipeline_deskew.py (Deskew endpoints)
- ocr_pipeline_dewarp.py (Dewarp endpoints)
- ocr_pipeline_structure.py (Structure detection + exclude regions)
- ocr_pipeline_columns.py (Column detection + ground truth)
The `router` object is assembled here by including all sub-routers.
Importers that did `from ocr_pipeline_geometry import router` continue to work.
"""
from fastapi import APIRouter
from ocr_pipeline_deskew import router as _deskew_router
from ocr_pipeline_dewarp import router as _dewarp_router
from ocr_pipeline_structure import router as _structure_router
from ocr_pipeline_columns import router as _columns_router
# Assemble the combined router.
# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix.
router = APIRouter()
router.include_router(_deskew_router)
router.include_router(_dewarp_router)
router.include_router(_structure_router)
router.include_router(_columns_router)
# Backward-compat shim -- module moved to ocr/pipeline/geometry.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.geometry")

View File

@@ -1,209 +1,4 @@
"""
OCR Pipeline LLM Review — LLM-based correction endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
from datetime import datetime
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
llm_review_entries,
llm_review_entries_streaming,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_append_pipeline_log,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 8: LLM Review
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}
# Backward-compat shim -- module moved to ocr/pipeline/llm_review.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.llm_review")

View File

@@ -1,266 +1,4 @@
"""
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
Merge helper functions live in ocr_merge_helpers.py.
This module re-exports them for backward compatibility.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from cv_words_first import build_grid_from_words
from ocr_pipeline_common import _cache, _append_pipeline_log
from ocr_pipeline_session_store import get_session_image, update_session_db
# Re-export merge helpers for backward compatibility
from ocr_merge_helpers import ( # noqa: F401
_split_paddle_multi_words,
_group_words_into_rows,
_row_center_y,
_merge_row_sequences,
_merge_paddle_tesseract,
_deduplicate_words,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
def _run_tesseract_words(img_bgr) -> list:
"""Run Tesseract OCR on an image and return word dicts."""
from PIL import Image
import pytesseract
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
data = pytesseract.image_to_data(
pil_img, lang="eng+deu",
config="--psm 6 --oem 3",
output_type=pytesseract.Output.DICT,
)
tess_words = []
for i in range(len(data["text"])):
text = str(data["text"][i]).strip()
conf_raw = str(data["conf"][i])
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
if not text or conf < 20:
continue
tess_words.append({
"text": text,
"left": data["left"][i],
"top": data["top"][i],
"width": data["width"][i],
"height": data["height"][i],
"conf": conf,
})
return tess_words
def _build_kombi_word_result(
cells: list,
columns_meta: list,
img_w: int,
img_h: int,
duration: float,
engine_name: str,
raw_engine_words: list,
raw_engine_words_split: list,
tess_words: list,
merged_words: list,
raw_engine_key: str = "raw_paddle_words",
raw_split_key: str = "raw_paddle_words_split",
) -> dict:
"""Build the word_result dict for kombi endpoints."""
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
return {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": engine_name,
"grid_method": engine_name,
raw_engine_key: raw_engine_words,
raw_split_key: raw_engine_words_split,
"raw_tesseract_words": tess_words,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
},
}
async def _load_session_image(session_id: str):
"""Load preprocessed image for kombi endpoints."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
return img_png, img_bgr
# ---------------------------------------------------------------------------
# Kombi endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/paddle-kombi")
async def paddle_kombi(session_id: str):
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
paddle_words = await ocr_region_paddle(img_bgr, region=None)
if not paddle_words:
paddle_words = []
tess_words = _run_tesseract_words(img_bgr)
paddle_words_split = _split_paddle_multi_words(paddle_words)
logger.info(
"paddle_kombi: split %d paddle boxes -> %d individual words",
len(paddle_words), len(paddle_words_split),
)
if not paddle_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
merged_words = _deduplicate_words(merged_words)
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "kombi"
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "kombi",
paddle_words, paddle_words_split, tess_words, merged_words,
"raw_paddle_words", "raw_paddle_words_split",
)
await update_session_db(
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[paddle=%d, tess=%d, merged=%d]",
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(paddle_words), len(tess_words), len(merged_words),
)
await _append_pipeline_log(session_id, "paddle_kombi", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"paddle_words": len(paddle_words),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
"ocr_engine": "kombi",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
@router.post("/sessions/{session_id}/rapid-kombi")
async def rapid_kombi(session_id: str):
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_rapid
from cv_vocab_types import PageRegion
t0 = time.time()
full_region = PageRegion(
type="full_page", x=0, y=0, width=img_w, height=img_h,
)
rapid_words = ocr_region_rapid(img_bgr, full_region)
if not rapid_words:
rapid_words = []
tess_words = _run_tesseract_words(img_bgr)
rapid_words_split = _split_paddle_multi_words(rapid_words)
logger.info(
"rapid_kombi: split %d rapid boxes -> %d individual words",
len(rapid_words), len(rapid_words_split),
)
if not rapid_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
merged_words = _deduplicate_words(merged_words)
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "rapid_kombi"
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
rapid_words, rapid_words_split, tess_words, merged_words,
"raw_rapid_words", "raw_rapid_words_split",
)
await update_session_db(
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[rapid=%d, tess=%d, merged=%d]",
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(rapid_words), len(tess_words), len(merged_words),
)
await _append_pipeline_log(session_id, "rapid_kombi", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"rapid_words": len(rapid_words),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
"ocr_engine": "rapid_kombi",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
# Backward-compat shim -- module moved to ocr/pipeline/ocr_merge.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.ocr_merge")

View File

@@ -1,333 +1,4 @@
"""
Overlay rendering for columns, rows, and words (grid-based overlays).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_common import _get_base_image_png
from ocr_pipeline_session_store import get_session_db
from ocr_pipeline_rows import _draw_box_exclusion_overlay
logger = logging.getLogger(__name__)
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types (BGR)
colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"column_text": (200, 200, 0), # Cyan/Turquoise
"page_ref": (200, 0, 200), # Purple
"column_marker": (0, 0, 220), # Red
"column_ignore": (180, 180, 180), # Light Gray
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label with confidence
label = col.get("type", "unknown").replace("column_", "").upper()
conf = col.get("classification_confidence")
if conf is not None and conf < 1.0:
label = f"{label} {int(conf * 100)}%"
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
# Draw detected box boundaries as dashed rectangles
zones = column_result.get("zones") or []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
box_color = (0, 200, 255) # Yellow (BGR)
# Draw dashed rectangle by drawing short line segments
dash_len = 15
for edge_x in range(bx, bx + bw, dash_len * 2):
end_x = min(edge_x + dash_len, bx + bw)
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
for edge_y in range(by, by + bh, dash_len * 2):
end_y = min(edge_y + dash_len, by + bh)
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with row bands drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
row_result = session.get("row_result")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=404, detail="No row data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for row types (BGR)
row_colors = {
"content": (255, 180, 0), # Blue
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for row in row_result["rows"]:
x, y = row["x"], row["y"]
w, h = row["width"], row["height"]
row_type = row.get("row_type", "content")
color = row_colors.get(row_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
# Label
idx = row.get("index", 0)
label = f"R{idx} {row_type.upper()}"
wc = row.get("word_count", 0)
if wc:
label = f"{label} ({wc}w)"
cv2.putText(img, label, (x + 5, y + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# Draw zone separator lines if zones exist
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
if zones:
img_w_px = img.shape[1]
zone_color = (0, 200, 255) # Yellow (BGR)
dash_len = 20
for zone in zones:
if zone.get("zone_type") == "box":
zy = zone["y"]
zh = zone["height"]
for line_y in [zy, zy + zh]:
for sx in range(0, img_w_px, dash_len * 2):
ex = min(sx + dash_len, img_w_px)
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=404, detail="No word data available")
# Support both new cell-based and legacy entry-based formats
cells = word_result.get("cells")
if not cells and not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
# Load best available base image (cropped > dewarped > original)
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
overlay = img.copy()
if cells:
# New cell-based overlay: color by column index
col_palette = [
(255, 180, 0), # Blue (BGR)
(0, 200, 0), # Green
(0, 140, 255), # Orange
(200, 100, 200), # Purple
(200, 200, 0), # Cyan
(100, 200, 200), # Yellow-ish
]
for cell in cells:
bbox = cell.get("bbox_px", {})
cx = bbox.get("x", 0)
cy = bbox.get("y", 0)
cw = bbox.get("w", 0)
ch = bbox.get("h", 0)
if cw <= 0 or ch <= 0:
continue
col_idx = cell.get("col_index", 0)
color = col_palette[col_idx % len(col_palette)]
# Cell rectangle border
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
# Semi-transparent fill
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
# Cell-ID label (top-left corner)
cell_id = cell.get("cell_id", "")
cv2.putText(img, cell_id, (cx + 2, cy + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
# Text label (bottom of cell)
text = cell.get("text", "")
if text:
conf = cell.get("confidence", 0)
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, cy + ch - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
else:
# Legacy fallback: entry-based overlay (for old sessions)
column_result = session.get("column_result")
row_result = session.get("row_result")
col_colors = {
"column_en": (255, 180, 0),
"column_de": (0, 200, 0),
"column_example": (0, 140, 255),
}
columns = []
if column_result and column_result.get("columns"):
columns = [c for c in column_result["columns"]
if c.get("type", "").startswith("column_")]
content_rows_data = []
if row_result and row_result.get("rows"):
content_rows_data = [r for r in row_result["rows"]
if r.get("row_type") == "content"]
for col in columns:
col_type = col.get("type", "")
color = col_colors.get(col_type, (200, 200, 200))
cx, cw = col["x"], col["width"]
for row in content_rows_data:
ry, rh = row["y"], row["height"]
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
entries = word_result["entries"]
entry_by_row: Dict[int, Dict] = {}
for entry in entries:
entry_by_row[entry.get("row_index", -1)] = entry
for row_idx, row in enumerate(content_rows_data):
entry = entry_by_row.get(row_idx)
if not entry:
continue
conf = entry.get("confidence", 0)
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
ry, rh = row["y"], row["height"]
for col in columns:
col_type = col.get("type", "")
cx, cw = col["x"], col["width"]
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
text = entry.get(field, "") if field else ""
if text:
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, ry + rh - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
# Red semi-transparent overlay for box zones
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
# Backward-compat shim -- module moved to ocr/pipeline/overlay_grid.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_grid")

View File

@@ -1,205 +1,4 @@
"""
Overlay rendering for structure detection (boxes, zones, colors, graphics).
Extracted from ocr_pipeline_overlays.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_common import _get_base_image_png
from ocr_pipeline_session_store import get_session_db
from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
from cv_box_detect import detect_boxes, split_page_into_zones
logger = logging.getLogger(__name__)
async def _get_structure_overlay(session_id: str) -> Response:
"""Generate overlay image showing detected boxes, zones, and color regions."""
base_png = await _get_base_image_png(session_id)
if not base_png:
raise HTTPException(status_code=404, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
# Get structure result (run detection if not cached)
session = await get_session_db(session_id)
structure = (session or {}).get("structure_result")
if not structure:
# Run detection on-the-fly
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
structure = {
"boxes": [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence, "border_thickness": b.border_thickness}
for b in boxes
],
"zones": [
{"index": z.index, "zone_type": z.zone_type,
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
for z in zones
],
}
overlay = img.copy()
# --- Draw zone boundaries ---
zone_colors = {
"content": (200, 200, 200), # light gray
"box": (255, 180, 0), # blue-ish (BGR)
}
for zone in structure.get("zones", []):
zx = zone["x"]
zy = zone["y"]
zw = zone["w"]
zh = zone["h"]
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
# Draw zone boundary as dashed line
dash_len = 12
for edge_x in range(zx, zx + zw, dash_len * 2):
end_x = min(edge_x + dash_len, zx + zw)
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
# Zone label
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
cv2.putText(img, zone_label, (zx + 5, zy + 15),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
# --- Draw detected boxes ---
# Color map for box backgrounds (BGR)
bg_hex_to_bgr = {
"#dc2626": (38, 38, 220), # red
"#2563eb": (235, 99, 37), # blue
"#16a34a": (74, 163, 22), # green
"#ea580c": (12, 88, 234), # orange
"#9333ea": (234, 51, 147), # purple
"#ca8a04": (4, 138, 202), # yellow
"#6b7280": (128, 114, 107), # gray
}
for box_data in structure.get("boxes", []):
bx = box_data["x"]
by = box_data["y"]
bw = box_data["w"]
bh = box_data["h"]
conf = box_data.get("confidence", 0)
thickness = box_data.get("border_thickness", 0)
bg_hex = box_data.get("bg_color_hex", "#6b7280")
bg_name = box_data.get("bg_color_name", "")
# Box fill color
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
# Semi-transparent fill
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
# Solid border
border_color = fill_bgr
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
# Label
label = f"BOX"
if bg_name and bg_name not in ("unknown", "white"):
label += f" ({bg_name})"
if thickness > 0:
label += f" border={thickness}px"
label += f" {int(conf * 100)}%"
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
cv2.putText(img, label, (bx + 8, by + 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
# --- Draw color regions (HSV masks) ---
hsv = cv2.cvtColor(
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
cv2.COLOR_BGR2HSV,
)
color_bgr_map = {
"red": (0, 0, 255),
"orange": (0, 140, 255),
"yellow": (0, 200, 255),
"green": (0, 200, 0),
"blue": (255, 150, 0),
"purple": (200, 0, 200),
}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
# Only draw if there are significant colored pixels
if np.sum(mask > 0) < 100:
continue
# Draw colored contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
for cnt in contours:
area = cv2.contourArea(cnt)
if area < 20:
continue
cv2.drawContours(img, [cnt], -1, draw_color, 2)
# --- Draw graphic elements ---
graphics_data = structure.get("graphics", [])
shape_icons = {
"image": "IMAGE",
"illustration": "ILLUST",
}
for gfx in graphics_data:
gx, gy = gfx["x"], gfx["y"]
gw, gh = gfx["w"], gfx["h"]
shape = gfx.get("shape", "icon")
color_hex = gfx.get("color_hex", "#6b7280")
conf = gfx.get("confidence", 0)
# Pick draw color based on element color (BGR)
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
# Draw bounding box (dashed style via short segments)
dash = 6
for seg_x in range(gx, gx + gw, dash * 2):
end_x = min(seg_x + dash, gx + gw)
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
for seg_y in range(gy, gy + gh, dash * 2):
end_y = min(seg_y + dash, gy + gh)
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
# Label
icon = shape_icons.get(shape, shape.upper()[:5])
label = f"{icon} {int(conf * 100)}%"
# White background for readability
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
lx = gx + 2
ly = max(gy - 4, th + 4)
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
# Encode result
_, png_buf = cv2.imencode(".png", img)
return Response(content=png_buf.tobytes(), media_type="image/png")
# Backward-compat shim -- module moved to ocr/pipeline/overlay_structure.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_structure")

View File

@@ -1,34 +1,4 @@
"""
Overlay image rendering for OCR pipeline — barrel re-export.
All implementation split into:
ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
ocr_pipeline_overlay_grid — columns, rows, words overlays
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import HTTPException
from fastapi.responses import Response
from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401
from ocr_pipeline_overlay_grid import ( # noqa: F401
_get_columns_overlay,
_get_rows_overlay,
_get_words_overlay,
)
async def render_overlay(overlay_type: str, session_id: str) -> Response:
"""Dispatch to the appropriate overlay renderer."""
if overlay_type == "structure":
return await _get_structure_overlay(session_id)
elif overlay_type == "columns":
return await _get_columns_overlay(session_id)
elif overlay_type == "rows":
return await _get_rows_overlay(session_id)
elif overlay_type == "words":
return await _get_words_overlay(session_id)
else:
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
# Backward-compat shim -- module moved to ocr/pipeline/overlays.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlays")

View File

@@ -1,26 +1,4 @@
"""
OCR Pipeline Postprocessing API — composite router assembling LLM review,
reconstruction, export, validation, image detection/generation, and
handwriting removal endpoints.
Split into sub-modules:
ocr_pipeline_llm_review — LLM review + apply corrections
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from ocr_pipeline_llm_review import router as _llm_review_router
from ocr_pipeline_reconstruction import router as _reconstruction_router
from ocr_pipeline_validation import router as _validation_router
# Composite router — drop-in replacement for the old monolithic router.
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
router = APIRouter()
router.include_router(_llm_review_router)
router.include_router(_reconstruction_router)
router.include_router(_validation_router)
# Backward-compat shim -- module moved to ocr/pipeline/postprocess.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.postprocess")

View File

@@ -1,362 +1,4 @@
"""
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Dict
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from ocr_pipeline_session_store import (
get_session_db,
get_sub_sessions,
update_session_db,
)
from ocr_pipeline_common import _cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 9: Reconstruction + Fabric JSON export
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=10)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in main_updates:
cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
for entry in entries:
row_idx = entry.get("row_index", -1)
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=10)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
# Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return {
"session_id": session_id,
"updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
for cell in sub_word["cells"]:
cell_copy = dict(cell)
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
# ---------------------------------------------------------------------------
# Vocab entries merged + PDF/DOCX export
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
for e in entries:
e.setdefault("source", "main")
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y
entries.append(e_copy)
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows x columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")
# Backward-compat shim -- module moved to ocr/pipeline/reconstruction.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reconstruction")

View File

@@ -1,22 +1,4 @@
"""
OCR Pipeline Regression Tests — barrel re-export.
All implementation split into:
ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
ocr_pipeline_regression_endpoints — FastAPI routes
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
# Helpers (used by grid_editor_api_grid.py)
from ocr_pipeline_regression_helpers import ( # noqa: F401
_init_regression_table,
_persist_regression_run,
_extract_cells_for_comparison,
_build_reference_snapshot,
compare_grids,
)
# Endpoints (router used by ocr_pipeline_api.py)
from ocr_pipeline_regression_endpoints import router # noqa: F401
# Backward-compat shim -- module moved to ocr/pipeline/regression.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression")

View File

@@ -1,421 +1,4 @@
"""
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
from ocr_pipeline_regression_helpers import (
_build_reference_snapshot,
_init_regression_table,
_persist_regression_run,
compare_grids,
get_pool,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to ocr/pipeline/regression_endpoints.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_endpoints")

View File

@@ -1,207 +1,4 @@
"""
OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from ocr_pipeline_session_store import get_pool
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}
# Backward-compat shim -- module moved to ocr/pipeline/regression_helpers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_helpers")

View File

@@ -1,94 +1,4 @@
"""
OCR Pipeline Reprocess Endpoint.
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Request
from ocr_pipeline_common import _cache
from ocr_pipeline_session_store import get_session_db, update_session_db
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}
# Backward-compat shim -- module moved to ocr/pipeline/reprocess.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reprocess")

View File

@@ -1,348 +1,4 @@
"""
OCR Pipeline - Row Detection Endpoints.
Extracted from ocr_pipeline_api.py.
Handles row detection (auto + manual) and row ground truth.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
create_ocr_image,
detect_column_geometry,
detect_row_geometry,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualRowsRequest,
RowGroundTruthRequest,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
# ---------------------------------------------------------------------------
def _draw_box_exclusion_overlay(
img: np.ndarray,
zones: List[Dict],
*,
label: str = "BOX — separat verarbeitet",
) -> None:
"""Draw red semi-transparent rectangles over box zones (in-place).
Reusable for columns, rows, and words overlays.
"""
for zone in zones:
if zone.get("zone_type") != "box" or not zone.get("box"):
continue
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Red semi-transparent fill (~25 %)
box_overlay = img.copy()
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
# Border
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
# Label
cv2.putText(img, label, (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# ---------------------------------------------------------------------------
# Row Detection Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/rows")
async def detect_rows(session_id: str):
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
t0 = time.time()
# Try to reuse cached word_dicts and inv from column detection
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
# Not cached — run column geometry to get intermediates
ocr_img = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
else:
left_x, right_x, top_y, bottom_y = content_bounds
# Read zones from column_result to exclude box regions
session = await get_session_db(session_id)
column_result = (session or {}).get("column_result") or {}
is_sub_session = bool((session or {}).get("parent_session_id"))
# Sub-sessions (box crops): use word-grouping instead of gap-based
# row detection. Box images are small with complex internal layouts
# (headings, sub-columns) where the horizontal projection approach
# merges rows. Word-grouping directly clusters words by Y proximity,
# which is more robust for these cases.
if is_sub_session and word_dicts:
from cv_layout import _build_rows_from_word_grouping
rows = _build_rows_from_word_grouping(
word_dicts, left_x, right_x, top_y, bottom_y,
right_x - left_x, bottom_y - top_y,
)
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
else:
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
# Collect box y-ranges for filtering.
# Use border_thickness to shrink the exclusion zone: the border pixels
# belong visually to the box frame, but text rows above/below the box
# may overlap with the border area and must not be clipped.
box_ranges = [] # [(y_start, y_end)]
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
box_ranges.append((box["y"], box["y"] + box["height"]))
# Inner range: shrink by border thickness so boundary rows aren't excluded
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges and inv is not None:
# Combined-image approach: strip box regions from inv image,
# run row detection on the combined image, then remap y-coords back.
content_strips = [] # [(y_start, y_end)] in absolute coords
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
# Using inner ranges means the border area is included in the content
# strips, so the last row above a box isn't clipped by the border.
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
strip_start = top_y
for by_start, by_end in sorted_boxes:
if by_start > strip_start:
content_strips.append((strip_start, by_start))
strip_start = max(strip_start, by_end)
if strip_start < bottom_y:
content_strips.append((strip_start, bottom_y))
# Filter to strips with meaningful height
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
if content_strips:
# Stack content strips vertically
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
combined_inv = np.vstack(inv_strips)
# Filter word_dicts to only include words from content strips
combined_words = []
cum_y = 0
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
for ys, ye in content_strips:
h = ye - ys
strip_offsets.append((cum_y, h, ys))
for w in word_dicts:
w_abs_y = w['top'] + top_y # word y is relative to content top
w_center = w_abs_y + w['height'] / 2
if ys <= w_center < ye:
# Remap to combined coordinates
w_copy = dict(w)
w_copy['top'] = cum_y + (w_abs_y - ys)
combined_words.append(w_copy)
cum_y += h
# Run row detection on combined image
combined_h = combined_inv.shape[0]
rows = detect_row_geometry(
combined_inv, combined_words, left_x, right_x, 0, combined_h,
)
# Remap y-coordinates back to absolute page coords
def _combined_y_to_abs(cy: int) -> int:
for c_start, s_h, abs_start in strip_offsets:
if cy < c_start + s_h:
return abs_start + (cy - c_start)
last_c, last_h, last_abs = strip_offsets[-1]
return last_abs + last_h
for r in rows:
abs_y = _combined_y_to_abs(r.y)
abs_y_end = _combined_y_to_abs(r.y + r.height)
r.y = abs_y
r.height = abs_y_end - abs_y
else:
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
else:
# No boxes — standard row detection
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
duration = time.time() - t0
# Assign zone_index based on which content zone each row falls in
# Build content zone list with indices
zones = column_result.get("zones") or []
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
# Build serializable result (exclude words to keep payload small)
rows_data = []
for r in rows:
# Determine zone_index
zone_idx = 0
row_center_y = r.y + r.height / 2
for zi, zone in content_zones:
zy = zone["y"]
zh = zone["height"]
if zy <= row_center_y < zy + zh:
zone_idx = zi
break
rd = {
"index": r.index,
"x": r.x,
"y": r.y,
"width": r.width,
"height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
"zone_index": zone_idx,
}
rows_data.append(rd)
type_counts = {}
for r in rows:
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
row_result = {
"rows": rows_data,
"summary": type_counts,
"total_rows": len(rows),
"duration_seconds": round(duration, 2),
}
# Persist to DB — also invalidate word_result since rows changed
await update_session_db(
session_id,
row_result=row_result,
word_result=None,
current_step=7,
)
cached["row_result"] = row_result
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: rows session {session_id}: "
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
content_rows = sum(1 for r in rows if r.row_type == "content")
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
await _append_pipeline_log(session_id, "rows", {
"total_rows": len(rows),
"content_rows": content_rows,
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
"avg_row_height_px": avg_height,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**row_result,
}
@router.post("/sessions/{session_id}/rows/manual")
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
"""Override detected rows with manual definitions."""
row_result = {
"rows": req.rows,
"total_rows": len(req.rows),
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, row_result=row_result, word_result=None)
if session_id in _cache:
_cache[session_id]["row_result"] = row_result
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
f"{len(req.rows)} rows set")
return {"session_id": session_id, **row_result}
@router.post("/sessions/{session_id}/ground-truth/rows")
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
"""Save ground truth feedback for the row detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_rows": req.corrected_rows,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"row_result": session.get("row_result"),
}
ground_truth["rows"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/rows")
async def get_row_ground_truth(session_id: str):
"""Retrieve saved ground truth for row detection."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
rows_gt = ground_truth.get("rows")
if not rows_gt:
raise HTTPException(status_code=404, detail="No row ground truth saved")
return {
"session_id": session_id,
"rows_gt": rows_gt,
"rows_auto": session.get("row_result"),
}
# Backward-compat shim -- module moved to ocr/pipeline/rows.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.rows")

View File

@@ -1,388 +1,4 @@
"""
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
Replaces in-memory storage with database persistence.
See migrations/002_ocr_pipeline_sessions.sql for schema.
"""
import os
import uuid
import logging
import json
from typing import Optional, List, Dict, Any
import asyncpg
logger = logging.getLogger(__name__)
# Database configuration (same as vocab_session_store)
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
)
# Connection pool (initialized lazily)
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
"""Get or create the database connection pool."""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def init_ocr_pipeline_tables():
"""Initialize OCR pipeline tables if they don't exist."""
pool = await get_pool()
async with pool.acquire() as conn:
tables_exist = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'ocr_pipeline_sessions'
)
""")
if not tables_exist:
logger.info("Creating OCR pipeline tables...")
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/002_ocr_pipeline_sessions.sql"
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
logger.info("OCR pipeline tables created successfully")
else:
logger.warning(f"Migration file not found: {migration_path}")
else:
logger.debug("OCR pipeline tables already exist")
# Ensure new columns exist (idempotent ALTER TABLE)
await conn.execute("""
ALTER TABLE ocr_pipeline_sessions
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
ADD COLUMN IF NOT EXISTS pipeline_log JSONB,
ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
ADD COLUMN IF NOT EXISTS orientation_result JSONB,
ADD COLUMN IF NOT EXISTS crop_result JSONB,
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
ADD COLUMN IF NOT EXISTS box_index INT,
ADD COLUMN IF NOT EXISTS grid_editor_result JSONB,
ADD COLUMN IF NOT EXISTS structure_result JSONB,
ADD COLUMN IF NOT EXISTS document_group_id UUID,
ADD COLUMN IF NOT EXISTS page_number INT
""")
# Index for document group lookups
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group
ON ocr_pipeline_sessions (document_group_id)
WHERE document_group_id IS NOT NULL
""")
# =============================================================================
# SESSION CRUD
# =============================================================================
async def create_session_db(
session_id: str,
name: str,
filename: str,
original_png: bytes,
parent_session_id: Optional[str] = None,
box_index: Optional[int] = None,
document_group_id: Optional[str] = None,
page_number: Optional[int] = None,
) -> Dict[str, Any]:
"""Create a new OCR pipeline session.
Args:
parent_session_id: If set, this is a sub-session for a box region.
box_index: 0-based index of the box this sub-session represents.
document_group_id: Groups multi-page uploads into one document.
page_number: 1-based page index within the document group.
"""
pool = await get_pool()
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
group_uuid = uuid.UUID(document_group_id) if document_group_id else None
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO ocr_pipeline_sessions (
id, name, filename, original_png, status, current_step,
parent_session_id, box_index, document_group_id, page_number
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8)
RETURNING id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png,
parent_uuid, box_index, group_uuid, page_number)
return _row_to_dict(row)
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
"""Get session metadata (without images)."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
if row:
return _row_to_dict(row)
return None
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
"""Load a single image (BYTEA) from the session."""
column_map = {
"original": "original_png",
"oriented": "oriented_png",
"cropped": "cropped_png",
"deskewed": "deskewed_png",
"binarized": "binarized_png",
"dewarped": "dewarped_png",
"clean": "clean_png",
}
column = column_map.get(image_type)
if not column:
return None
pool = await get_pool()
async with pool.acquire() as conn:
return await conn.fetchval(
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
uuid.UUID(session_id)
)
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
"""Update session fields dynamically."""
pool = await get_pool()
fields = []
values = []
param_idx = 1
allowed_fields = {
'name', 'filename', 'status', 'current_step',
'original_png', 'oriented_png', 'cropped_png',
'deskewed_png', 'binarized_png', 'dewarped_png',
'clean_png', 'handwriting_removal_meta',
'orientation_result', 'crop_result',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'word_result', 'ground_truth', 'auto_shear_degrees',
'doc_type', 'doc_type_result',
'document_category', 'pipeline_log',
'grid_editor_result', 'structure_result',
'parent_session_id', 'box_index',
'document_group_id', 'page_number',
}
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'}
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
if key in jsonb_fields and value is not None and not isinstance(value, str):
value = json.dumps(value)
values.append(value)
param_idx += 1
if not fields:
return await get_session_db(session_id)
# Always update updated_at
fields.append(f"updated_at = NOW()")
values.append(uuid.UUID(session_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE ocr_pipeline_sessions
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING id, name, filename, status, current_step,
orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
grid_editor_result, structure_result,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
""", *values)
if row:
return _row_to_dict(row)
return None
async def list_sessions_db(
limit: int = 50,
include_sub_sessions: bool = False,
) -> List[Dict[str, Any]]:
"""List sessions (metadata only, no images).
By default, sub-sessions (those with parent_session_id) are excluded.
Pass include_sub_sessions=True to include them.
"""
pool = await get_pool()
async with pool.acquire() as conn:
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')"
rows = await conn.fetch(f"""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at,
ground_truth
FROM ocr_pipeline_sessions
{where}
ORDER BY created_at DESC
LIMIT $1
""", limit)
results = []
for row in rows:
d = _row_to_dict(row)
# Derive is_ground_truth flag from JSONB, then drop the heavy field
gt = d.pop("ground_truth", None) or {}
d["is_ground_truth"] = bool(gt.get("build_grid_reference"))
results.append(d)
return results
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
"""Get all sub-sessions for a parent session, ordered by box_index."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE parent_session_id = $1
ORDER BY box_index ASC
""", uuid.UUID(parent_session_id))
return [_row_to_dict(row) for row in rows]
async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]:
"""Get all sessions in a document group, ordered by page_number."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
document_group_id, page_number,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE document_group_id = $1
ORDER BY page_number ASC
""", uuid.UUID(document_group_id))
return [_row_to_dict(row) for row in rows]
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
"""List sessions that have a build_grid_reference in ground_truth."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
ground_truth,
parent_session_id, box_index,
created_at, updated_at
FROM ocr_pipeline_sessions
WHERE ground_truth IS NOT NULL
AND ground_truth::text LIKE '%build_grid_reference%'
AND parent_session_id IS NULL
ORDER BY created_at DESC
""")
return [_row_to_dict(row) for row in rows]
async def delete_session_db(session_id: str) -> bool:
"""Delete a session."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
return result == "DELETE 1"
async def delete_all_sessions_db() -> int:
"""Delete all sessions. Returns number of deleted rows."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
# result is e.g. "DELETE 5"
try:
return int(result.split()[-1])
except (ValueError, IndexError):
return 0
# =============================================================================
# HELPER
# =============================================================================
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
"""Convert asyncpg Record to JSON-serializable dict."""
if row is None:
return {}
result = dict(row)
# UUID → string
for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']:
if key in result and result[key] is not None:
result[key] = str(result[key])
# datetime → ISO string
for key in ['created_at', 'updated_at']:
if key in result and result[key] is not None:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])
return result
# Backward-compat shim -- module moved to ocr/pipeline/session_store.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.session_store")

View File

@@ -1,20 +1,4 @@
"""
OCR Pipeline Sessions API — barrel re-export.
All implementation split into:
ocr_pipeline_sessions_crud — session CRUD, box sessions
ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
from fastapi import APIRouter
from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401
from ocr_pipeline_sessions_images import router as _images_router # noqa: F401
# Composite router (used by ocr_pipeline_api.py)
router = APIRouter()
router.include_router(_crud_router)
router.include_router(_images_router)
# Backward-compat shim -- module moved to ocr/pipeline/sessions.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions")

View File

@@ -1,449 +1,4 @@
"""
OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import uuid
from typing import Any, Dict, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
from ocr_pipeline_common import (
VALID_DOCUMENT_CATEGORIES,
UpdateSessionRequest,
_cache,
)
from ocr_pipeline_session_store import (
create_session_db,
delete_all_sessions_db,
delete_session_db,
get_session_db,
get_session_image,
get_sub_sessions,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions(include_sub_sessions: bool = False):
"""List OCR pipeline sessions.
By default, sub-sessions (box regions) are hidden.
Pass ?include_sub_sessions=true to show them.
"""
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session.
For multi-page PDFs (> 1 page), each page becomes its own session
grouped under a ``document_group_id``. The response includes a
``pages`` array with one entry per page/session.
"""
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
session_name = name or filename
# --- Multi-page PDF handling ---
if is_pdf:
try:
import fitz # PyMuPDF
pdf_doc = fitz.open(stream=file_data, filetype="pdf")
page_count = pdf_doc.page_count
pdf_doc.close()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
if page_count > 1:
return await _create_multi_page_sessions(
file_data, filename, session_name, page_count,
)
# --- Single page (image or 1-page PDF) ---
session_id = str(uuid.uuid4())
try:
if is_pdf:
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
else:
img_bgr = render_image_high_res(file_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
original_png = png_buf.tobytes()
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": session_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
}
async def _create_multi_page_sessions(
pdf_data: bytes,
filename: str,
base_name: str,
page_count: int,
) -> dict:
"""Create one session per PDF page, grouped by document_group_id."""
document_group_id = str(uuid.uuid4())
pages = []
for page_idx in range(page_count):
session_id = str(uuid.uuid4())
page_name = f"{base_name} — Seite {page_idx + 1}"
try:
img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
except Exception as e:
logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
continue
ok, png_buf = cv2.imencode(".png", img_bgr)
if not ok:
continue
page_png = png_buf.tobytes()
await create_session_db(
session_id=session_id,
name=page_name,
filename=filename,
original_png=page_png,
document_group_id=document_group_id,
page_number=page_idx + 1,
)
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": page_name,
"original_bgr": img_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
h, w = img_bgr.shape[:2]
pages.append({
"session_id": session_id,
"name": page_name,
"page_number": page_idx + 1,
"image_width": w,
"image_height": h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
})
logger.info(
f"OCR Pipeline: created page session {session_id} "
f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
)
# Include session_id pointing to first page for backwards compatibility
# (frontends that expect a single session_id will navigate to page 1)
first_session_id = pages[0]["session_id"] if pages else None
return {
"session_id": first_session_id,
"document_group_id": document_group_id,
"filename": filename,
"name": base_name,
"page_count": page_count,
"pages": pages,
}
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
"document_category": session.get("document_category"),
"doc_type": session.get("doc_type"),
}
if session.get("orientation_result"):
result["orientation_result"] = session["orientation_result"]
if session.get("crop_result"):
result["crop_result"] = session["crop_result"]
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"]
if session.get("structure_result"):
result["structure_result"] = session["structure_result"]
if session.get("grid_editor_result"):
# Include summary only to keep response small
gr = session["grid_editor_result"]
result["grid_editor_result"] = {
"summary": gr.get("summary", {}),
"zones_count": len(gr.get("zones", [])),
"edited": gr.get("edited", False),
}
if session.get("ground_truth"):
result["ground_truth"] = session["ground_truth"]
# Box sub-session info (zone_type='box' from column detection — NOT page-split)
if session.get("parent_session_id"):
result["parent_session_id"] = session["parent_session_id"]
result["box_index"] = session.get("box_index")
else:
# Check for box sub-sessions (column detection creates these)
subs = await get_sub_sessions(session_id)
if subs:
result["sub_sessions"] = [
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
for s in subs
]
return result
@router.put("/sessions/{session_id}")
async def update_session(session_id: str, req: UpdateSessionRequest):
"""Update session name and/or document category."""
kwargs: Dict[str, Any] = {}
if req.name is not None:
kwargs["name"] = req.name
if req.document_category is not None:
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
)
kwargs["document_category"] = req.document_category
if not kwargs:
raise HTTPException(status_code=400, detail="Nothing to update")
updated = await update_session_db(session_id, **kwargs)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, **kwargs}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
@router.delete("/sessions")
async def delete_all_sessions():
"""Delete ALL sessions (cleanup)."""
_cache.clear()
count = await delete_all_sessions_db()
return {"deleted_count": count}
@router.post("/sessions/{session_id}/create-box-sessions")
async def create_box_sessions(session_id: str):
"""Create sub-sessions for each detected box region.
Crops box regions from the cropped/dewarped image and creates
independent sub-sessions that can be processed through the pipeline.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result:
raise HTTPException(status_code=400, detail="Column detection must be completed first")
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
if not box_zones:
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
# Check for existing sub-sessions
existing = await get_sub_sessions(session_id)
if existing:
return {
"session_id": session_id,
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
"message": f"{len(existing)} sub-session(s) already exist",
}
# Load base image
base_png = await get_session_image(session_id, "cropped")
if not base_png:
base_png = await get_session_image(session_id, "dewarped")
if not base_png:
raise HTTPException(status_code=400, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
parent_name = session.get("name", "Session")
created = []
for i, zone in enumerate(box_zones):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Crop box region with small padding
pad = 5
y1 = max(0, by - pad)
y2 = min(img.shape[0], by + bh + pad)
x1 = max(0, bx - pad)
x2 = min(img.shape[1], bx + bw + pad)
crop = img[y1:y2, x1:x2]
# Encode as PNG
success, png_buf = cv2.imencode(".png", crop)
if not success:
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
continue
sub_id = str(uuid.uuid4())
sub_name = f"{parent_name} — Box {i + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=session.get("filename", "box-crop.png"),
original_png=png_buf.tobytes(),
parent_session_id=session_id,
box_index=i,
)
# Cache the BGR for immediate processing
# Promote original to cropped so column/row/word detection finds it
box_bgr = crop.copy()
_cache[sub_id] = {
"id": sub_id,
"filename": session.get("filename", "box-crop.png"),
"name": sub_name,
"parent_session_id": session_id,
"original_bgr": box_bgr,
"oriented_bgr": None,
"cropped_bgr": box_bgr,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
created.append({
"id": sub_id,
"name": sub_name,
"box_index": i,
"box": box,
"image_width": crop.shape[1],
"image_height": crop.shape[0],
})
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
return {
"session_id": session_id,
"sub_sessions": created,
"total": len(created),
}
# Backward-compat shim -- module moved to ocr/pipeline/sessions_crud.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_crud")

View File

@@ -1,176 +1,4 @@
"""
OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
categories, and document type detection.
Extracted from ocr_pipeline_sessions.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import Response
from cv_vocab_pipeline import create_ocr_image, detect_document_type
from ocr_pipeline_common import (
VALID_DOCUMENT_CATEGORIES,
_append_pipeline_log,
_cache,
_get_base_image_png,
_get_cached,
_load_session_to_cache,
)
from ocr_pipeline_overlays import render_overlay
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Thumbnail & Log Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image."""
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
scale = size / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
_, png_bytes = cv2.imencode(".png", thumb)
return Response(content=png_bytes.tobytes(), media_type="image/png",
headers={"Cache-Control": "public, max-age=3600"})
@router.get("/sessions/{session_id}/pipeline-log")
async def get_pipeline_log(session_id: str):
"""Get the pipeline execution log for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
@router.get("/categories")
async def list_categories():
"""List valid document categories."""
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "structure-overlay":
return await render_overlay("structure", session_id)
if image_type == "columns-overlay":
return await render_overlay("columns", session_id)
if image_type == "rows-overlay":
return await render_overlay("rows", session_id)
if image_type == "words-overlay":
return await render_overlay("words", session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB — for cropped/dewarped, fall back through the chain
if image_type in ("cropped", "dewarped"):
data = await _get_base_image_png(session_id)
else:
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after crop (clean image available).
Falls back to dewarped if crop was skipped.
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(img_bgr)
result = detect_document_type(ocr_img, img_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
await _append_pipeline_log(session_id, "detect_type", {
"doc_type": result.doc_type,
"pipeline": result.pipeline,
"confidence": result.confidence,
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **result_dict}
# Backward-compat shim -- module moved to ocr/pipeline/sessions_images.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_images")

View File

@@ -1,299 +1,4 @@
"""
OCR Pipeline Structure Detection and Exclude Regions
Detect document structure (boxes, zones, color regions, graphics)
and manage user-drawn exclude regions.
Extracted from ocr_pipeline_geometry.py for file-size compliance.
"""
import logging
import time
from typing import Any, Dict, List
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_box_detect import detect_boxes
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
from cv_graphic_detect import detect_graphic_elements
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_filter_border_ghost_words,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Structure Detection Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-structure")
async def detect_structure(session_id: str):
"""Detect document structure: boxes, zones, and color regions.
Runs box detection (line + shading) and color analysis on the cropped
image. Returns structured JSON with all detected elements for the
structure visualization step.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = (
cached.get("cropped_bgr")
if cached.get("cropped_bgr") is not None
else cached.get("dewarped_bgr")
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
h, w = img_bgr.shape[:2]
# --- Content bounds from word result (if available) or full image ---
word_result = cached.get("word_result")
words: List[Dict] = []
if word_result and word_result.get("cells"):
for cell in word_result["cells"]:
for wb in (cell.get("word_boxes") or []):
words.append(wb)
# Fallback: use raw OCR words if cell word_boxes are empty
if not words and word_result:
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
raw = word_result.get(key, [])
if raw:
words = raw
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
break
# If no words yet, use image dimensions with small margin
if words:
content_x = max(0, min(int(wb["left"]) for wb in words))
content_y = max(0, min(int(wb["top"]) for wb in words))
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
content_w_px = content_r - content_x
content_h_px = content_b - content_y
else:
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
# --- Box detection ---
boxes = detect_boxes(
img_bgr,
content_x=content_x,
content_w=content_w_px,
content_y=content_y,
content_h=content_h_px,
)
# --- Zone splitting ---
from cv_box_detect import split_page_into_zones as _split_zones
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
# --- Color region sampling ---
# Sample background shading in each detected box
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
box_colors = []
for box in boxes:
# Sample the center region of each box
cy1 = box.y + box.height // 4
cy2 = box.y + 3 * box.height // 4
cx1 = box.x + box.width // 4
cx2 = box.x + 3 * box.width // 4
cy1 = max(0, min(cy1, h - 1))
cy2 = max(0, min(cy2, h - 1))
cx1 = max(0, min(cx1, w - 1))
cx2 = max(0, min(cx2, w - 1))
if cy2 > cy1 and cx2 > cx1:
roi_hsv = hsv[cy1:cy2, cx1:cx2]
med_h = float(np.median(roi_hsv[:, :, 0]))
med_s = float(np.median(roi_hsv[:, :, 1]))
med_v = float(np.median(roi_hsv[:, :, 2]))
if med_s > 15:
from cv_color_detect import _hue_to_color_name
bg_name = _hue_to_color_name(med_h)
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
else:
bg_name = "gray" if med_v < 220 else "white"
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
else:
bg_name = "unknown"
bg_hex = "#6b7280"
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
# --- Color text detection overview ---
# Quick scan for colored text regions across the page
color_summary: Dict[str, int] = {}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
pixel_count = int(np.sum(mask > 0))
if pixel_count > 50: # minimum threshold
color_summary[color_name] = pixel_count
# --- Graphic element detection ---
box_dicts = [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
for b in boxes
]
graphics = detect_graphic_elements(
img_bgr, words,
detected_boxes=box_dicts,
)
# --- Filter border-ghost words from OCR result ---
ghost_count = 0
if boxes and word_result:
ghost_count = _filter_border_ghost_words(word_result, boxes)
if ghost_count:
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
await update_session_db(session_id, word_result=word_result)
cached["word_result"] = word_result
duration = time.time() - t0
# Preserve user-drawn exclude regions from previous run
prev_sr = cached.get("structure_result") or {}
prev_exclude = prev_sr.get("exclude_regions", [])
result_dict = {
"image_width": w,
"image_height": h,
"content_bounds": {
"x": content_x, "y": content_y,
"w": content_w_px, "h": content_h_px,
},
"boxes": [
{
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence,
"border_thickness": b.border_thickness,
"bg_color_name": box_colors[i]["color_name"],
"bg_color_hex": box_colors[i]["color_hex"],
}
for i, b in enumerate(boxes)
],
"zones": [
{
"index": z.index,
"zone_type": z.zone_type,
"y": z.y, "h": z.height,
"x": z.x, "w": z.width,
}
for z in zones
],
"graphics": [
{
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
"area": g.area,
"shape": g.shape,
"color_name": g.color_name,
"color_hex": g.color_hex,
"confidence": round(g.confidence, 2),
}
for g in graphics
],
"exclude_regions": prev_exclude,
"color_pixel_counts": color_summary,
"has_words": len(words) > 0,
"word_count": len(words),
"border_ghosts_removed": ghost_count,
"duration_seconds": round(duration, 2),
}
# Persist to session
await update_session_db(session_id, structure_result=result_dict)
cached["structure_result"] = result_dict
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
session_id, len(boxes), len(zones), len(graphics), duration)
return {"session_id": session_id, **result_dict}
# ---------------------------------------------------------------------------
# Exclude Regions -- user-drawn rectangles to exclude from OCR results
# ---------------------------------------------------------------------------
class _ExcludeRegionIn(BaseModel):
x: int
y: int
w: int
h: int
label: str = ""
class _ExcludeRegionsBatchIn(BaseModel):
regions: list[_ExcludeRegionIn]
@router.put("/sessions/{session_id}/exclude-regions")
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
"""Replace all exclude regions for a session.
Regions are stored inside ``structure_result.exclude_regions``.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
# Update cache
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"exclude_regions": sr["exclude_regions"],
"count": len(sr["exclude_regions"]),
}
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
async def delete_exclude_region(session_id: str, region_index: int):
"""Remove a single exclude region by index."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
regions = sr.get("exclude_regions", [])
if region_index < 0 or region_index >= len(regions):
raise HTTPException(status_code=404, detail="Region index out of range")
removed = regions.pop(region_index)
sr["exclude_regions"] = regions
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"removed": removed,
"remaining": len(regions),
}
# Backward-compat shim -- module moved to ocr/pipeline/structure.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.structure")

View File

@@ -1,362 +1,4 @@
"""
OCR Pipeline Validation — image detection, generation, validation save,
and handwriting removal endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Image detection + generation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM."""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux."""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
# ---------------------------------------------------------------------------
# Validation save/get
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""Remove handwriting from a session image using inpainting."""
import time as _time
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
t0 = _time.monotonic()
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}
# Backward-compat shim -- module moved to ocr/pipeline/validation.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.validation")

View File

@@ -1,185 +1,4 @@
"""
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
and ground truth endpoints.
Split into sub-modules:
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
ocr_pipeline_words_stream — SSE streaming generators
This barrel module contains the PaddleOCR direct endpoint and ground truth
endpoints, and assembles all word-related routers.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_words_first import build_grid_from_words
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_append_pipeline_log,
)
from ocr_pipeline_words_detect import router as _detect_router
logger = logging.getLogger(__name__)
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
# ---------------------------------------------------------------------------
# PaddleOCR Direct Endpoint
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/paddle-direct")
async def paddle_direct(session_id: str):
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode original image")
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
word_dicts = await ocr_region_paddle(img_bgr, region=None)
if not word_dicts:
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "paddle_direct"
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": "paddle_direct",
"grid_method": "paddle_direct",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
await update_session_db(
session_id,
word_result=word_result,
cropped_png=img_png,
current_step=8,
)
logger.info(
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
session_id, len(cells), n_rows, n_cols, duration,
)
await _append_pipeline_log(session_id, "paddle_direct", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": "paddle_direct",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
# ---------------------------------------------------------------------------
# Ground Truth Words Endpoints
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_entries": req.corrected_entries,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"word_result": session.get("word_result"),
}
ground_truth["words"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@_local_router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
words_gt = ground_truth.get("words")
if not words_gt:
raise HTTPException(status_code=404, detail="No word ground truth saved")
return {
"session_id": session_id,
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Composite router
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_detect_router)
router.include_router(_local_router)
# Backward-compat shim -- module moved to ocr/pipeline/words.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words")

View File

@@ -1,393 +1,4 @@
"""
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
endpoint which handles both v2 and words_first grid methods.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
create_ocr_image,
detect_column_geometry,
)
from cv_words_first import build_grid_from_words
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
)
from ocr_pipeline_words_stream import (
_word_batch_stream_generator,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Word Detection Endpoint (Step 7)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
skip_heal_gaps: bool = False,
grid_method: str = "v2",
):
"""Build word grid from columns x rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
pronunciation: 'british' (default) or 'american'
stream: false (default) for JSON response, true for SSE streaming
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
grid_method: 'v2' (default) or 'words_first'
"""
# PaddleOCR is full-page remote OCR -> force words_first grid method
if engine == "paddle" and grid_method != "words_first":
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
grid_method = "words_first"
if session_id not in _cache:
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": img_w_tmp, "height": img_h_tmp,
"classification_confidence": 1.0,
"classification_method": "full_page_fallback",
}],
"zones": [],
"duration_seconds": 0,
}
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Populate word counts from cached words
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
else:
top_y = min(r.y for r in row_geoms) if row_geoms else 0
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
# Exclude rows that fall within box zones
zones = column_result.get("zones") or []
box_ranges_inner = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5)
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges_inner:
def _row_in_box(r):
center_y = r.y + r.height / 2
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
before_count = len(row_geoms)
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
excluded = before_count - len(row_geoms)
if excluded:
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
# --- Words-First path ---
if grid_method == "words_first":
return await _words_first_path(
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
)
if stream:
return StreamingResponse(
_word_batch_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
skip_heal_gaps=skip_heal_gaps,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (grid_method=v2) ---
return await _v2_path(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
)
async def _words_first_path(
session_id: str,
cached: Dict[str, Any],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
zones: list,
) -> dict:
"""Words-first grid construction path."""
t0 = time.time()
img_h, img_w = dewarped_bgr.shape[:2]
if engine == "paddle":
from cv_ocr_engines import ocr_region_paddle
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
cached["_paddle_word_dicts"] = wf_word_dicts
else:
wf_word_dicts = cached.get("_word_dicts")
if wf_word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
cached["_word_dicts"] = wf_word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if not wf_word_dicts:
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
# Convert word coordinates to absolute if needed
if engine != "paddle":
content_bounds = cached.get("_content_bounds")
if content_bounds:
lx, _rx, ty, _by = content_bounds
abs_words = []
for w in wf_word_dicts:
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
wf_word_dicts = abs_words
box_rects = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box_rects.append(zone["box"])
cells, columns_meta = build_grid_from_words(
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
)
duration = time.time() - t0
fix_cell_phonetics(cells, pronunciation=pronunciation)
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
n_cols = len(columns_meta)
used_engine = "paddle" if engine == "paddle" else "words_first"
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"grid_method": "words_first",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
if is_vocab or 'column_text' in col_types:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words-first session {session_id}: "
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
await _append_pipeline_log(session_id, "words", {
"grid_method": "words_first",
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
async def _v2_path(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
skip_heal_gaps: bool,
) -> dict:
"""Cell-First OCR v2 non-streaming path."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
)
duration = time.time() - t0
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len(columns_meta)
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
fix_cell_phonetics(cells, pronunciation=pronunciation)
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
await _append_pipeline_log(session_id, "words", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"low_confidence_count": word_result["summary"]["low_confidence"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
"entry_count": word_result.get("entry_count", 0),
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
# Backward-compat shim -- module moved to ocr/pipeline/words_detect.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_detect")

View File

@@ -1,303 +1,4 @@
"""
OCR Pipeline Words Stream — SSE streaming generators for word detection.
Extracted from ocr_pipeline_words.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import Request
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
create_ocr_image,
)
from ocr_pipeline_session_store import update_session_db
from ocr_pipeline_common import _cache
logger = logging.getLogger(__name__)
async def _word_batch_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
skip_heal_gaps: bool = False,
):
"""SSE generator that runs batch OCR (parallel) then streams results.
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
then emits all cells as SSE events.
"""
import asyncio
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
total_cells = n_content_rows * n_cols
# 1. Send meta event immediately
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# 2. Send preparing event (keepalive for proxy)
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
# 3. Run batch OCR in thread pool with periodic keepalive events.
loop = asyncio.get_event_loop()
ocr_future = loop.run_in_executor(
None,
lambda: build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
),
)
# Send keepalive events every 5 seconds while OCR runs
keepalive_count = 0
while not ocr_future.done():
try:
cells, columns_meta = await asyncio.wait_for(
asyncio.shield(ocr_future), timeout=5.0,
)
break # OCR finished
except asyncio.TimeoutError:
keepalive_count += 1
elapsed = int(time.time() - t0)
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
ocr_future.cancel()
return
else:
cells, columns_meta = ocr_future.result()
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
return
# 4. Apply IPA phonetic fixes
fix_cell_phonetics(cells, pronunciation=pronunciation)
# 5. Send columns meta
if columns_meta:
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
# 6. Stream all cells
for idx, cell in enumerate(cells):
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": idx + 1, "total": len(cells)},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# 7. Build final result and persist
duration = time.time() - t0
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
# 8. Send complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
columns_meta = None
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
last_keepalive = time.time()
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
meta_update = {"type": "columns", "columns_used": cols_meta}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
# Remove all-empty rows
rows_with_text: set = set()
for c in all_cells:
if c.get("text", "").strip():
rows_with_text.add(c["row_index"])
before_filter = len(all_cells)
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
if empty_rows_removed > 0:
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
word_result = {
"cells": all_cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
# Backward-compat shim -- module moved to ocr/pipeline/words_stream.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_stream")

View File

@@ -1,188 +1,4 @@
"""
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import detect_and_fix_orientation
from page_crop import detect_page_splits
from ocr_pipeline_session_store import update_session_db
from orientation_crop_helpers import ensure_cached, append_pipeline_log
from page_sub_sessions import create_page_sub_sessions_full
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 1: Orientation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/orientation")
async def detect_orientation(session_id: str):
"""Detect and fix 90/180/270 degree rotations from scanners.
Reads the original image, applies orientation correction,
stores the result as oriented_png.
"""
cached = await ensure_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
# Detect and fix orientation
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
duration = time.time() - t0
orientation_result = {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
"duration_seconds": round(duration, 2),
}
# Encode oriented image
success, png_buf = cv2.imencode(".png", oriented_bgr)
oriented_png = png_buf.tobytes() if success else b""
# Update cache
cached["oriented_bgr"] = oriented_bgr
cached["orientation_result"] = orientation_result
# Persist to DB
await update_session_db(
session_id,
oriented_png=oriented_png,
orientation_result=orientation_result,
current_step=2,
)
logger.info(
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
session_id, orientation_deg,
"corrected" if orientation_deg else "no change",
duration,
)
await append_pipeline_log(session_id, "orientation", {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
}, duration_ms=int(duration * 1000))
h, w = oriented_bgr.shape[:2]
return {
"session_id": session_id,
**orientation_result,
"image_width": w,
"image_height": h,
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
}
# ---------------------------------------------------------------------------
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/page-split")
async def detect_page_split(session_id: str):
"""Detect if the image is a double-page book spread and split into sub-sessions.
Must be called **after orientation** (step 1) and **before deskew** (step 2).
Each sub-session receives the raw page region and goes through the full
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
independently, so each page gets its own deskew correction.
Returns ``{"multi_page": false}`` if only one page is detected.
"""
cached = await ensure_cached(session_id)
# Use oriented (preferred), fall back to original
img_bgr = next(
(v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for page-split detection")
t0 = time.time()
page_splits = detect_page_splits(img_bgr)
used_original = False
if not page_splits or len(page_splits) < 2:
# Orientation may have rotated a landscape double-page spread to
# portrait. Try the original (pre-orientation) image as fallback.
orig_bgr = cached.get("original_bgr")
if orig_bgr is not None and orig_bgr is not img_bgr:
page_splits_orig = detect_page_splits(orig_bgr)
if page_splits_orig and len(page_splits_orig) >= 2:
logger.info(
"OCR Pipeline: page-split session %s: spread detected on "
"ORIGINAL (orientation rotated it away)",
session_id,
)
img_bgr = orig_bgr
page_splits = page_splits_orig
used_original = True
if not page_splits or len(page_splits) < 2:
duration = time.time() - t0
logger.info(
"OCR Pipeline: page-split session %s: single page (%.2fs)",
session_id, duration,
)
return {
"session_id": session_id,
"multi_page": False,
"duration_seconds": round(duration, 2),
}
# Multi-page spread detected — create sub-sessions for full pipeline.
# start_step=2 means "ready for deskew" (orientation already applied).
# start_step=1 means "needs orientation too" (split from original image).
start_step = 1 if used_original else 2
sub_sessions = await create_page_sub_sessions_full(
session_id, cached, img_bgr, page_splits, start_step=start_step,
)
duration = time.time() - t0
split_info: Dict[str, Any] = {
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"used_original": used_original,
"duration_seconds": round(duration, 2),
}
# Mark parent session as split and hidden from session list
await update_session_db(session_id, crop_result=split_info, status='split')
cached["crop_result"] = split_info
await append_pipeline_log(session_id, "page_split", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
logger.info(
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
session_id, len(page_splits), duration,
)
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**split_info,
"image_width": w,
"image_height": h,
"sub_sessions": sub_sessions,
}
# Backward-compat shim -- module moved to ocr/pipeline/orientation_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_api")

View File

@@ -1,16 +1,4 @@
"""
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
Barrel re-export: merges routers from orientation_api and crop_api,
and re-exports set_cache_ref for main.py.
"""
from fastapi import APIRouter
from orientation_crop_helpers import set_cache_ref # noqa: F401
from orientation_api import router as _orientation_router
from crop_api import router as _crop_router
router = APIRouter()
router.include_router(_orientation_router)
router.include_router(_crop_router)
# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_api")

View File

@@ -1,86 +1,4 @@
"""
Orientation & Crop shared helpers - cache management and pipeline logging.
"""
import logging
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import HTTPException
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
_cache: Dict[str, Dict[str, Any]] = {}
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
"""Set reference to the shared cache from ocr_pipeline_api."""
global _cache
_cache = cache
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
"""Get reference to the shared cache."""
return _cache
async def ensure_cached(session_id: str) -> Dict[str, Any]:
"""Ensure session is in cache, loading from DB if needed."""
if session_id in _cache:
return _cache[session_id]
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
"""Append a step entry to the pipeline log."""
from datetime import datetime
session = await get_session_db(session_id)
if not session:
return
pipeline_log = session.get("pipeline_log") or {"steps": []}
pipeline_log["steps"].append({
"step": step,
"completed_at": datetime.utcnow().isoformat(),
"success": True,
"duration_ms": duration_ms,
"metrics": metrics,
})
await update_session_db(session_id, pipeline_log=pipeline_log)
# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_helpers.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_helpers")

View File

@@ -1,33 +1,4 @@
"""
Page Crop — Barrel Re-export
Content-based crop for scanned pages and book scans.
Split into:
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
- page_crop_core.py — Main crop algorithm and format detection
All public names are re-exported here for backward compatibility.
License: Apache 2.0
"""
# Core: main crop functions and format detection
from page_crop_core import ( # noqa: F401
PAPER_FORMATS,
detect_page_splits,
detect_and_crop_page,
_detect_format,
)
# Edge detection helpers
from page_crop_edges import ( # noqa: F401
_INK_THRESHOLD,
_MIN_RUN_FRAC,
_detect_spine_shadow,
_detect_gutter_continuity,
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
_detect_edge_projection,
_filter_narrow_runs,
)
# Backward-compat shim -- module moved to ocr/pipeline/page_crop.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop")

View File

@@ -1,342 +1,4 @@
"""
Page Crop - Core Crop and Format Detection
Content-based crop for scanned pages and book scans. Detects the content
boundary by analysing ink density projections and (for book scans) the
spine shadow gradient.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple
import cv2
import numpy as np
from page_crop_edges import (
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
)
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = []
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo)
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
run_brightness = float(np.mean(center_brightness[rs:re]))
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold -> binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Edge detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)
# Backward-compat shim -- module moved to ocr/pipeline/page_crop_core.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_core")

View File

@@ -1,388 +1,4 @@
"""
Page Crop - Edge Detection Helpers
Spine shadow detection, gutter continuity analysis, projection-based
edge detection, and narrow-run filtering for content cropping.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Optional, Tuple
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
max_valley_frac = 0.50
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
rise_check_w = max(5, trimmed_len // 5)
if side == "left":
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has >= 75% of strips darker than page_median - d
4. Smooth the dark-fraction profile and find the transition point
5. Validate: gutter band must be 0.5%-10% of image width
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
dark_thresh = page_median - 5.0
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips
# --- 4. Smooth and find transition ---
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None
if side == "right":
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005))
max_gutter = int(w * 0.10)
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 -> project vertically (column densities) -> returns x position
axis=1 -> project horizontally (row densities) -> returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
projection = np.mean(binary, axis=axis) / 255.0
ink_mask = projection >= _INK_THRESHOLD
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result
# Backward-compat shim -- module moved to ocr/pipeline/page_crop_edges.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_edges")

View File

@@ -1,189 +1,4 @@
"""
Sub-session creation for multi-page spreads.
Used by both the page-split and crop steps when a double-page scan is detected.
"""
import logging
import uuid as uuid_mod
from typing import Any, Dict, List
import cv2
import numpy as np
from page_crop import detect_and_crop_page
from ocr_pipeline_session_store import (
create_session_db,
get_sub_sessions,
update_session_db,
)
from orientation_crop_helpers import get_cache_ref
logger = logging.getLogger(__name__)
async def create_page_sub_sessions(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each detected page in a multi-page spread.
Each page region is individually cropped, then stored as a sub-session
with its own cropped image ready for the rest of the pipeline.
"""
# Check for existing sub-sessions (idempotent)
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract page region
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Crop each page individually (remove its own borders)
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
# Encode as PNG
ok, png_buf = cv2.imencode(".png", cropped_page)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# Pre-populate: set cropped = original (already cropped)
await update_session_db(
sub_id,
cropped_png=page_png,
crop_result=page_crop_info,
current_step=5,
)
ch, cw = cropped_page.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"cropped_size": {"width": cw, "height": ch},
"detected_format": page_crop_info.get("detected_format"),
})
logger.info(
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
sub_id, pi + 1, px, pw, cw, ch,
)
return sub_sessions
async def create_page_sub_sessions_full(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
start_step: int = 2,
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
Unlike ``create_page_sub_sessions`` (used by the crop step), these
sub-sessions store the *uncropped* page region and start at
``start_step`` (default 2 = ready for deskew; 1 if orientation still
needed). Each page goes through its own pipeline independently,
which is essential for book spreads where each page has a different tilt.
"""
_cache = get_cache_ref()
# Idempotent: reuse existing sub-sessions
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract RAW page region — NO individual cropping here; each
# sub-session will run its own crop step after deskew + dewarp.
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Encode as PNG
ok, png_buf = cv2.imencode(".png", page_bgr)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# start_step=2 -> ready for deskew (orientation already done on spread)
# start_step=1 -> needs its own orientation (split from original image)
await update_session_db(sub_id, current_step=start_step)
# Cache the BGR so the pipeline can start immediately
_cache[sub_id] = {
"id": sub_id,
"filename": parent_filename,
"name": sub_name,
"original_bgr": page_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": start_step,
}
rh, rw = page_bgr.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"image_size": {"width": rw, "height": rh},
})
logger.info(
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
sub_id, pi + 1, px, pw, rw, rh,
)
return sub_sessions
# Backward-compat shim -- module moved to ocr/pipeline/page_sub_sessions.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_sub_sessions")

View File

@@ -1,102 +1,4 @@
"""
Scan Quality Assessment — Measures image quality before OCR.
Computes blur score, contrast score, and an overall quality rating.
Used to gate enhancement steps and warn users about degraded scans.
All operations use OpenCV (Apache-2.0), no additional dependencies.
"""
import logging
from dataclasses import dataclass, asdict
from typing import Dict, Any
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Thresholds (empirically tuned on textbook scans)
BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry
CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast
CONFIDENCE_GOOD = 40 # OCR min confidence for good scans
CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans
@dataclass
class ScanQualityReport:
"""Result of scan quality assessment."""
blur_score: float # Laplacian variance (higher = sharper)
contrast_score: float # Grayscale std deviation (higher = more contrast)
brightness: float # Mean grayscale value (0-255)
is_blurry: bool
is_low_contrast: bool
is_degraded: bool # True if any quality issue detected
quality_pct: int # 0-100 overall quality estimate
recommended_min_conf: int # Recommended OCR confidence threshold
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport:
"""
Assess the quality of a scanned image.
Uses:
- Laplacian variance for blur detection
- Grayscale standard deviation for contrast
- Mean brightness for exposure assessment
Args:
img_bgr: BGR image (numpy array from OpenCV)
Returns:
ScanQualityReport with scores and recommendations
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Blur detection: Laplacian variance
# Higher = sharper edges = better quality
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
blur_score = float(laplacian.var())
# Contrast: standard deviation of grayscale
contrast_score = float(np.std(gray))
# Brightness: mean grayscale
brightness = float(np.mean(gray))
# Quality flags
is_blurry = blur_score < BLUR_THRESHOLD
is_low_contrast = contrast_score < CONTRAST_THRESHOLD
is_degraded = is_blurry or is_low_contrast
# Overall quality percentage (simple weighted combination)
blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50)
contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50)
quality_pct = int(min(100, blur_pct + contrast_pct))
# Recommended confidence threshold
recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD
report = ScanQualityReport(
blur_score=round(blur_score, 1),
contrast_score=round(contrast_score, 1),
brightness=round(brightness, 1),
is_blurry=is_blurry,
is_low_contrast=is_low_contrast,
is_degraded=is_degraded,
quality_pct=quality_pct,
recommended_min_conf=recommended_min_conf,
)
logger.info(
f"Scan quality: blur={report.blur_score} "
f"contrast={report.contrast_score} "
f"quality={report.quality_pct}% "
f"degraded={report.is_degraded} "
f"min_conf={report.recommended_min_conf}"
)
return report
# Backward-compat shim -- module moved to ocr/pipeline/scan_quality.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.scan_quality")

View File

@@ -1,261 +1,4 @@
"""
Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading.
Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL.
The LLM can read degraded text using context understanding and visual inspection,
while OCR coordinates provide structural hints (where text is, column positions).
Uses Ollama API (same pattern as handwriting_htr_api.py).
"""
import base64
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional
import cv2
import httpx
import numpy as np
logger = logging.getLogger(__name__)
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b")
# Document category → prompt context
CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = {
"vokabelseite": {
"label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)",
"columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.",
},
"woerterbuch": {
"label": "Woerterbuchseite",
"columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.",
},
"arbeitsblatt": {
"label": "Arbeitsblatt",
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
},
"buchseite": {
"label": "Schulbuchseite",
"columns": "Erkenne die Spaltenstruktur aus dem Layout.",
},
}
def _group_words_into_lines(
words: List[Dict], y_tolerance: float = 15.0,
) -> List[List[Dict]]:
"""Group OCR words into lines by Y-proximity."""
if not words:
return []
sorted_w = sorted(words, key=lambda w: w.get("top", 0))
lines: List[List[Dict]] = [[sorted_w[0]]]
for w in sorted_w[1:]:
last_line = lines[-1]
avg_y = sum(ww["top"] for ww in last_line) / len(last_line)
if abs(w["top"] - avg_y) <= y_tolerance:
last_line.append(w)
else:
lines.append([w])
# Sort words within each line by X
for line in lines:
line.sort(key=lambda w: w.get("left", 0))
return lines
def _build_ocr_context(words: List[Dict], img_h: int) -> str:
"""Build a text description of OCR words with positions for the prompt."""
lines = _group_words_into_lines(words)
context_parts = []
for i, line in enumerate(lines):
word_descs = []
for w in line:
text = w.get("text", "").strip()
x = w.get("left", 0)
conf = w.get("conf", 0)
marker = " (?)" if conf < 50 else ""
word_descs.append(f'x={x} "{text}"{marker}')
avg_y = int(sum(w["top"] for w in line) / len(line))
context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}")
return "\n".join(context_parts)
def _build_prompt(
ocr_context: str, category: str, img_w: int, img_h: int,
) -> str:
"""Build the Vision-LLM prompt with OCR context and document type."""
cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"])
return f"""Du siehst eine eingescannte {cat_info['label']}.
{cat_info['columns']}
Die OCR-Software hat folgende Woerter an diesen Positionen erkannt.
Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch:
{ocr_context}
Bildgroesse: {img_w} x {img_h} Pixel.
AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle.
- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst
- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist,
gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht)
- Behalte die Reihenfolge bei
Antworte NUR mit einem JSON-Array, keine Erklaerungen:
[
{{"row": 1, "english": "...", "german": "...", "example": "..."}},
{{"row": 2, "english": "...", "german": "...", "example": "..."}}
]"""
def _parse_llm_response(response_text: str) -> Optional[List[Dict]]:
"""Parse the LLM JSON response, handling markdown code blocks."""
text = response_text.strip()
# Strip markdown code block if present
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```\s*$", "", text)
# Try to find JSON array
match = re.search(r"\[[\s\S]*\]", text)
if not match:
logger.warning("vision_fuse_ocr: no JSON array found in LLM response")
return None
try:
data = json.loads(match.group())
if not isinstance(data, list):
return None
return data
except json.JSONDecodeError as e:
logger.warning(f"vision_fuse_ocr: JSON parse error: {e}")
return None
def _vocab_rows_to_words(
rows: List[Dict], img_w: int, img_h: int,
) -> List[Dict]:
"""Convert LLM vocab rows back to word dicts for grid building.
Distributes words across estimated column positions so the
existing grid builder can process them normally.
"""
words = []
# Estimate column positions (3-column vocab layout)
col_positions = [
(0.02, 0.28), # EN: 2%-28% of width
(0.30, 0.55), # DE: 30%-55%
(0.57, 0.98), # Example: 57%-98%
]
median_h = max(15, img_h // (len(rows) * 3)) if rows else 20
y_step = max(median_h + 5, img_h // max(len(rows), 1))
for i, row in enumerate(rows):
y = int(i * y_step + 20)
row_num = row.get("row", i + 1)
for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([
("english", col_positions[0]),
("german", col_positions[1]),
("example", col_positions[2]),
]):
text = (row.get(field) or "").strip()
if not text:
continue
x = int(x_start_pct * img_w)
w = int((x_end_pct - x_start_pct) * img_w)
words.append({
"text": text,
"left": x,
"top": y,
"width": w,
"height": median_h,
"conf": 95, # LLM-corrected → high confidence
"_source": "vision_llm",
"_row": row_num,
"_col_type": f"column_{['en', 'de', 'example'][col_idx]}",
})
logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words")
return words
async def vision_fuse_ocr(
img_bgr: np.ndarray,
ocr_words: List[Dict],
document_category: str = "vokabelseite",
) -> List[Dict]:
"""Fuse traditional OCR results with Vision-LLM reading.
Sends the image + OCR word positions to Qwen2.5-VL which can:
- Read degraded text that traditional OCR cannot
- Use document context (knows what a vocab table looks like)
- Merge continuation rows (understands table structure)
Args:
img_bgr: The cropped/dewarped scan image (BGR)
ocr_words: Traditional OCR word list with positions
document_category: Type of document being scanned
Returns:
Corrected word list in same format as input, ready for grid building.
Falls back to original ocr_words on error.
"""
img_h, img_w = img_bgr.shape[:2]
# Build OCR context string
ocr_context = _build_ocr_context(ocr_words, img_h)
# Build prompt
prompt = _build_prompt(ocr_context, document_category, img_w, img_h)
# Encode image as base64
_, img_encoded = cv2.imencode(".png", img_bgr)
img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8")
# Call Qwen2.5-VL via Ollama
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": VISION_FUSION_MODEL,
"prompt": prompt,
"images": [img_b64],
"stream": False,
"options": {"temperature": 0.1, "num_predict": 4096},
},
)
resp.raise_for_status()
data = resp.json()
response_text = data.get("response", "").strip()
except Exception as e:
logger.error(f"vision_fuse_ocr: Ollama call failed: {e}")
return ocr_words # Fallback to original
if not response_text:
logger.warning("vision_fuse_ocr: empty LLM response")
return ocr_words
# Parse JSON response
rows = _parse_llm_response(response_text)
if not rows:
logger.warning(
"vision_fuse_ocr: could not parse LLM response, "
"first 200 chars: %s", response_text[:200],
)
return ocr_words
logger.info(
f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows "
f"(from {len(ocr_words)} OCR words)"
)
# Convert back to word format for grid building
return _vocab_rows_to_words(rows, img_w, img_h)
# Backward-compat shim -- module moved to ocr/pipeline/vision_fusion.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.vision_fusion")