From 0504d22b8e969a3057701ec36654af2be38bafa8 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sat, 25 Apr 2026 21:51:43 +0200 Subject: [PATCH] Restructure: Move ocr_pipeline + labeling + crop into ocr/ package Co-Authored-By: Claude Opus 4.6 (1M context) --- .claude/rules/loc-exceptions.txt | 1 + klausur-service/backend/crop_api.py | 294 +--------- klausur-service/backend/cv_vocab_pipeline.py | 4 +- klausur-service/backend/ocr/__init__.py | 2 +- .../ocr/{pipeline.py => cv_pipeline.py} | 0 .../backend/ocr/labeling/__init__.py | 6 + klausur-service/backend/ocr/labeling/api.py | 81 +++ .../backend/ocr/labeling/helpers.py | 205 +++++++ .../backend/ocr/labeling/models.py | 86 +++ .../backend/ocr/labeling/routes.py | 241 ++++++++ .../backend/ocr/labeling/upload_routes.py | 313 +++++++++++ .../backend/ocr/pipeline/__init__.py | 8 + klausur-service/backend/ocr/pipeline/api.py | 63 +++ klausur-service/backend/ocr/pipeline/auto.py | 23 + .../backend/ocr/pipeline/auto_helpers.py | 84 +++ .../backend/ocr/pipeline/auto_steps.py | 528 +++++++++++++++++ .../backend/ocr/pipeline/columns.py | 293 ++++++++++ .../backend/ocr/pipeline/common.py | 354 ++++++++++++ .../backend/ocr/pipeline/crop_api.py | 290 ++++++++++ .../backend/ocr/pipeline/deskew.py | 236 ++++++++ .../backend/ocr/pipeline/dewarp.py | 346 ++++++++++++ .../backend/ocr/pipeline/geometry.py | 27 + .../backend/ocr/pipeline/llm_review.py | 209 +++++++ .../backend/ocr/pipeline/merge_helpers.py | 272 +++++++++ .../backend/ocr/pipeline/ocr_merge.py | 266 +++++++++ .../backend/ocr/pipeline/orientation_api.py | 188 +++++++ .../ocr/pipeline/orientation_crop_api.py | 16 + .../ocr/pipeline/orientation_crop_helpers.py | 86 +++ .../backend/ocr/pipeline/overlay_grid.py | 333 +++++++++++ .../backend/ocr/pipeline/overlay_structure.py | 205 +++++++ .../backend/ocr/pipeline/overlays.py | 34 ++ .../backend/ocr/pipeline/page_crop.py | 33 ++ .../backend/ocr/pipeline/page_crop_core.py | 342 +++++++++++ .../backend/ocr/pipeline/page_crop_edges.py | 388 +++++++++++++ .../backend/ocr/pipeline/page_sub_sessions.py | 189 +++++++ .../backend/ocr/pipeline/postprocess.py | 26 + .../backend/ocr/pipeline/reconstruction.py | 362 ++++++++++++ .../backend/ocr/pipeline/regression.py | 22 + .../ocr/pipeline/regression_endpoints.py | 421 ++++++++++++++ .../ocr/pipeline/regression_helpers.py | 207 +++++++ .../backend/ocr/pipeline/reprocess.py | 94 ++++ klausur-service/backend/ocr/pipeline/rows.py | 348 ++++++++++++ .../backend/ocr/pipeline/scan_quality.py | 102 ++++ .../backend/ocr/pipeline/session_store.py | 388 +++++++++++++ .../backend/ocr/pipeline/sessions.py | 20 + .../backend/ocr/pipeline/sessions_crud.py | 449 +++++++++++++++ .../backend/ocr/pipeline/sessions_images.py | 176 ++++++ .../backend/ocr/pipeline/structure.py | 299 ++++++++++ .../backend/ocr/pipeline/validation.py | 362 ++++++++++++ .../backend/ocr/pipeline/vision_fusion.py | 261 +++++++++ klausur-service/backend/ocr/pipeline/words.py | 185 ++++++ .../backend/ocr/pipeline/words_detect.py | 393 +++++++++++++ .../backend/ocr/pipeline/words_stream.py | 303 ++++++++++ klausur-service/backend/ocr_labeling_api.py | 85 +-- .../backend/ocr_labeling_helpers.py | 209 +------ .../backend/ocr_labeling_models.py | 90 +-- .../backend/ocr_labeling_routes.py | 245 +------- .../backend/ocr_labeling_upload_routes.py | 317 +---------- klausur-service/backend/ocr_merge_helpers.py | 276 +-------- klausur-service/backend/ocr_pipeline_api.py | 67 +-- klausur-service/backend/ocr_pipeline_auto.py | 27 +- .../backend/ocr_pipeline_auto_helpers.py | 88 +-- .../backend/ocr_pipeline_auto_steps.py | 532 +----------------- .../backend/ocr_pipeline_columns.py | 297 +--------- .../backend/ocr_pipeline_common.py | 358 +----------- .../backend/ocr_pipeline_deskew.py | 240 +------- .../backend/ocr_pipeline_dewarp.py | 350 +----------- .../backend/ocr_pipeline_geometry.py | 31 +- .../backend/ocr_pipeline_llm_review.py | 213 +------ .../backend/ocr_pipeline_ocr_merge.py | 270 +-------- .../backend/ocr_pipeline_overlay_grid.py | 337 +---------- .../backend/ocr_pipeline_overlay_structure.py | 209 +------ .../backend/ocr_pipeline_overlays.py | 38 +- .../backend/ocr_pipeline_postprocess.py | 30 +- .../backend/ocr_pipeline_reconstruction.py | 366 +----------- .../backend/ocr_pipeline_regression.py | 26 +- .../ocr_pipeline_regression_endpoints.py | 425 +------------- .../ocr_pipeline_regression_helpers.py | 211 +------ .../backend/ocr_pipeline_reprocess.py | 98 +--- klausur-service/backend/ocr_pipeline_rows.py | 352 +----------- .../backend/ocr_pipeline_session_store.py | 392 +------------ .../backend/ocr_pipeline_sessions.py | 24 +- .../backend/ocr_pipeline_sessions_crud.py | 453 +-------------- .../backend/ocr_pipeline_sessions_images.py | 180 +----- .../backend/ocr_pipeline_structure.py | 303 +--------- .../backend/ocr_pipeline_validation.py | 366 +----------- klausur-service/backend/ocr_pipeline_words.py | 189 +------ .../backend/ocr_pipeline_words_detect.py | 397 +------------ .../backend/ocr_pipeline_words_stream.py | 307 +--------- klausur-service/backend/orientation_api.py | 192 +------ .../backend/orientation_crop_api.py | 20 +- .../backend/orientation_crop_helpers.py | 90 +-- klausur-service/backend/page_crop.py | 37 +- klausur-service/backend/page_crop_core.py | 346 +----------- klausur-service/backend/page_crop_edges.py | 392 +------------ klausur-service/backend/page_sub_sessions.py | 193 +------ klausur-service/backend/scan_quality.py | 106 +--- klausur-service/backend/vision_ocr_fusion.py | 265 +-------- 98 files changed, 10351 insertions(+), 10152 deletions(-) rename klausur-service/backend/ocr/{pipeline.py => cv_pipeline.py} (100%) create mode 100644 klausur-service/backend/ocr/labeling/__init__.py create mode 100644 klausur-service/backend/ocr/labeling/api.py create mode 100644 klausur-service/backend/ocr/labeling/helpers.py create mode 100644 klausur-service/backend/ocr/labeling/models.py create mode 100644 klausur-service/backend/ocr/labeling/routes.py create mode 100644 klausur-service/backend/ocr/labeling/upload_routes.py create mode 100644 klausur-service/backend/ocr/pipeline/__init__.py create mode 100644 klausur-service/backend/ocr/pipeline/api.py create mode 100644 klausur-service/backend/ocr/pipeline/auto.py create mode 100644 klausur-service/backend/ocr/pipeline/auto_helpers.py create mode 100644 klausur-service/backend/ocr/pipeline/auto_steps.py create mode 100644 klausur-service/backend/ocr/pipeline/columns.py create mode 100644 klausur-service/backend/ocr/pipeline/common.py create mode 100644 klausur-service/backend/ocr/pipeline/crop_api.py create mode 100644 klausur-service/backend/ocr/pipeline/deskew.py create mode 100644 klausur-service/backend/ocr/pipeline/dewarp.py create mode 100644 klausur-service/backend/ocr/pipeline/geometry.py create mode 100644 klausur-service/backend/ocr/pipeline/llm_review.py create mode 100644 klausur-service/backend/ocr/pipeline/merge_helpers.py create mode 100644 klausur-service/backend/ocr/pipeline/ocr_merge.py create mode 100644 klausur-service/backend/ocr/pipeline/orientation_api.py create mode 100644 klausur-service/backend/ocr/pipeline/orientation_crop_api.py create mode 100644 klausur-service/backend/ocr/pipeline/orientation_crop_helpers.py create mode 100644 klausur-service/backend/ocr/pipeline/overlay_grid.py create mode 100644 klausur-service/backend/ocr/pipeline/overlay_structure.py create mode 100644 klausur-service/backend/ocr/pipeline/overlays.py create mode 100644 klausur-service/backend/ocr/pipeline/page_crop.py create mode 100644 klausur-service/backend/ocr/pipeline/page_crop_core.py create mode 100644 klausur-service/backend/ocr/pipeline/page_crop_edges.py create mode 100644 klausur-service/backend/ocr/pipeline/page_sub_sessions.py create mode 100644 klausur-service/backend/ocr/pipeline/postprocess.py create mode 100644 klausur-service/backend/ocr/pipeline/reconstruction.py create mode 100644 klausur-service/backend/ocr/pipeline/regression.py create mode 100644 klausur-service/backend/ocr/pipeline/regression_endpoints.py create mode 100644 klausur-service/backend/ocr/pipeline/regression_helpers.py create mode 100644 klausur-service/backend/ocr/pipeline/reprocess.py create mode 100644 klausur-service/backend/ocr/pipeline/rows.py create mode 100644 klausur-service/backend/ocr/pipeline/scan_quality.py create mode 100644 klausur-service/backend/ocr/pipeline/session_store.py create mode 100644 klausur-service/backend/ocr/pipeline/sessions.py create mode 100644 klausur-service/backend/ocr/pipeline/sessions_crud.py create mode 100644 klausur-service/backend/ocr/pipeline/sessions_images.py create mode 100644 klausur-service/backend/ocr/pipeline/structure.py create mode 100644 klausur-service/backend/ocr/pipeline/validation.py create mode 100644 klausur-service/backend/ocr/pipeline/vision_fusion.py create mode 100644 klausur-service/backend/ocr/pipeline/words.py create mode 100644 klausur-service/backend/ocr/pipeline/words_detect.py create mode 100644 klausur-service/backend/ocr/pipeline/words_stream.py diff --git a/.claude/rules/loc-exceptions.txt b/.claude/rules/loc-exceptions.txt index 486a0e1..082e607 100644 --- a/.claude/rules/loc-exceptions.txt +++ b/.claude/rules/loc-exceptions.txt @@ -47,6 +47,7 @@ # Single SSE generator orchestrating 6 pipeline steps — cannot split generator context **/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01 +**/ocr/pipeline/auto_steps.py | owner=klausur | reason=Same file moved to ocr/ package | review=2026-10-01 # Legacy — TEMPORAER bis Refactoring abgeschlossen # Dateien hier werden Phase fuer Phase abgearbeitet und entfernt. diff --git a/klausur-service/backend/crop_api.py b/klausur-service/backend/crop_api.py index ea5a72a..97971a7 100644 --- a/klausur-service/backend/crop_api.py +++ b/klausur-service/backend/crop_api.py @@ -1,290 +1,4 @@ -""" -Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline). - -Auto-crop, manual crop, and skip-crop for scanner/book borders. -""" - -import logging -import time -from typing import Any, Dict - -import cv2 -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from page_crop import detect_and_crop_page, detect_page_splits -from ocr_pipeline_session_store import get_sub_sessions, update_session_db - -from orientation_crop_helpers import ensure_cached, append_pipeline_log -from page_sub_sessions import create_page_sub_sessions - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Step 4 (UI index 3): Crop — runs after deskew + dewarp -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/crop") -async def auto_crop(session_id: str): - """Auto-detect and crop scanner/book borders. - - Reads the dewarped image (post-deskew + dewarp, so the page is straight). - Falls back to oriented -> original if earlier steps were skipped. - - If the image is a multi-page spread (e.g. book on scanner), it will - automatically split into separate sub-sessions per page, crop each - individually, and return the split info. - """ - cached = await ensure_cached(session_id) - - # Use dewarped (preferred), fall back to oriented, then original - img_bgr = next( - (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), - None, - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for cropping") - - t0 = time.time() - - # --- Check for existing sub-sessions (from page-split step) --- - # If page-split already created sub-sessions, skip multi-page detection - # in the crop step. Each sub-session runs its own crop independently. - existing_subs = await get_sub_sessions(session_id) - if existing_subs: - crop_result = cached.get("crop_result") or {} - if crop_result.get("multi_page"): - # Already split -- just return the existing info - duration = time.time() - t0 - h, w = img_bgr.shape[:2] - return { - "session_id": session_id, - **crop_result, - "image_width": w, - "image_height": h, - "sub_sessions": [ - {"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)} - for i, s in enumerate(existing_subs) - ], - "note": "Page split was already performed; each sub-session runs its own crop.", - } - - # --- Multi-page detection (fallback for sessions that skipped page-split) --- - page_splits = detect_page_splits(img_bgr) - - if page_splits and len(page_splits) >= 2: - # Multi-page spread detected -- create sub-sessions - sub_sessions = await create_page_sub_sessions( - session_id, cached, img_bgr, page_splits, - ) - duration = time.time() - t0 - - crop_info: Dict[str, Any] = { - "crop_applied": True, - "multi_page": True, - "page_count": len(page_splits), - "page_splits": page_splits, - "duration_seconds": round(duration, 2), - } - cached["crop_result"] = crop_info - - # Store the first page as the main cropped image for backward compat - first_page = page_splits[0] - first_bgr = img_bgr[ - first_page["y"]:first_page["y"] + first_page["height"], - first_page["x"]:first_page["x"] + first_page["width"], - ].copy() - first_cropped, _ = detect_and_crop_page(first_bgr) - cached["cropped_bgr"] = first_cropped - - ok, png_buf = cv2.imencode(".png", first_cropped) - await update_session_db( - session_id, - cropped_png=png_buf.tobytes() if ok else b"", - crop_result=crop_info, - current_step=5, - status='split', - ) - - logger.info( - "OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs", - session_id, len(page_splits), duration, - ) - - await append_pipeline_log(session_id, "crop", { - "multi_page": True, - "page_count": len(page_splits), - }, duration_ms=int(duration * 1000)) - - h, w = first_cropped.shape[:2] - return { - "session_id": session_id, - **crop_info, - "image_width": w, - "image_height": h, - "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", - "sub_sessions": sub_sessions, - } - - # --- Single page (normal) --- - cropped_bgr, crop_info = detect_and_crop_page(img_bgr) - - duration = time.time() - t0 - crop_info["duration_seconds"] = round(duration, 2) - crop_info["multi_page"] = False - - # Encode cropped image - success, png_buf = cv2.imencode(".png", cropped_bgr) - cropped_png = png_buf.tobytes() if success else b"" - - # Update cache - cached["cropped_bgr"] = cropped_bgr - cached["crop_result"] = crop_info - - # Persist to DB - await update_session_db( - session_id, - cropped_png=cropped_png, - crop_result=crop_info, - current_step=5, - ) - - logger.info( - "OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs", - session_id, crop_info["crop_applied"], - crop_info.get("detected_format", "?"), - duration, - ) - - await append_pipeline_log(session_id, "crop", { - "crop_applied": crop_info["crop_applied"], - "detected_format": crop_info.get("detected_format"), - "format_confidence": crop_info.get("format_confidence"), - }, duration_ms=int(duration * 1000)) - - h, w = cropped_bgr.shape[:2] - return { - "session_id": session_id, - **crop_info, - "image_width": w, - "image_height": h, - "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", - } - - -class ManualCropRequest(BaseModel): - x: float # percentage 0-100 - y: float # percentage 0-100 - width: float # percentage 0-100 - height: float # percentage 0-100 - - -@router.post("/sessions/{session_id}/crop/manual") -async def manual_crop(session_id: str, req: ManualCropRequest): - """Manually crop using percentage coordinates.""" - cached = await ensure_cached(session_id) - - img_bgr = next( - (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), - None, - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for cropping") - - h, w = img_bgr.shape[:2] - - # Convert percentages to pixels - px_x = int(w * req.x / 100.0) - px_y = int(h * req.y / 100.0) - px_w = int(w * req.width / 100.0) - px_h = int(h * req.height / 100.0) - - # Clamp - px_x = max(0, min(px_x, w - 1)) - px_y = max(0, min(px_y, h - 1)) - px_w = max(1, min(px_w, w - px_x)) - px_h = max(1, min(px_h, h - px_y)) - - cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy() - - success, png_buf = cv2.imencode(".png", cropped_bgr) - cropped_png = png_buf.tobytes() if success else b"" - - crop_result = { - "crop_applied": True, - "crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h}, - "crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2), - "width": round(req.width, 2), "height": round(req.height, 2)}, - "original_size": {"width": w, "height": h}, - "cropped_size": {"width": px_w, "height": px_h}, - "method": "manual", - } - - cached["cropped_bgr"] = cropped_bgr - cached["crop_result"] = crop_result - - await update_session_db( - session_id, - cropped_png=cropped_png, - crop_result=crop_result, - current_step=5, - ) - - ch, cw = cropped_bgr.shape[:2] - return { - "session_id": session_id, - **crop_result, - "image_width": cw, - "image_height": ch, - "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", - } - - -@router.post("/sessions/{session_id}/crop/skip") -async def skip_crop(session_id: str): - """Skip cropping -- use dewarped (or oriented/original) image as-is.""" - cached = await ensure_cached(session_id) - - img_bgr = next( - (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), - None, - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available") - - h, w = img_bgr.shape[:2] - - # Store the dewarped image as cropped (identity crop) - success, png_buf = cv2.imencode(".png", img_bgr) - cropped_png = png_buf.tobytes() if success else b"" - - crop_result = { - "crop_applied": False, - "skipped": True, - "original_size": {"width": w, "height": h}, - "cropped_size": {"width": w, "height": h}, - } - - cached["cropped_bgr"] = img_bgr - cached["crop_result"] = crop_result - - await update_session_db( - session_id, - cropped_png=cropped_png, - crop_result=crop_result, - current_step=5, - ) - - return { - "session_id": session_id, - **crop_result, - "image_width": w, - "image_height": h, - "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", - } +# Backward-compat shim -- module moved to ocr/pipeline/crop_api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.crop_api") diff --git a/klausur-service/backend/cv_vocab_pipeline.py b/klausur-service/backend/cv_vocab_pipeline.py index 01a0e6f..7875825 100644 --- a/klausur-service/backend/cv_vocab_pipeline.py +++ b/klausur-service/backend/cv_vocab_pipeline.py @@ -1,4 +1,4 @@ -# Backward-compat shim -- module moved to ocr\/pipeline.py +# Backward-compat shim -- module moved to ocr/cv_pipeline.py import importlib as _importlib import sys as _sys -_sys.modules[__name__] = _importlib.import_module("ocr.pipeline") +_sys.modules[__name__] = _importlib.import_module("ocr.cv_pipeline") diff --git a/klausur-service/backend/ocr/__init__.py b/klausur-service/backend/ocr/__init__.py index 36e3203..679c58a 100644 --- a/klausur-service/backend/ocr/__init__.py +++ b/klausur-service/backend/ocr/__init__.py @@ -6,4 +6,4 @@ Backward-compatible re-exports: consumers can still use """ from .types import * # noqa: F401,F403 -from .pipeline import * # noqa: F401,F403 +from .cv_pipeline import * # noqa: F401,F403 diff --git a/klausur-service/backend/ocr/pipeline.py b/klausur-service/backend/ocr/cv_pipeline.py similarity index 100% rename from klausur-service/backend/ocr/pipeline.py rename to klausur-service/backend/ocr/cv_pipeline.py diff --git a/klausur-service/backend/ocr/labeling/__init__.py b/klausur-service/backend/ocr/labeling/__init__.py new file mode 100644 index 0000000..41eae6a --- /dev/null +++ b/klausur-service/backend/ocr/labeling/__init__.py @@ -0,0 +1,6 @@ +""" +OCR Labeling sub-package — labeling API, models, helpers, and route handlers. + +Moved from backend/ flat modules (ocr_labeling_*.py). +Backward-compatible shim files remain at the old locations. +""" diff --git a/klausur-service/backend/ocr/labeling/api.py b/klausur-service/backend/ocr/labeling/api.py new file mode 100644 index 0000000..30a1bc6 --- /dev/null +++ b/klausur-service/backend/ocr/labeling/api.py @@ -0,0 +1,81 @@ +""" +OCR Labeling API — Barrel Re-export + +Split into: +- ocr_labeling_models.py — Pydantic models and constants +- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing +- ocr_labeling_routes.py — Session/queue/labeling route handlers +- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers + +All public names are re-exported here for backward compatibility. +""" + +# Models +from .models import ( # noqa: F401 + LOCAL_STORAGE_PATH, + SessionCreate, + SessionResponse, + ItemResponse, + ConfirmRequest, + CorrectRequest, + SkipRequest, + ExportRequest, + StatsResponse, +) + +# Helpers +from .helpers import ( # noqa: F401 + VISION_OCR_AVAILABLE, + PADDLEOCR_AVAILABLE, + TROCR_AVAILABLE, + DONUT_AVAILABLE, + MINIO_AVAILABLE, + TRAINING_EXPORT_AVAILABLE, + compute_image_hash, + run_ocr_on_image, + run_vision_ocr_wrapper, + run_paddleocr_wrapper, + run_trocr_wrapper, + run_donut_wrapper, + save_image_locally, + get_image_url, +) + +# Conditional re-exports from helpers' optional imports +try: + from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401 +except ImportError: + pass + +try: + from training_export_service import ( # noqa: F401 + TrainingExportService, + TrainingSample, + get_training_export_service, + ) +except ImportError: + pass + +try: + from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401 +except ImportError: + pass + +try: + from services.trocr_service import run_trocr_ocr # noqa: F401 +except ImportError: + pass + +try: + from services.donut_ocr_service import run_donut_ocr # noqa: F401 +except ImportError: + pass + +try: + from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401 +except ImportError: + pass + +# Routes (router is the main export for app.include_router) +from .routes import router # noqa: F401 +from .upload_routes import router as upload_router # noqa: F401 diff --git a/klausur-service/backend/ocr/labeling/helpers.py b/klausur-service/backend/ocr/labeling/helpers.py new file mode 100644 index 0000000..a4af3f1 --- /dev/null +++ b/klausur-service/backend/ocr/labeling/helpers.py @@ -0,0 +1,205 @@ +""" +OCR Labeling - Helper Functions and OCR Wrappers + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. + +DATENSCHUTZ/PRIVACY: +- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama) +- Keine Daten werden an externe Server gesendet +""" + +import os +import hashlib + +from .models import LOCAL_STORAGE_PATH + +# Try to import Vision OCR service +try: + import sys + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services')) + from vision_ocr_service import get_vision_ocr_service, VisionOCRService + VISION_OCR_AVAILABLE = True +except ImportError: + VISION_OCR_AVAILABLE = False + print("Warning: Vision OCR service not available") + +# Try to import PaddleOCR from hybrid_vocab_extractor +try: + from hybrid_vocab_extractor import run_paddle_ocr + PADDLEOCR_AVAILABLE = True +except ImportError: + PADDLEOCR_AVAILABLE = False + print("Warning: PaddleOCR not available") + +# Try to import TrOCR service +try: + from services.trocr_service import run_trocr_ocr + TROCR_AVAILABLE = True +except ImportError: + TROCR_AVAILABLE = False + print("Warning: TrOCR service not available") + +# Try to import Donut service +try: + from services.donut_ocr_service import run_donut_ocr + DONUT_AVAILABLE = True +except ImportError: + DONUT_AVAILABLE = False + print("Warning: Donut OCR service not available") + +# Try to import MinIO storage +try: + from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET + MINIO_AVAILABLE = True +except ImportError: + MINIO_AVAILABLE = False + print("Warning: MinIO storage not available, using local storage") + +# Try to import Training Export Service +try: + from training_export_service import ( + TrainingExportService, + TrainingSample, + get_training_export_service, + ) + TRAINING_EXPORT_AVAILABLE = True +except ImportError: + TRAINING_EXPORT_AVAILABLE = False + print("Warning: Training export service not available") + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def compute_image_hash(image_data: bytes) -> str: + """Compute SHA256 hash of image data.""" + return hashlib.sha256(image_data).hexdigest() + + +async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple: + """ + Run OCR on an image using the specified model. + + Models: + - llama3.2-vision:11b: Vision LLM (default, best for handwriting) + - trocr: Microsoft TrOCR (fast for printed text) + - paddleocr: PaddleOCR + LLM hybrid (4x faster) + - donut: Document Understanding Transformer (structured documents) + + Returns: + Tuple of (ocr_text, confidence) + """ + print(f"Running OCR with model: {model}") + + # Route to appropriate OCR service based on model + if model == "paddleocr": + return await run_paddleocr_wrapper(image_data, filename) + elif model == "donut": + return await run_donut_wrapper(image_data, filename) + elif model == "trocr": + return await run_trocr_wrapper(image_data, filename) + else: + # Default: Vision LLM (llama3.2-vision or similar) + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple: + """Vision LLM OCR wrapper.""" + if not VISION_OCR_AVAILABLE: + print("Vision OCR service not available") + return None, 0.0 + + try: + service = get_vision_ocr_service() + if not await service.is_available(): + print("Vision OCR service not available (is_available check failed)") + return None, 0.0 + + result = await service.extract_text( + image_data, + filename=filename, + is_handwriting=True + ) + return result.text, result.confidence + except Exception as e: + print(f"Vision OCR failed: {e}") + return None, 0.0 + + +async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple: + """PaddleOCR wrapper - uses hybrid_vocab_extractor.""" + if not PADDLEOCR_AVAILABLE: + print("PaddleOCR not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + # run_paddle_ocr returns (regions, raw_text) + regions, raw_text = run_paddle_ocr(image_data) + + if not raw_text: + print("PaddleOCR returned empty text") + return None, 0.0 + + # Calculate average confidence from regions + if regions: + avg_confidence = sum(r.confidence for r in regions) / len(regions) + else: + avg_confidence = 0.5 + + return raw_text, avg_confidence + except Exception as e: + print(f"PaddleOCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple: + """TrOCR wrapper.""" + if not TROCR_AVAILABLE: + print("TrOCR not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + text, confidence = await run_trocr_ocr(image_data) + return text, confidence + except Exception as e: + print(f"TrOCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple: + """Donut OCR wrapper.""" + if not DONUT_AVAILABLE: + print("Donut not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + text, confidence = await run_donut_ocr(image_data) + return text, confidence + except Exception as e: + print(f"Donut OCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str: + """Save image to local storage.""" + session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id) + os.makedirs(session_dir, exist_ok=True) + + filename = f"{item_id}.{extension}" + filepath = os.path.join(session_dir, filename) + + with open(filepath, 'wb') as f: + f.write(image_data) + + return filepath + + +def get_image_url(image_path: str) -> str: + """Get URL for an image.""" + # For local images, return a relative path that the frontend can use + if image_path.startswith(LOCAL_STORAGE_PATH): + relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/') + return f"/api/v1/ocr-label/images/{relative_path}" + # For MinIO images, the path is already a URL or key + return image_path diff --git a/klausur-service/backend/ocr/labeling/models.py b/klausur-service/backend/ocr/labeling/models.py new file mode 100644 index 0000000..f27601f --- /dev/null +++ b/klausur-service/backend/ocr/labeling/models.py @@ -0,0 +1,86 @@ +""" +OCR Labeling - Pydantic Models and Constants + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. +""" + +import os +from pydantic import BaseModel +from typing import Optional, Dict +from datetime import datetime + + +# Local storage path (fallback if MinIO not available) +LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling") + + +# ============================================================================= +# Pydantic Models +# ============================================================================= + +class SessionCreate(BaseModel): + name: str + source_type: str = "klausur" # klausur, handwriting_sample, scan + description: Optional[str] = None + ocr_model: Optional[str] = "llama3.2-vision:11b" + + +class SessionResponse(BaseModel): + id: str + name: str + source_type: str + description: Optional[str] + ocr_model: Optional[str] + total_items: int + labeled_items: int + confirmed_items: int + corrected_items: int + skipped_items: int + created_at: datetime + + +class ItemResponse(BaseModel): + id: str + session_id: str + session_name: str + image_path: str + image_url: Optional[str] + ocr_text: Optional[str] + ocr_confidence: Optional[float] + ground_truth: Optional[str] + status: str + metadata: Optional[Dict] + created_at: datetime + + +class ConfirmRequest(BaseModel): + item_id: str + label_time_seconds: Optional[int] = None + + +class CorrectRequest(BaseModel): + item_id: str + ground_truth: str + label_time_seconds: Optional[int] = None + + +class SkipRequest(BaseModel): + item_id: str + + +class ExportRequest(BaseModel): + export_format: str = "generic" # generic, trocr, llama_vision + session_id: Optional[str] = None + batch_id: Optional[str] = None + + +class StatsResponse(BaseModel): + total_sessions: Optional[int] = None + total_items: int + labeled_items: int + confirmed_items: int + corrected_items: int + pending_items: int + exportable_items: Optional[int] = None + accuracy_rate: float + avg_label_time_seconds: Optional[float] = None diff --git a/klausur-service/backend/ocr/labeling/routes.py b/klausur-service/backend/ocr/labeling/routes.py new file mode 100644 index 0000000..8674353 --- /dev/null +++ b/klausur-service/backend/ocr/labeling/routes.py @@ -0,0 +1,241 @@ +""" +OCR Labeling - Session and Labeling Route Handlers + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. + +Endpoints: +- POST /sessions - Create labeling session +- GET /sessions - List sessions +- GET /sessions/{id} - Get session +- GET /queue - Get labeling queue +- GET /items/{id} - Get item +- POST /confirm - Confirm OCR +- POST /correct - Correct ground truth +- POST /skip - Skip item +- GET /stats - Get statistics +""" + +from fastapi import APIRouter, HTTPException, Query +from typing import Optional, List +from datetime import datetime +import uuid + +from metrics_db import ( + create_ocr_labeling_session, + get_ocr_labeling_sessions, + get_ocr_labeling_session, + get_ocr_labeling_queue, + get_ocr_labeling_item, + confirm_ocr_label, + correct_ocr_label, + skip_ocr_item, + get_ocr_labeling_stats, +) + +from .models import ( + SessionCreate, SessionResponse, ItemResponse, + ConfirmRequest, CorrectRequest, SkipRequest, +) +from .helpers import get_image_url + + +router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) + + +# ============================================================================= +# Session Endpoints +# ============================================================================= + +@router.post("/sessions", response_model=SessionResponse) +async def create_session(session: SessionCreate): + """Create a new OCR labeling session.""" + session_id = str(uuid.uuid4()) + + success = await create_ocr_labeling_session( + session_id=session_id, + name=session.name, + source_type=session.source_type, + description=session.description, + ocr_model=session.ocr_model, + ) + + if not success: + raise HTTPException(status_code=500, detail="Failed to create session") + + return SessionResponse( + id=session_id, + name=session.name, + source_type=session.source_type, + description=session.description, + ocr_model=session.ocr_model, + total_items=0, + labeled_items=0, + confirmed_items=0, + corrected_items=0, + skipped_items=0, + created_at=datetime.utcnow(), + ) + + +@router.get("/sessions", response_model=List[SessionResponse]) +async def list_sessions(limit: int = Query(50, ge=1, le=100)): + """List all OCR labeling sessions.""" + sessions = await get_ocr_labeling_sessions(limit=limit) + + return [ + SessionResponse( + id=s['id'], + name=s['name'], + source_type=s['source_type'], + description=s.get('description'), + ocr_model=s.get('ocr_model'), + total_items=s.get('total_items', 0), + labeled_items=s.get('labeled_items', 0), + confirmed_items=s.get('confirmed_items', 0), + corrected_items=s.get('corrected_items', 0), + skipped_items=s.get('skipped_items', 0), + created_at=s.get('created_at', datetime.utcnow()), + ) + for s in sessions + ] + + +@router.get("/sessions/{session_id}", response_model=SessionResponse) +async def get_session(session_id: str): + """Get a specific OCR labeling session.""" + session = await get_ocr_labeling_session(session_id) + + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + return SessionResponse( + id=session['id'], + name=session['name'], + source_type=session['source_type'], + description=session.get('description'), + ocr_model=session.get('ocr_model'), + total_items=session.get('total_items', 0), + labeled_items=session.get('labeled_items', 0), + confirmed_items=session.get('confirmed_items', 0), + corrected_items=session.get('corrected_items', 0), + skipped_items=session.get('skipped_items', 0), + created_at=session.get('created_at', datetime.utcnow()), + ) + + +# ============================================================================= +# Queue and Item Endpoints +# ============================================================================= + +@router.get("/queue", response_model=List[ItemResponse]) +async def get_labeling_queue( + session_id: Optional[str] = Query(None), + status: str = Query("pending"), + limit: int = Query(10, ge=1, le=50), +): + """Get items from the labeling queue.""" + items = await get_ocr_labeling_queue( + session_id=session_id, + status=status, + limit=limit, + ) + + return [ + ItemResponse( + id=item['id'], + session_id=item['session_id'], + session_name=item.get('session_name', ''), + image_path=item['image_path'], + image_url=get_image_url(item['image_path']), + ocr_text=item.get('ocr_text'), + ocr_confidence=item.get('ocr_confidence'), + ground_truth=item.get('ground_truth'), + status=item.get('status', 'pending'), + metadata=item.get('metadata'), + created_at=item.get('created_at', datetime.utcnow()), + ) + for item in items + ] + + +@router.get("/items/{item_id}", response_model=ItemResponse) +async def get_item(item_id: str): + """Get a specific labeling item.""" + item = await get_ocr_labeling_item(item_id) + + if not item: + raise HTTPException(status_code=404, detail="Item not found") + + return ItemResponse( + id=item['id'], + session_id=item['session_id'], + session_name=item.get('session_name', ''), + image_path=item['image_path'], + image_url=get_image_url(item['image_path']), + ocr_text=item.get('ocr_text'), + ocr_confidence=item.get('ocr_confidence'), + ground_truth=item.get('ground_truth'), + status=item.get('status', 'pending'), + metadata=item.get('metadata'), + created_at=item.get('created_at', datetime.utcnow()), + ) + + +# ============================================================================= +# Labeling Action Endpoints +# ============================================================================= + +@router.post("/confirm") +async def confirm_item(request: ConfirmRequest): + """Confirm that OCR text is correct.""" + success = await confirm_ocr_label( + item_id=request.item_id, + labeled_by="admin", + label_time_seconds=request.label_time_seconds, + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to confirm item") + + return {"status": "confirmed", "item_id": request.item_id} + + +@router.post("/correct") +async def correct_item(request: CorrectRequest): + """Save corrected ground truth for an item.""" + success = await correct_ocr_label( + item_id=request.item_id, + ground_truth=request.ground_truth, + labeled_by="admin", + label_time_seconds=request.label_time_seconds, + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to correct item") + + return {"status": "corrected", "item_id": request.item_id} + + +@router.post("/skip") +async def skip_item(request: SkipRequest): + """Skip an item (unusable image, etc.).""" + success = await skip_ocr_item( + item_id=request.item_id, + labeled_by="admin", + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to skip item") + + return {"status": "skipped", "item_id": request.item_id} + + +@router.get("/stats") +async def get_stats(session_id: Optional[str] = Query(None)): + """Get labeling statistics.""" + stats = await get_ocr_labeling_stats(session_id=session_id) + + if "error" in stats: + raise HTTPException(status_code=500, detail=stats["error"]) + + return stats diff --git a/klausur-service/backend/ocr/labeling/upload_routes.py b/klausur-service/backend/ocr/labeling/upload_routes.py new file mode 100644 index 0000000..bad8bc5 --- /dev/null +++ b/klausur-service/backend/ocr/labeling/upload_routes.py @@ -0,0 +1,313 @@ +""" +OCR Labeling - Upload, Run-OCR, and Export Route Handlers + +Extracted from ocr_labeling_routes.py to keep files under 500 LOC. + +Endpoints: +- POST /sessions/{id}/upload - Upload images for labeling +- POST /run-ocr/{item_id} - Run OCR on existing item +- POST /export - Export training data +- GET /training-samples - List training samples +- GET /images/{path} - Serve images from local storage +- GET /exports - List exports +""" + +from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query +from typing import Optional, List +import uuid +import os + +from metrics_db import ( + get_ocr_labeling_session, + add_ocr_labeling_item, + get_ocr_labeling_item, + export_training_samples, + get_training_samples, +) + +from .models import ( + ExportRequest, + LOCAL_STORAGE_PATH, +) +from .helpers import ( + compute_image_hash, run_ocr_on_image, + save_image_locally, + MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE, +) + +# Conditional imports +try: + from minio_storage import upload_ocr_image, get_ocr_image +except ImportError: + pass + +try: + from training_export_service import TrainingSample, get_training_export_service +except ImportError: + pass + + +router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) + + +@router.post("/sessions/{session_id}/upload") +async def upload_images( + session_id: str, + files: List[UploadFile] = File(...), + run_ocr: bool = Form(True), + metadata: Optional[str] = Form(None), +): + """ + Upload images to a labeling session. + + Args: + session_id: Session to add images to + files: Image files to upload (PNG, JPG, PDF) + run_ocr: Whether to run OCR immediately (default: True) + metadata: Optional JSON metadata (subject, year, etc.) + """ + import json + + session = await get_ocr_labeling_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + meta_dict = None + if metadata: + try: + meta_dict = json.loads(metadata) + except json.JSONDecodeError: + meta_dict = {"raw": metadata} + + results = [] + ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') + + for file in files: + content = await file.read() + image_hash = compute_image_hash(content) + item_id = str(uuid.uuid4()) + + extension = file.filename.split('.')[-1].lower() if file.filename else 'png' + if extension not in ['png', 'jpg', 'jpeg', 'pdf']: + extension = 'png' + + if MINIO_AVAILABLE: + try: + image_path = upload_ocr_image(session_id, item_id, content, extension) + except Exception as e: + print(f"MinIO upload failed, using local storage: {e}") + image_path = save_image_locally(session_id, item_id, content, extension) + else: + image_path = save_image_locally(session_id, item_id, content, extension) + + ocr_text = None + ocr_confidence = None + + if run_ocr and extension != 'pdf': + ocr_text, ocr_confidence = await run_ocr_on_image( + content, + file.filename or f"{item_id}.{extension}", + model=ocr_model + ) + + success = await add_ocr_labeling_item( + item_id=item_id, + session_id=session_id, + image_path=image_path, + image_hash=image_hash, + ocr_text=ocr_text, + ocr_confidence=ocr_confidence, + ocr_model=ocr_model if ocr_text else None, + metadata=meta_dict, + ) + + if success: + results.append({ + "id": item_id, + "filename": file.filename, + "image_path": image_path, + "image_hash": image_hash, + "ocr_text": ocr_text, + "ocr_confidence": ocr_confidence, + "status": "pending", + }) + + return { + "session_id": session_id, + "uploaded_count": len(results), + "items": results, + } + + +@router.post("/export") +async def export_data(request: ExportRequest): + """Export labeled data for training.""" + db_samples = await export_training_samples( + export_format=request.export_format, + session_id=request.session_id, + batch_id=request.batch_id, + exported_by="admin", + ) + + if not db_samples: + return { + "export_format": request.export_format, + "batch_id": request.batch_id, + "exported_count": 0, + "samples": [], + "message": "No labeled samples found to export", + } + + export_result = None + if TRAINING_EXPORT_AVAILABLE: + try: + export_service = get_training_export_service() + + training_samples = [] + for s in db_samples: + training_samples.append(TrainingSample( + id=s.get('id', s.get('item_id', '')), + image_path=s.get('image_path', ''), + ground_truth=s.get('ground_truth', ''), + ocr_text=s.get('ocr_text'), + ocr_confidence=s.get('ocr_confidence'), + metadata=s.get('metadata'), + )) + + export_result = export_service.export( + samples=training_samples, + export_format=request.export_format, + batch_id=request.batch_id, + ) + except Exception as e: + print(f"Training export failed: {e}") + + response = { + "export_format": request.export_format, + "batch_id": request.batch_id or (export_result.batch_id if export_result else None), + "exported_count": len(db_samples), + "samples": db_samples, + } + + if export_result: + response["export_path"] = export_result.export_path + response["manifest_path"] = export_result.manifest_path + + return response + + +@router.get("/training-samples") +async def list_training_samples( + export_format: Optional[str] = Query(None), + batch_id: Optional[str] = Query(None), + limit: int = Query(100, ge=1, le=1000), +): + """Get exported training samples.""" + samples = await get_training_samples( + export_format=export_format, + batch_id=batch_id, + limit=limit, + ) + + return { + "count": len(samples), + "samples": samples, + } + + +@router.get("/images/{path:path}") +async def get_image(path: str): + """Serve an image from local storage.""" + from fastapi.responses import FileResponse + + filepath = os.path.join(LOCAL_STORAGE_PATH, path) + + if not os.path.exists(filepath): + raise HTTPException(status_code=404, detail="Image not found") + + extension = filepath.split('.')[-1].lower() + content_type = { + 'png': 'image/png', + 'jpg': 'image/jpeg', + 'jpeg': 'image/jpeg', + 'pdf': 'application/pdf', + }.get(extension, 'application/octet-stream') + + return FileResponse(filepath, media_type=content_type) + + +@router.post("/run-ocr/{item_id}") +async def run_ocr_for_item(item_id: str): + """Run OCR on an existing item.""" + item = await get_ocr_labeling_item(item_id) + + if not item: + raise HTTPException(status_code=404, detail="Item not found") + + image_path = item['image_path'] + + if image_path.startswith(LOCAL_STORAGE_PATH): + if not os.path.exists(image_path): + raise HTTPException(status_code=404, detail="Image file not found") + with open(image_path, 'rb') as f: + image_data = f.read() + elif MINIO_AVAILABLE: + try: + image_data = get_ocr_image(image_path) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to load image: {e}") + else: + raise HTTPException(status_code=500, detail="Cannot load image") + + session = await get_ocr_labeling_session(item['session_id']) + ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b' + + ocr_text, ocr_confidence = await run_ocr_on_image( + image_data, + os.path.basename(image_path), + model=ocr_model + ) + + if ocr_text is None: + raise HTTPException(status_code=500, detail="OCR failed") + + from metrics_db import get_pool + pool = await get_pool() + if pool: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE ocr_labeling_items + SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4 + WHERE id = $1 + """, + item_id, ocr_text, ocr_confidence, ocr_model + ) + + return { + "item_id": item_id, + "ocr_text": ocr_text, + "ocr_confidence": ocr_confidence, + "ocr_model": ocr_model, + } + + +@router.get("/exports") +async def list_exports(export_format: Optional[str] = Query(None)): + """List all available training data exports.""" + if not TRAINING_EXPORT_AVAILABLE: + return { + "exports": [], + "message": "Training export service not available", + } + + try: + export_service = get_training_export_service() + exports = export_service.list_exports(export_format=export_format) + + return { + "count": len(exports), + "exports": exports, + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}") diff --git a/klausur-service/backend/ocr/pipeline/__init__.py b/klausur-service/backend/ocr/pipeline/__init__.py new file mode 100644 index 0000000..024e886 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/__init__.py @@ -0,0 +1,8 @@ +""" +OCR Pipeline sub-package — API endpoints, session management, overlays, +geometry steps, word detection, regression testing, and related utilities. + +Moved from backend/ flat modules (ocr_pipeline_*.py, page_crop*.py, +orientation_*.py, crop_api.py, etc.). +Backward-compatible shim files remain at the old locations. +""" diff --git a/klausur-service/backend/ocr/pipeline/api.py b/klausur-service/backend/ocr/pipeline/api.py new file mode 100644 index 0000000..8300416 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/api.py @@ -0,0 +1,63 @@ +""" +OCR Pipeline API - Schrittweise Seitenrekonstruktion. + +Thin wrapper that assembles all sub-module routers into a single +composite router. Backward-compatible: main.py and tests can still +import ``router``, ``_cache``, and helper functions from here. + +Sub-modules (each < 1 000 lines): + ocr_pipeline_common – shared state, cache, Pydantic models, helpers + ocr_pipeline_sessions – session CRUD, image serving, doc-type + ocr_pipeline_geometry – deskew, dewarp, structure, columns + ocr_pipeline_rows – row detection, box-overlay helper + ocr_pipeline_words – word detection (SSE), paddle-direct, word GT + ocr_pipeline_ocr_merge – paddle/tesseract merge helpers, kombi endpoints + ocr_pipeline_postprocess – LLM review, reconstruction, export, validation + ocr_pipeline_auto – auto-mode orchestrator, reprocess + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +from fastapi import APIRouter + +# --------------------------------------------------------------------------- +# Shared state (imported by main.py and orientation_crop_api.py) +# --------------------------------------------------------------------------- +from .common import ( # noqa: F401 – re-exported + _cache, + _BORDER_GHOST_CHARS, + _filter_border_ghost_words, +) + +# --------------------------------------------------------------------------- +# Sub-module routers +# --------------------------------------------------------------------------- +from .sessions import router as _sessions_router +from .geometry import router as _geometry_router +from .rows import router as _rows_router +from .words import router as _words_router +from .ocr_merge import ( + router as _ocr_merge_router, + # Re-export for test backward compatibility + _split_paddle_multi_words, # noqa: F401 + _group_words_into_rows, # noqa: F401 + _merge_row_sequences, # noqa: F401 + _merge_paddle_tesseract, # noqa: F401 +) +from .postprocess import router as _postprocess_router +from .auto import router as _auto_router +from .regression import router as _regression_router + +# --------------------------------------------------------------------------- +# Composite router (used by main.py) +# --------------------------------------------------------------------------- +router = APIRouter() +router.include_router(_sessions_router) +router.include_router(_geometry_router) +router.include_router(_rows_router) +router.include_router(_words_router) +router.include_router(_ocr_merge_router) +router.include_router(_postprocess_router) +router.include_router(_auto_router) +router.include_router(_regression_router) diff --git a/klausur-service/backend/ocr/pipeline/auto.py b/klausur-service/backend/ocr/pipeline/auto.py new file mode 100644 index 0000000..7b28616 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/auto.py @@ -0,0 +1,23 @@ +""" +OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export. + +Split into submodules: +- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess +- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +from fastapi import APIRouter + +from .reprocess import router as _reprocess_router +from .auto_steps import router as _steps_router + +# Combine both sub-routers into a single router for backwards compatibility. +# The consumer imports `from ocr_pipeline_auto import router as _auto_router`. +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) +router.include_router(_reprocess_router) +router.include_router(_steps_router) + +__all__ = ["router"] diff --git a/klausur-service/backend/ocr/pipeline/auto_helpers.py b/klausur-service/backend/ocr/pipeline/auto_helpers.py new file mode 100644 index 0000000..05df86d --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/auto_helpers.py @@ -0,0 +1,84 @@ +""" +OCR Pipeline Auto-Mode Helpers. + +VLM shear detection, SSE event formatting, and request models. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import re +from typing import Any, Dict + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class RunAutoRequest(BaseModel): + from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review + ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" + pronunciation: str = "british" + skip_llm_review: bool = False + dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" + + +async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: + """Format a single SSE event line.""" + payload = {"step": step, "status": status, **data} + return f"data: {json.dumps(payload)}\n\n" + + +async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} diff --git a/klausur-service/backend/ocr/pipeline/auto_steps.py b/klausur-service/backend/ocr/pipeline/auto_steps.py new file mode 100644 index 0000000..86c7897 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/auto_steps.py @@ -0,0 +1,528 @@ +""" +OCR Pipeline Auto-Mode Orchestrator. + +POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from dataclasses import asdict +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _detect_header_footer_gaps, + _detect_sub_columns, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + analyze_layout, + build_cell_grid, + classify_column_types, + create_layout_image, + create_ocr_image, + deskew_image, + deskew_image_by_word_alignment, + detect_column_geometry, + detect_row_geometry, + _apply_shear, + dewarp_image, + llm_review_entries, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, +) +from .session_store import ( + get_session_db, + update_session_db, +) +from .auto_helpers import ( + RunAutoRequest, + auto_sse_event as _auto_sse_event, + detect_shear_with_vlm as _detect_shear_with_vlm, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["ocr-pipeline"]) + +@router.post("/sessions/{session_id}/run-auto") +async def run_auto(session_id: str, req: RunAutoRequest, request: Request): + """Run the full OCR pipeline automatically from a given step, streaming SSE progress. + + Steps: + 1. Deskew -- straighten the scan + 2. Dewarp -- correct vertical shear (ensemble CV or VLM) + 3. Columns -- detect column layout + 4. Rows -- detect row layout + 5. Words -- OCR each cell + 6. LLM review -- correct OCR errors (optional) + + Already-completed steps are skipped unless `from_step` forces a rerun. + Yields SSE events of the form: + data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} + + Final event: + data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} + """ + if req.from_step < 1 or req.from_step > 6: + raise HTTPException(status_code=400, detail="from_step must be 1-6") + if req.dewarp_method not in ("ensemble", "vlm", "cv"): + raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + + async def _generate(): + steps_run: List[str] = [] + steps_skipped: List[str] = [] + error_step: Optional[str] = None + + session = await get_session_db(session_id) + if not session: + yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) + return + + cached = _get_cached(session_id) + + # Step 1: Deskew + if req.from_step <= 1: + yield await _auto_sse_event("deskew", "start", {}) + try: + t0 = time.time() + orig_bgr = cached.get("original_bgr") + if orig_bgr is None: + raise ValueError("Original image not loaded") + + try: + deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) + except Exception: + deskewed_hough, angle_hough = orig_bgr, 0.0 + + success_enc, png_orig = cv2.imencode(".png", orig_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 + + if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: + method_used = "word_alignment" + angle_applied = angle_wa + wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) + deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) + if deskewed_bgr is None: + deskewed_bgr = deskewed_hough + method_used = "hough" + angle_applied = angle_hough + else: + method_used = "hough" + angle_applied = angle_hough + deskewed_bgr = deskewed_hough + + success, png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = png_buf.tobytes() if success else b"" + + deskew_result = { + "method_used": method_used, + "rotation_degrees": round(float(angle_applied), 3), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["deskewed_bgr"] = deskewed_bgr + cached["deskew_result"] = deskew_result + await update_session_db( + session_id, + deskewed_png=deskewed_png, + deskew_result=deskew_result, + auto_rotation_degrees=float(angle_applied), + current_step=3, + ) + session = await get_session_db(session_id) + + steps_run.append("deskew") + yield await _auto_sse_event("deskew", "done", deskew_result) + except Exception as e: + logger.error(f"Auto-mode deskew failed for {session_id}: {e}") + error_step = "deskew" + yield await _auto_sse_event("deskew", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("deskew") + yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) + + # Step 2: Dewarp + if req.from_step <= 2: + yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) + try: + t0 = time.time() + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise ValueError("Deskewed image not available") + + if req.dewarp_method == "vlm": + success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success_enc else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success_enc else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(time.time() - t0, 2), + "detections": dewarp_info.get("detections", []), + } + + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=4, + ) + session = await get_session_db(session_id) + + steps_run.append("dewarp") + yield await _auto_sse_event("dewarp", "done", dewarp_result) + except Exception as e: + logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") + error_step = "dewarp" + yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("dewarp") + yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) + + # Step 3: Columns + if req.from_step <= 3: + yield await _auto_sse_event("columns", "start", {}) + try: + t0 = time.time() + col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if col_img is None: + raise ValueError("Cropped/dewarped image not available") + + ocr_img = create_ocr_image(col_img) + h, w = ocr_img.shape[:2] + + geo_result = detect_column_geometry(ocr_img, col_img) + if geo_result is None: + layout_img = create_layout_image(col_img) + regions = analyze_layout(layout_img, ocr_img) + cached["_word_dicts"] = None + cached["_inv"] = None + cached["_content_bounds"] = None + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + content_w = right_x - left_x + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + columns = [asdict(r) for r in regions] + column_result = { + "columns": columns, + "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["column_result"] = column_result + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None, current_step=6) + session = await get_session_db(session_id) + + steps_run.append("columns") + yield await _auto_sse_event("columns", "done", { + "column_count": len(columns), + "duration_seconds": column_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode columns failed for {session_id}: {e}") + error_step = "columns" + yield await _auto_sse_event("columns", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("columns") + yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) + + # Step 4: Rows + if req.from_step <= 4: + yield await _auto_sse_event("rows", "start", {}) + try: + t0 = time.time() + row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + column_result = session.get("column_result") or cached.get("column_result") + if not column_result or not column_result.get("columns"): + raise ValueError("Column detection must complete first") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + ocr_img_tmp = create_ocr_image(row_img) + geo_result = detect_column_geometry(ocr_img_tmp, row_img) + if geo_result is None: + raise ValueError("Column geometry detection failed -- cannot detect rows") + _g, lx, rx, ty, by, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (lx, rx, ty, by) + content_bounds = (lx, rx, ty, by) + + left_x, right_x, top_y, bottom_y = content_bounds + row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + row_list = [ + { + "index": r.index, "x": r.x, "y": r.y, + "width": r.width, "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + } + for r in row_geoms + ] + row_result = { + "rows": row_list, + "row_count": len(row_list), + "content_rows": len([r for r in row_geoms if r.row_type == "content"]), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["row_result"] = row_result + await update_session_db(session_id, row_result=row_result, current_step=7) + session = await get_session_db(session_id) + + steps_run.append("rows") + yield await _auto_sse_event("rows", "done", { + "row_count": len(row_list), + "content_rows": row_result["content_rows"], + "duration_seconds": row_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode rows failed for {session_id}: {e}") + error_step = "rows" + yield await _auto_sse_event("rows", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("rows") + yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) + + # Step 5: Words (OCR) + if req.from_step <= 5: + yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) + try: + t0 = time.time() + word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + + column_result = session.get("column_result") or cached.get("column_result") + row_result = session.get("row_result") or cached.get("row_result") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + row_geoms = [ + RowGeometry( + index=r["index"], x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + word_dicts = cached.get("_word_dicts") + if word_dicts is not None: + content_bounds = cached.get("_content_bounds") + top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + ocr_img = create_ocr_image(word_img) + img_h, img_w = word_img.shape[:2] + + cells, columns_meta = build_cell_grid( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=req.ocr_engine, img_bgr=word_img, + ) + duration = time.time() - t0 + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine + + fix_cell_phonetics(cells, pronunciation=req.pronunciation) + + word_result_data = { + "cells": cells, + "grid_shape": { + "rows": n_content_rows, + "cols": len(columns_meta), + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) + word_result_data["vocab_entries"] = entries + word_result_data["entries"] = entries + word_result_data["entry_count"] = len(entries) + word_result_data["summary"]["total_entries"] = len(entries) + + await update_session_db(session_id, word_result=word_result_data, current_step=8) + cached["word_result"] = word_result_data + session = await get_session_db(session_id) + + steps_run.append("words") + yield await _auto_sse_event("words", "done", { + "total_cells": len(cells), + "layout": word_result_data["layout"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": word_result_data["summary"], + }) + except Exception as e: + logger.error(f"Auto-mode words failed for {session_id}: {e}") + error_step = "words" + yield await _auto_sse_event("words", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("words") + yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) + + # Step 6: LLM Review (optional) + if req.from_step <= 6 and not req.skip_llm_review: + yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) + try: + session = await get_session_db(session_id) + word_result = session.get("word_result") or cached.get("word_result") + entries = word_result.get("entries") or word_result.get("vocab_entries") or [] + + if not entries: + yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) + steps_skipped.append("llm_review") + else: + reviewed = await llm_review_entries(entries) + + session = await get_session_db(session_id) + word_result_updated = dict(session.get("word_result") or {}) + word_result_updated["entries"] = reviewed + word_result_updated["vocab_entries"] = reviewed + word_result_updated["llm_reviewed"] = True + word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL + + await update_session_db(session_id, word_result=word_result_updated, current_step=9) + cached["word_result"] = word_result_updated + + steps_run.append("llm_review") + yield await _auto_sse_event("llm_review", "done", { + "entries_reviewed": len(reviewed), + "model": OLLAMA_REVIEW_MODEL, + }) + except Exception as e: + logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") + yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) + steps_skipped.append("llm_review") + else: + steps_skipped.append("llm_review") + reason = "skipped by request" if req.skip_llm_review else "from_step > 6" + yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) + + # Final event + yield await _auto_sse_event("complete", "done", { + "steps_run": steps_run, + "steps_skipped": steps_skipped, + }) + + return StreamingResponse( + _generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/klausur-service/backend/ocr/pipeline/columns.py b/klausur-service/backend/ocr/pipeline/columns.py new file mode 100644 index 0000000..6fa2672 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/columns.py @@ -0,0 +1,293 @@ +""" +OCR Pipeline Column Detection Endpoints (Step 5) + +Detect invisible columns, manual column override, and ground truth. +Extracted from ocr_pipeline_geometry.py for file-size compliance. +""" + +import logging +import time +from dataclasses import asdict +from datetime import datetime +from typing import Dict, List + +import cv2 +from fastapi import APIRouter, HTTPException + +from cv_vocab_pipeline import ( + _detect_header_footer_gaps, + _detect_sub_columns, + classify_column_types, + create_layout_image, + create_ocr_image, + analyze_layout, + detect_column_geometry_zoned, + expand_narrow_columns, +) +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, + ManualColumnsRequest, + ColumnGroundTruthRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +@router.post("/sessions/{session_id}/columns") +async def detect_columns(session_id: str): + """Run column detection on the cropped (or dewarped) image.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection") + + # ----------------------------------------------------------------------- + # Sub-sessions (box crops): skip column detection entirely. + # Instead, create a single pseudo-column spanning the full image width. + # Also run Tesseract + binarization here so that the row detection step + # can reuse the cached intermediates (_word_dicts, _inv, _content_bounds) + # instead of falling back to detect_column_geometry() which may fail + # on small box images with < 5 words. + # ----------------------------------------------------------------------- + session = await get_session_db(session_id) + if session and session.get("parent_session_id"): + h, w = img_bgr.shape[:2] + + # Binarize + invert for row detection (horizontal projection profile) + ocr_img = create_ocr_image(img_bgr) + inv = cv2.bitwise_not(ocr_img) + + # Run Tesseract to get word bounding boxes. + try: + from PIL import Image as PILImage + pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) + import pytesseract + data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT) + word_dicts = [] + for i in range(len(data['text'])): + conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1 + text = str(data['text'][i]).strip() + if conf < 30 or not text: + continue + word_dicts.append({ + 'text': text, 'conf': conf, + 'left': int(data['left'][i]), + 'top': int(data['top'][i]), + 'width': int(data['width'][i]), + 'height': int(data['height'][i]), + }) + # Log all words including low-confidence ones for debugging + all_count = sum(1 for i in range(len(data['text'])) + if str(data['text'][i]).strip()) + low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) + for i in range(len(data['text'])) + if str(data['text'][i]).strip() + and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30 + and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0] + if low_conf: + logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}") + logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)") + except Exception as e: + logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}") + word_dicts = [] + + # Cache intermediates for row detection (detect_rows reuses these) + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (0, w, 0, h) + + column_result = { + "columns": [{ + "type": "column_text", + "x": 0, "y": 0, + "width": w, "height": h, + }], + "zones": None, + "boxes_detected": 0, + "duration_seconds": 0, + "method": "sub_session_pseudo_column", + } + await update_session_db( + session_id, + column_result=column_result, + row_result=None, + word_result=None, + current_step=6, + ) + cached["column_result"] = column_result + cached.pop("row_result", None) + cached.pop("word_result", None) + logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px") + return {"session_id": session_id, **column_result} + + t0 = time.time() + + # Binarized image for layout analysis + ocr_img = create_ocr_image(img_bgr) + h, w = ocr_img.shape[:2] + + # Phase A: Zone-aware geometry detection + zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr) + + boxes_detected = 0 + if zoned_result is None: + # Fallback to projection-based layout + layout_img = create_layout_image(img_bgr) + regions = analyze_layout(layout_img, ocr_img) + zones_data = None + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result + content_w = right_x - left_x + boxes_detected = len(boxes) + + # Cache intermediates for row detection (avoids second Tesseract run) + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + cached["_zones_data"] = zones_data + cached["_boxes_detected"] = boxes_detected + + # Detect header/footer early so sub-column clustering ignores them + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + + # Split sub-columns (e.g. page references) before classification + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + + # Expand narrow columns (sub-columns are often very narrow) + geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts) + + # Phase B: Content-based classification + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + duration = time.time() - t0 + + columns = [asdict(r) for r in regions] + + # Determine classification methods used + methods = list(set( + c.get("classification_method", "") for c in columns + if c.get("classification_method") + )) + + column_result = { + "columns": columns, + "classification_methods": methods, + "duration_seconds": round(duration, 2), + "boxes_detected": boxes_detected, + } + + # Add zone data when boxes are present + if zones_data and boxes_detected > 0: + column_result["zones"] = zones_data + + # Persist to DB -- also invalidate downstream results (rows, words) + await update_session_db( + session_id, + column_result=column_result, + row_result=None, + word_result=None, + current_step=6, + ) + + # Update cache + cached["column_result"] = column_result + cached.pop("row_result", None) + cached.pop("word_result", None) + + col_count = len([c for c in columns if c["type"].startswith("column")]) + logger.info(f"OCR Pipeline: columns session {session_id}: " + f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)") + + img_w = img_bgr.shape[1] + await _append_pipeline_log(session_id, "columns", { + "total_columns": len(columns), + "column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns], + "column_types": [c["type"] for c in columns], + "boxes_detected": boxes_detected, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **column_result, + } + + +@router.post("/sessions/{session_id}/columns/manual") +async def set_manual_columns(session_id: str, req: ManualColumnsRequest): + """Override detected columns with manual definitions.""" + column_result = { + "columns": req.columns, + "duration_seconds": 0, + "method": "manual", + } + + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None) + + if session_id in _cache: + _cache[session_id]["column_result"] = column_result + _cache[session_id].pop("row_result", None) + _cache[session_id].pop("word_result", None) + + logger.info(f"OCR Pipeline: manual columns session {session_id}: " + f"{len(req.columns)} columns set") + + return {"session_id": session_id, **column_result} + + +@router.post("/sessions/{session_id}/ground-truth/columns") +async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): + """Save ground truth feedback for the column detection step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_columns": req.corrected_columns, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "column_result": session.get("column_result"), + } + ground_truth["columns"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@router.get("/sessions/{session_id}/ground-truth/columns") +async def get_column_ground_truth(session_id: str): + """Retrieve saved ground truth for column detection, including auto vs GT diff.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + columns_gt = ground_truth.get("columns") + if not columns_gt: + raise HTTPException(status_code=404, detail="No column ground truth saved") + + return { + "session_id": session_id, + "columns_gt": columns_gt, + "columns_auto": session.get("column_result"), + } diff --git a/klausur-service/backend/ocr/pipeline/common.py b/klausur-service/backend/ocr/pipeline/common.py new file mode 100644 index 0000000..9aac36b --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/common.py @@ -0,0 +1,354 @@ +""" +Shared common module for the OCR pipeline. + +Contains in-memory cache, helper functions, Pydantic request models, +pipeline logging, and border-ghost word filtering used by the pipeline +API endpoints and related modules. +""" + +import logging +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import HTTPException +from pydantic import BaseModel + +from .session_store import get_session_db, get_session_image, update_session_db + +__all__ = [ + # Cache + "_cache", + # Helper functions + "_get_base_image_png", + "_load_session_to_cache", + "_get_cached", + # Pydantic models + "ManualDeskewRequest", + "DeskewGroundTruthRequest", + "ManualDewarpRequest", + "CombinedAdjustRequest", + "DewarpGroundTruthRequest", + "VALID_DOCUMENT_CATEGORIES", + "UpdateSessionRequest", + "ManualColumnsRequest", + "ColumnGroundTruthRequest", + "ManualRowsRequest", + "RowGroundTruthRequest", + "RemoveHandwritingRequest", + # Pipeline log + "_append_pipeline_log", + # Border-ghost filter + "_BORDER_GHOST_CHARS", + "_filter_border_ghost_words", +] + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# In-memory cache for active sessions (BGR numpy arrays for processing) +# DB is source of truth, cache holds BGR arrays during active processing. +# --------------------------------------------------------------------------- + +_cache: Dict[str, Dict[str, Any]] = {} + + +async def _get_base_image_png(session_id: str) -> Optional[bytes]: + """Get the best available base image for a session (cropped > dewarped > original).""" + for img_type in ("cropped", "dewarped", "original"): + png_data = await get_session_image(session_id, img_type) + if png_data: + return png_data + return None + + +async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: + """Load session from DB into cache, decoding PNGs to BGR arrays.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + if session_id in _cache: + return _cache[session_id] + + cache_entry: Dict[str, Any] = { + "id": session_id, + **session, + "original_bgr": None, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + } + + # Decode images from DB into BGR numpy arrays + for img_type, bgr_key in [ + ("original", "original_bgr"), + ("oriented", "oriented_bgr"), + ("cropped", "cropped_bgr"), + ("deskewed", "deskewed_bgr"), + ("dewarped", "dewarped_bgr"), + ]: + png_data = await get_session_image(session_id, img_type) + if png_data: + arr = np.frombuffer(png_data, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + cache_entry[bgr_key] = bgr + + # Sub-sessions: original image IS the cropped box region. + # Promote original_bgr to cropped_bgr so downstream steps find it. + if session.get("parent_session_id") and cache_entry["original_bgr"] is not None: + if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None: + cache_entry["cropped_bgr"] = cache_entry["original_bgr"] + + _cache[session_id] = cache_entry + return cache_entry + + +def _get_cached(session_id: str) -> Dict[str, Any]: + """Get from cache or raise 404.""" + entry = _cache.get(session_id) + if not entry: + raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") + return entry + + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +class ManualDeskewRequest(BaseModel): + angle: float + + +class DeskewGroundTruthRequest(BaseModel): + is_correct: bool + corrected_angle: Optional[float] = None + notes: Optional[str] = None + + +class ManualDewarpRequest(BaseModel): + shear_degrees: float + + +class CombinedAdjustRequest(BaseModel): + rotation_degrees: float = 0.0 + shear_degrees: float = 0.0 + + +class DewarpGroundTruthRequest(BaseModel): + is_correct: bool + corrected_shear: Optional[float] = None + notes: Optional[str] = None + + +VALID_DOCUMENT_CATEGORIES = { + 'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite', + 'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges', +} + + +class UpdateSessionRequest(BaseModel): + name: Optional[str] = None + document_category: Optional[str] = None + + +class ManualColumnsRequest(BaseModel): + columns: List[Dict[str, Any]] + + +class ColumnGroundTruthRequest(BaseModel): + is_correct: bool + corrected_columns: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +class ManualRowsRequest(BaseModel): + rows: List[Dict[str, Any]] + + +class RowGroundTruthRequest(BaseModel): + is_correct: bool + corrected_rows: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +class RemoveHandwritingRequest(BaseModel): + method: str = "auto" # "auto" | "telea" | "ns" + target_ink: str = "all" # "all" | "colored" | "pencil" + dilation: int = 2 # mask dilation iterations (0-5) + use_source: str = "auto" # "original" | "deskewed" | "auto" + + +# --------------------------------------------------------------------------- +# Pipeline Log Helper +# --------------------------------------------------------------------------- + +async def _append_pipeline_log( + session_id: str, + step_name: str, + metrics: Dict[str, Any], + success: bool = True, + duration_ms: Optional[int] = None, +): + """Append a step entry to the session's pipeline_log JSONB.""" + session = await get_session_db(session_id) + if not session: + return + log = session.get("pipeline_log") or {"steps": []} + if not isinstance(log, dict): + log = {"steps": []} + entry = { + "step": step_name, + "completed_at": datetime.utcnow().isoformat(), + "success": success, + "metrics": metrics, + } + if duration_ms is not None: + entry["duration_ms"] = duration_ms + log.setdefault("steps", []).append(entry) + await update_session_db(session_id, pipeline_log=log) + + +# --------------------------------------------------------------------------- +# Border-ghost word filter +# --------------------------------------------------------------------------- + +# Characters that OCR produces when reading box-border lines. +_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"") + + +def _filter_border_ghost_words( + word_result: Dict, + boxes: List, +) -> int: + """Remove OCR words that are actually box border lines. + + A word is considered a border ghost when it sits on a known box edge + (left, right, top, or bottom) and looks like a line artefact (narrow + aspect ratio or text consists only of line-like characters). + + After removing ghost cells, columns that have become empty are also + removed from ``columns_used`` so the grid no longer shows phantom + columns. + + Modifies *word_result* in-place and returns the number of removed cells. + """ + if not boxes or not word_result: + return 0 + + cells = word_result.get("cells") + if not cells: + return 0 + + # Build border bands — vertical (X) and horizontal (Y) + x_bands = [] # list of (x_lo, x_hi) + y_bands = [] # list of (y_lo, y_hi) + for b in boxes: + bx = b.x if hasattr(b, "x") else b.get("x", 0) + by = b.y if hasattr(b, "y") else b.get("y", 0) + bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) + bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) + bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3) + margin = max(bt * 2, 10) + 6 # generous margin + + # Vertical edges (left / right) + x_bands.append((bx - margin, bx + margin)) + x_bands.append((bx + bw - margin, bx + bw + margin)) + # Horizontal edges (top / bottom) + y_bands.append((by - margin, by + margin)) + y_bands.append((by + bh - margin, by + bh + margin)) + + img_w = word_result.get("image_width", 1) + img_h = word_result.get("image_height", 1) + + def _is_ghost(cell: Dict) -> bool: + text = (cell.get("text") or "").strip() + if not text: + return False + + # Compute absolute pixel position + if cell.get("bbox_px"): + px = cell["bbox_px"] + cx = px["x"] + px["w"] / 2 + cy = px["y"] + px["h"] / 2 + cw = px["w"] + ch = px["h"] + elif cell.get("bbox_pct"): + pct = cell["bbox_pct"] + cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2 + cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2 + cw = (pct["w"] / 100) * img_w + ch = (pct["h"] / 100) * img_h + else: + return False + + # Check if center sits on a vertical or horizontal border + on_vertical = any(lo <= cx <= hi for lo, hi in x_bands) + on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands) + if not on_vertical and not on_horizontal: + return False + + # Very short text (1-2 chars) on a border → very likely ghost + if len(text) <= 2: + # Narrow vertically (line-like) or narrow horizontally (dash-like)? + if ch > 0 and cw / ch < 0.5: + return True + if cw > 0 and ch / cw < 0.5: + return True + # Text is only border-ghost characters? + if all(c in _BORDER_GHOST_CHARS for c in text): + return True + + # Longer text but still only ghost chars and very narrow + if all(c in _BORDER_GHOST_CHARS for c in text): + if ch > 0 and cw / ch < 0.35: + return True + if cw > 0 and ch / cw < 0.35: + return True + return True # all ghost chars on a border → remove + + return False + + before = len(cells) + word_result["cells"] = [c for c in cells if not _is_ghost(c)] + removed = before - len(word_result["cells"]) + + # --- Remove empty columns from columns_used --- + columns_used = word_result.get("columns_used") + if removed and columns_used and len(columns_used) > 1: + remaining_cells = word_result["cells"] + occupied_cols = {c.get("col_index") for c in remaining_cells} + before_cols = len(columns_used) + columns_used = [col for col in columns_used if col.get("index") in occupied_cols] + + # Re-index columns and remap cell col_index values + if len(columns_used) < before_cols: + old_to_new = {} + for new_i, col in enumerate(columns_used): + old_to_new[col["index"]] = new_i + col["index"] = new_i + for cell in remaining_cells: + old_ci = cell.get("col_index") + if old_ci in old_to_new: + cell["col_index"] = old_to_new[old_ci] + word_result["columns_used"] = columns_used + logger.info("border-ghost: removed %d empty column(s), %d remaining", + before_cols - len(columns_used), len(columns_used)) + + if removed: + # Update summary counts + summary = word_result.get("summary", {}) + summary["total_cells"] = len(word_result["cells"]) + summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text")) + word_result["summary"] = summary + gs = word_result.get("grid_shape", {}) + gs["total_cells"] = len(word_result["cells"]) + if columns_used is not None: + gs["cols"] = len(columns_used) + word_result["grid_shape"] = gs + + return removed diff --git a/klausur-service/backend/ocr/pipeline/crop_api.py b/klausur-service/backend/ocr/pipeline/crop_api.py new file mode 100644 index 0000000..e4cdbd3 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/crop_api.py @@ -0,0 +1,290 @@ +""" +Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline). + +Auto-crop, manual crop, and skip-crop for scanner/book borders. +""" + +import logging +import time +from typing import Any, Dict + +import cv2 +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from .page_crop import detect_and_crop_page, detect_page_splits +from .session_store import get_sub_sessions, update_session_db + +from .orientation_crop_helpers import ensure_cached, append_pipeline_log +from .page_sub_sessions import create_page_sub_sessions + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 4 (UI index 3): Crop — runs after deskew + dewarp +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/crop") +async def auto_crop(session_id: str): + """Auto-detect and crop scanner/book borders. + + Reads the dewarped image (post-deskew + dewarp, so the page is straight). + Falls back to oriented -> original if earlier steps were skipped. + + If the image is a multi-page spread (e.g. book on scanner), it will + automatically split into separate sub-sessions per page, crop each + individually, and return the split info. + """ + cached = await ensure_cached(session_id) + + # Use dewarped (preferred), fall back to oriented, then original + img_bgr = next( + (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), + None, + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for cropping") + + t0 = time.time() + + # --- Check for existing sub-sessions (from page-split step) --- + # If page-split already created sub-sessions, skip multi-page detection + # in the crop step. Each sub-session runs its own crop independently. + existing_subs = await get_sub_sessions(session_id) + if existing_subs: + crop_result = cached.get("crop_result") or {} + if crop_result.get("multi_page"): + # Already split -- just return the existing info + duration = time.time() - t0 + h, w = img_bgr.shape[:2] + return { + "session_id": session_id, + **crop_result, + "image_width": w, + "image_height": h, + "sub_sessions": [ + {"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)} + for i, s in enumerate(existing_subs) + ], + "note": "Page split was already performed; each sub-session runs its own crop.", + } + + # --- Multi-page detection (fallback for sessions that skipped page-split) --- + page_splits = detect_page_splits(img_bgr) + + if page_splits and len(page_splits) >= 2: + # Multi-page spread detected -- create sub-sessions + sub_sessions = await create_page_sub_sessions( + session_id, cached, img_bgr, page_splits, + ) + duration = time.time() - t0 + + crop_info: Dict[str, Any] = { + "crop_applied": True, + "multi_page": True, + "page_count": len(page_splits), + "page_splits": page_splits, + "duration_seconds": round(duration, 2), + } + cached["crop_result"] = crop_info + + # Store the first page as the main cropped image for backward compat + first_page = page_splits[0] + first_bgr = img_bgr[ + first_page["y"]:first_page["y"] + first_page["height"], + first_page["x"]:first_page["x"] + first_page["width"], + ].copy() + first_cropped, _ = detect_and_crop_page(first_bgr) + cached["cropped_bgr"] = first_cropped + + ok, png_buf = cv2.imencode(".png", first_cropped) + await update_session_db( + session_id, + cropped_png=png_buf.tobytes() if ok else b"", + crop_result=crop_info, + current_step=5, + status='split', + ) + + logger.info( + "OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs", + session_id, len(page_splits), duration, + ) + + await append_pipeline_log(session_id, "crop", { + "multi_page": True, + "page_count": len(page_splits), + }, duration_ms=int(duration * 1000)) + + h, w = first_cropped.shape[:2] + return { + "session_id": session_id, + **crop_info, + "image_width": w, + "image_height": h, + "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", + "sub_sessions": sub_sessions, + } + + # --- Single page (normal) --- + cropped_bgr, crop_info = detect_and_crop_page(img_bgr) + + duration = time.time() - t0 + crop_info["duration_seconds"] = round(duration, 2) + crop_info["multi_page"] = False + + # Encode cropped image + success, png_buf = cv2.imencode(".png", cropped_bgr) + cropped_png = png_buf.tobytes() if success else b"" + + # Update cache + cached["cropped_bgr"] = cropped_bgr + cached["crop_result"] = crop_info + + # Persist to DB + await update_session_db( + session_id, + cropped_png=cropped_png, + crop_result=crop_info, + current_step=5, + ) + + logger.info( + "OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs", + session_id, crop_info["crop_applied"], + crop_info.get("detected_format", "?"), + duration, + ) + + await append_pipeline_log(session_id, "crop", { + "crop_applied": crop_info["crop_applied"], + "detected_format": crop_info.get("detected_format"), + "format_confidence": crop_info.get("format_confidence"), + }, duration_ms=int(duration * 1000)) + + h, w = cropped_bgr.shape[:2] + return { + "session_id": session_id, + **crop_info, + "image_width": w, + "image_height": h, + "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", + } + + +class ManualCropRequest(BaseModel): + x: float # percentage 0-100 + y: float # percentage 0-100 + width: float # percentage 0-100 + height: float # percentage 0-100 + + +@router.post("/sessions/{session_id}/crop/manual") +async def manual_crop(session_id: str, req: ManualCropRequest): + """Manually crop using percentage coordinates.""" + cached = await ensure_cached(session_id) + + img_bgr = next( + (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), + None, + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for cropping") + + h, w = img_bgr.shape[:2] + + # Convert percentages to pixels + px_x = int(w * req.x / 100.0) + px_y = int(h * req.y / 100.0) + px_w = int(w * req.width / 100.0) + px_h = int(h * req.height / 100.0) + + # Clamp + px_x = max(0, min(px_x, w - 1)) + px_y = max(0, min(px_y, h - 1)) + px_w = max(1, min(px_w, w - px_x)) + px_h = max(1, min(px_h, h - px_y)) + + cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy() + + success, png_buf = cv2.imencode(".png", cropped_bgr) + cropped_png = png_buf.tobytes() if success else b"" + + crop_result = { + "crop_applied": True, + "crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h}, + "crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2), + "width": round(req.width, 2), "height": round(req.height, 2)}, + "original_size": {"width": w, "height": h}, + "cropped_size": {"width": px_w, "height": px_h}, + "method": "manual", + } + + cached["cropped_bgr"] = cropped_bgr + cached["crop_result"] = crop_result + + await update_session_db( + session_id, + cropped_png=cropped_png, + crop_result=crop_result, + current_step=5, + ) + + ch, cw = cropped_bgr.shape[:2] + return { + "session_id": session_id, + **crop_result, + "image_width": cw, + "image_height": ch, + "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", + } + + +@router.post("/sessions/{session_id}/crop/skip") +async def skip_crop(session_id: str): + """Skip cropping -- use dewarped (or oriented/original) image as-is.""" + cached = await ensure_cached(session_id) + + img_bgr = next( + (v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), + None, + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available") + + h, w = img_bgr.shape[:2] + + # Store the dewarped image as cropped (identity crop) + success, png_buf = cv2.imencode(".png", img_bgr) + cropped_png = png_buf.tobytes() if success else b"" + + crop_result = { + "crop_applied": False, + "skipped": True, + "original_size": {"width": w, "height": h}, + "cropped_size": {"width": w, "height": h}, + } + + cached["cropped_bgr"] = img_bgr + cached["crop_result"] = crop_result + + await update_session_db( + session_id, + cropped_png=cropped_png, + crop_result=crop_result, + current_step=5, + ) + + return { + "session_id": session_id, + **crop_result, + "image_width": w, + "image_height": h, + "cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped", + } diff --git a/klausur-service/backend/ocr/pipeline/deskew.py b/klausur-service/backend/ocr/pipeline/deskew.py new file mode 100644 index 0000000..1caeec2 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/deskew.py @@ -0,0 +1,236 @@ +""" +OCR Pipeline Deskew Endpoints (Step 2) + +Auto deskew, manual deskew, and ground truth for the deskew step. +Extracted from ocr_pipeline_geometry.py for file-size compliance. +""" + +import logging +import time +from datetime import datetime + +import cv2 +from fastapi import APIRouter, HTTPException + +from cv_vocab_pipeline import ( + create_ocr_image, + deskew_image, + deskew_image_by_word_alignment, + deskew_two_pass, +) +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, + ManualDeskewRequest, + DeskewGroundTruthRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +@router.post("/sessions/{session_id}/deskew") +async def auto_deskew(session_id: str): + """Two-pass deskew: iterative projection (wide range) + word-alignment residual.""" + # Ensure session is in cache + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + # Deskew runs right after orientation -- use oriented image, fall back to original + img_bgr = next((v for k in ("oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), None) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for deskewing") + + t0 = time.time() + + # Two-pass deskew: iterative (+-5 deg) + word-alignment residual check + deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy()) + + # Also run individual methods for reporting (non-authoritative) + try: + _, angle_hough = deskew_image(img_bgr.copy()) + except Exception: + angle_hough = 0.0 + + success_enc, png_orig = cv2.imencode(".png", img_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + _, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + angle_wa = 0.0 + + angle_iterative = two_pass_debug.get("pass1_angle", 0.0) + angle_residual = two_pass_debug.get("pass2_angle", 0.0) + angle_textline = two_pass_debug.get("pass3_angle", 0.0) + + duration = time.time() - t0 + + method_used = "three_pass" if abs(angle_textline) >= 0.01 else ( + "two_pass" if abs(angle_residual) >= 0.01 else "iterative" + ) + + # Encode as PNG + success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = deskewed_png_buf.tobytes() if success else b"" + + # Create binarized version + binarized_png = None + try: + binarized = create_ocr_image(deskewed_bgr) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception as e: + logger.warning(f"Binarization failed: {e}") + + confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0) + + deskew_result = { + "angle_hough": round(angle_hough, 3), + "angle_word_alignment": round(angle_wa, 3), + "angle_iterative": round(angle_iterative, 3), + "angle_residual": round(angle_residual, 3), + "angle_textline": round(angle_textline, 3), + "angle_applied": round(angle_applied, 3), + "method_used": method_used, + "confidence": round(confidence, 2), + "duration_seconds": round(duration, 2), + "two_pass_debug": two_pass_debug, + } + + # Update cache + cached["deskewed_bgr"] = deskewed_bgr + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + "current_step": 3, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: deskew session {session_id}: " + f"hough={angle_hough:.2f} wa={angle_wa:.2f} " + f"iter={angle_iterative:.2f} residual={angle_residual:.2f} " + f"textline={angle_textline:.2f} " + f"-> {method_used} total={angle_applied:.2f}") + + await _append_pipeline_log(session_id, "deskew", { + "angle_applied": round(angle_applied, 3), + "angle_iterative": round(angle_iterative, 3), + "angle_residual": round(angle_residual, 3), + "angle_textline": round(angle_textline, 3), + "confidence": round(confidence, 2), + "method": method_used, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **deskew_result, + "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", + "binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized", + } + + +@router.post("/sessions/{session_id}/deskew/manual") +async def manual_deskew(session_id: str, req: ManualDeskewRequest): + """Apply a manual rotation angle to the oriented image.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = next((v for k in ("oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), None) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for deskewing") + + angle = max(-5.0, min(5.0, req.angle)) + + h, w = img_bgr.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(img_bgr, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + success, png_buf = cv2.imencode(".png", rotated) + deskewed_png = png_buf.tobytes() if success else b"" + + # Binarize + binarized_png = None + try: + binarized = create_ocr_image(rotated) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception: + pass + + deskew_result = { + **(cached.get("deskew_result") or {}), + "angle_applied": round(angle, 3), + "method_used": "manual", + } + + # Update cache + cached["deskewed_bgr"] = rotated + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}") + + return { + "session_id": session_id, + "angle_applied": round(angle, 3), + "method_used": "manual", + "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", + } + + +@router.post("/sessions/{session_id}/ground-truth/deskew") +async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest): + """Save ground truth feedback for the deskew step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_angle": req.corrected_angle, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "deskew_result": session.get("deskew_result"), + } + ground_truth["deskew"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + # Update cache + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: " + f"correct={req.is_correct}, corrected_angle={req.corrected_angle}") + + return {"session_id": session_id, "ground_truth": gt} diff --git a/klausur-service/backend/ocr/pipeline/dewarp.py b/klausur-service/backend/ocr/pipeline/dewarp.py new file mode 100644 index 0000000..21fd7d0 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/dewarp.py @@ -0,0 +1,346 @@ +""" +OCR Pipeline Dewarp Endpoints + +Auto dewarp (with VLM/CV ensemble), manual dewarp, combined +rotation+shear adjustment, and ground truth. +Extracted from ocr_pipeline_geometry.py for file-size compliance. +""" + +import json +import logging +import os +import re +import time +from datetime import datetime +from typing import Any, Dict + +import cv2 +from fastapi import APIRouter, HTTPException, Query + +from cv_vocab_pipeline import ( + _apply_shear, + create_ocr_image, + dewarp_image, + dewarp_image_manual, +) +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, + ManualDewarpRequest, + CombinedAdjustRequest, + DewarpGroundTruthRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} + + +@router.post("/sessions/{session_id}/dewarp") +async def auto_dewarp( + session_id: str, + method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), +): + """Detect and correct vertical shear on the deskewed image. + + Methods: + - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) + - **cv**: CV ensemble only (same as ensemble) + - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually + """ + if method not in ("ensemble", "cv", "vlm"): + raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") + + t0 = time.time() + + if method == "vlm": + # Encode deskewed image to PNG for VLM + success, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + duration = time.time() - t0 + + # Encode as PNG + success, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(duration, 2), + "detections": dewarp_info.get("detections", []), + } + + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=4, + ) + + logger.info(f"OCR Pipeline: dewarp session {session_id}: " + f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " + f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") + + await _append_pipeline_log(session_id, "dewarp", { + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "method": dewarp_info["method"], + "ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])], + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **dewarp_result, + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/dewarp/manual") +async def manual_dewarp(session_id: str, req: ManualDewarpRequest): + """Apply shear correction with a manual angle.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") + + shear_deg = max(-2.0, min(2.0, req.shear_degrees)) + + if abs(shear_deg) < 0.001: + dewarped_bgr = deskewed_bgr + else: + dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) + + success, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + dewarp_result = { + **(cached.get("dewarp_result") or {}), + "method_used": "manual", + "shear_degrees": round(shear_deg, 3), + } + + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + ) + + logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") + + return { + "session_id": session_id, + "shear_degrees": round(shear_deg, 3), + "method_used": "manual", + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/adjust-combined") +async def adjust_combined(session_id: str, req: CombinedAdjustRequest): + """Apply rotation + shear combined to the original image. + + Used by the fine-tuning sliders to preview arbitrary rotation/shear + combinations without re-running the full deskew/dewarp pipeline. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("original_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Original image not available") + + rotation = max(-15.0, min(15.0, req.rotation_degrees)) + shear_deg = max(-5.0, min(5.0, req.shear_degrees)) + + h, w = img_bgr.shape[:2] + result_bgr = img_bgr + + # Step 1: Apply rotation + if abs(rotation) >= 0.001: + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, rotation, 1.0) + result_bgr = cv2.warpAffine(result_bgr, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + # Step 2: Apply shear + if abs(shear_deg) >= 0.001: + result_bgr = dewarp_image_manual(result_bgr, shear_deg) + + # Encode + success, png_buf = cv2.imencode(".png", result_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + # Binarize + binarized_png = None + try: + binarized = create_ocr_image(result_bgr) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception: + pass + + # Build combined result dicts + deskew_result = { + **(cached.get("deskew_result") or {}), + "angle_applied": round(rotation, 3), + "method_used": "manual_combined", + } + dewarp_result = { + **(cached.get("dewarp_result") or {}), + "method_used": "manual_combined", + "shear_degrees": round(shear_deg, 3), + } + + # Update cache + cached["deskewed_bgr"] = result_bgr + cached["dewarped_bgr"] = result_bgr + cached["deskew_result"] = deskew_result + cached["dewarp_result"] = dewarp_result + + # Persist to DB + db_update = { + "dewarped_png": dewarped_png, + "deskew_result": deskew_result, + "dewarp_result": dewarp_result, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + db_update["deskewed_png"] = dewarped_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: combined adjust session {session_id}: " + f"rotation={rotation:.3f} shear={shear_deg:.3f}") + + return { + "session_id": session_id, + "rotation_degrees": round(rotation, 3), + "shear_degrees": round(shear_deg, 3), + "method_used": "manual_combined", + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/ground-truth/dewarp") +async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): + """Save ground truth feedback for the dewarp step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_shear": req.corrected_shear, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "dewarp_result": session.get("dewarp_result"), + } + ground_truth["dewarp"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " + f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") + + return {"session_id": session_id, "ground_truth": gt} diff --git a/klausur-service/backend/ocr/pipeline/geometry.py b/klausur-service/backend/ocr/pipeline/geometry.py new file mode 100644 index 0000000..2e77a75 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/geometry.py @@ -0,0 +1,27 @@ +""" +OCR Pipeline Geometry API (barrel re-export) + +This module was split into: + - ocr_pipeline_deskew.py (Deskew endpoints) + - ocr_pipeline_dewarp.py (Dewarp endpoints) + - ocr_pipeline_structure.py (Structure detection + exclude regions) + - ocr_pipeline_columns.py (Column detection + ground truth) + +The `router` object is assembled here by including all sub-routers. +Importers that did `from ocr_pipeline_geometry import router` continue to work. +""" + +from fastapi import APIRouter + +from .deskew import router as _deskew_router +from .dewarp import router as _dewarp_router +from .structure import router as _structure_router +from .columns import router as _columns_router + +# Assemble the combined router. +# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix. +router = APIRouter() +router.include_router(_deskew_router) +router.include_router(_dewarp_router) +router.include_router(_structure_router) +router.include_router(_columns_router) diff --git a/klausur-service/backend/ocr/pipeline/llm_review.py b/klausur-service/backend/ocr/pipeline/llm_review.py new file mode 100644 index 0000000..f2dd054 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/llm_review.py @@ -0,0 +1,209 @@ +""" +OCR Pipeline LLM Review — LLM-based correction endpoints. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +from datetime import datetime +from typing import Dict, List + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + llm_review_entries, + llm_review_entries_streaming, +) +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _append_pipeline_log, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 8: LLM Review +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/llm-review") +async def run_llm_review(session_id: str, request: Request, stream: bool = False): + """Run LLM-based correction on vocab entries from Step 5. + + Query params: + stream: false (default) for JSON response, true for SSE streaming + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") + + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if not entries: + raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") + + # Optional model override from request body + body = {} + try: + body = await request.json() + except Exception: + pass + model = body.get("model") or OLLAMA_REVIEW_MODEL + + if stream: + return StreamingResponse( + _llm_review_stream_generator(session_id, entries, word_result, model, request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + + # Non-streaming path + try: + result = await llm_review_entries(entries, model=model) + except Exception as e: + import traceback + logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") + + # Store result inside word_result as a sub-key + word_result["llm_review"] = { + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "entries_corrected": result["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " + f"{result['duration_ms']}ms, model={result['model_used']}") + + await _append_pipeline_log(session_id, "correction", { + "engine": "llm", + "model": result["model_used"], + "total_entries": len(entries), + "corrections_proposed": len(result["changes"]), + }, duration_ms=result["duration_ms"]) + + return { + "session_id": session_id, + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "total_entries": len(entries), + "corrections_found": len(result["changes"]), + } + + +async def _llm_review_stream_generator( + session_id: str, + entries: List[Dict], + word_result: Dict, + model: str, + request: Request, +): + """SSE generator that yields batch-by-batch LLM review progress.""" + try: + async for event in llm_review_entries_streaming(entries, model=model): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during LLM review for {session_id}") + return + + yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" + + # On complete: persist to DB + if event.get("type") == "complete": + word_result["llm_review"] = { + "changes": event["changes"], + "model_used": event["model_used"], + "duration_ms": event["duration_ms"], + "entries_corrected": event["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " + f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") + + except Exception as e: + import traceback + logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} + yield f"data: {json.dumps(error_event)}\n\n" + + +@router.post("/sessions/{session_id}/llm-review/apply") +async def apply_llm_corrections(session_id: str, request: Request): + """Apply selected LLM corrections to vocab entries.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + llm_review = word_result.get("llm_review") + if not llm_review: + raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") + + body = await request.json() + accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] + + changes = llm_review.get("changes", []) + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + + # Build a lookup: (row_index, field) -> new_value for accepted changes + corrections = {} + applied_count = 0 + for idx, change in enumerate(changes): + if idx in accepted_indices: + key = (change["row_index"], change["field"]) + corrections[key] = change["new"] + applied_count += 1 + + # Apply corrections to entries + for entry in entries: + row_idx = entry.get("row_index", -1) + for field_name in ("english", "german", "example"): + key = (row_idx, field_name) + if key in corrections: + entry[field_name] = corrections[key] + entry["llm_corrected"] = True + + # Update word_result + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["llm_review"]["applied_count"] = applied_count + word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() + + await update_session_db(session_id, word_result=word_result) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") + + return { + "session_id": session_id, + "applied_count": applied_count, + "total_changes": len(changes), + } diff --git a/klausur-service/backend/ocr/pipeline/merge_helpers.py b/klausur-service/backend/ocr/pipeline/merge_helpers.py new file mode 100644 index 0000000..571c116 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/merge_helpers.py @@ -0,0 +1,272 @@ +""" +OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results. + +Extracted from ocr_pipeline_ocr_merge.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +def _split_paddle_multi_words(words: list) -> list: + """Split PaddleOCR multi-word boxes into individual word boxes. + + PaddleOCR often returns entire phrases as a single box, e.g. + "More than 200 singers took part in the" with one bounding box. + This splits them into individual words with proportional widths. + Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"]) + and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]). + """ + import re + + result = [] + for w in words: + raw_text = w.get("text", "").strip() + if not raw_text: + continue + # Split on whitespace, before "[" (IPA), and after "!" before letter + tokens = re.split( + r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text + ) + tokens = [t for t in tokens if t] + + if len(tokens) <= 1: + result.append(w) + else: + # Split proportionally by character count + total_chars = sum(len(t) for t in tokens) + if total_chars == 0: + continue + n_gaps = len(tokens) - 1 + gap_px = w["width"] * 0.02 + usable_w = w["width"] - gap_px * n_gaps + cursor = w["left"] + for t in tokens: + token_w = max(1, usable_w * len(t) / total_chars) + result.append({ + "text": t, + "left": round(cursor), + "top": w["top"], + "width": round(token_w), + "height": w["height"], + "conf": w.get("conf", 0), + }) + cursor += token_w + gap_px + return result + + +def _group_words_into_rows(words: list, row_gap: int = 12) -> list: + """Group words into rows by Y-position clustering. + + Words whose vertical centers are within `row_gap` pixels are on the same row. + Returns list of rows, each row is a list of words sorted left-to-right. + """ + if not words: + return [] + # Sort by vertical center + sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) + rows: list = [] + current_row: list = [sorted_words[0]] + current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 + + for w in sorted_words[1:]: + cy = w["top"] + w.get("height", 0) / 2 + if abs(cy - current_cy) <= row_gap: + current_row.append(w) + else: + # Sort current row left-to-right before saving + rows.append(sorted(current_row, key=lambda w: w["left"])) + current_row = [w] + current_cy = cy + if current_row: + rows.append(sorted(current_row, key=lambda w: w["left"])) + return rows + + +def _row_center_y(row: list) -> float: + """Average vertical center of a row of words.""" + if not row: + return 0.0 + return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) + + +def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: + """Merge two word sequences from the same row using sequence alignment. + + Both sequences are sorted left-to-right. Walk through both simultaneously: + - If words match (same/similar text): take Paddle text with averaged coords + - If they don't match: the extra word is unique to one engine, include it + """ + merged = [] + pi, ti = 0, 0 + + while pi < len(paddle_row) and ti < len(tess_row): + pw = paddle_row[pi] + tw = tess_row[ti] + + pt = pw.get("text", "").lower().strip() + tt = tw.get("text", "").lower().strip() + + is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) + + # Spatial overlap check + spatial_match = False + if not is_same: + overlap_left = max(pw["left"], tw["left"]) + overlap_right = min( + pw["left"] + pw.get("width", 0), + tw["left"] + tw.get("width", 0), + ) + overlap_w = max(0, overlap_right - overlap_left) + min_w = min(pw.get("width", 1), tw.get("width", 1)) + if min_w > 0 and overlap_w / min_w >= 0.4: + is_same = True + spatial_match = True + + if is_same: + pc = pw.get("conf", 80) + tc = tw.get("conf", 50) + total = pc + tc + if total == 0: + total = 1 + if spatial_match and pc < tc: + best_text = tw["text"] + else: + best_text = pw["text"] + merged.append({ + "text": best_text, + "left": round((pw["left"] * pc + tw["left"] * tc) / total), + "top": round((pw["top"] * pc + tw["top"] * tc) / total), + "width": round((pw["width"] * pc + tw["width"] * tc) / total), + "height": round((pw["height"] * pc + tw["height"] * tc) / total), + "conf": max(pc, tc), + }) + pi += 1 + ti += 1 + else: + paddle_ahead = any( + tess_row[t].get("text", "").lower().strip() == pt + for t in range(ti + 1, min(ti + 4, len(tess_row))) + ) + tess_ahead = any( + paddle_row[p].get("text", "").lower().strip() == tt + for p in range(pi + 1, min(pi + 4, len(paddle_row))) + ) + + if paddle_ahead and not tess_ahead: + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + elif tess_ahead and not paddle_ahead: + merged.append(pw) + pi += 1 + else: + if pw["left"] <= tw["left"]: + merged.append(pw) + pi += 1 + else: + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + while pi < len(paddle_row): + merged.append(paddle_row[pi]) + pi += 1 + while ti < len(tess_row): + tw = tess_row[ti] + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + return merged + + +def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: + """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.""" + if not paddle_words and not tess_words: + return [] + if not paddle_words: + return [w for w in tess_words if w.get("conf", 0) >= 40] + if not tess_words: + return list(paddle_words) + + paddle_rows = _group_words_into_rows(paddle_words) + tess_rows = _group_words_into_rows(tess_words) + + used_tess_rows: set = set() + merged_all: list = [] + + for pr in paddle_rows: + pr_cy = _row_center_y(pr) + best_dist, best_tri = float("inf"), -1 + for tri, tr in enumerate(tess_rows): + if tri in used_tess_rows: + continue + tr_cy = _row_center_y(tr) + dist = abs(pr_cy - tr_cy) + if dist < best_dist: + best_dist, best_tri = dist, tri + + max_row_dist = max( + max((w.get("height", 20) for w in pr), default=20), + 15, + ) + + if best_tri >= 0 and best_dist <= max_row_dist: + tr = tess_rows[best_tri] + used_tess_rows.add(best_tri) + merged_all.extend(_merge_row_sequences(pr, tr)) + else: + merged_all.extend(pr) + + for tri, tr in enumerate(tess_rows): + if tri not in used_tess_rows: + for tw in tr: + if tw.get("conf", 0) >= 40: + merged_all.append(tw) + + return merged_all + + +def _deduplicate_words(words: list) -> list: + """Remove duplicate words with same text at overlapping positions.""" + if not words: + return words + + result: list = [] + for w in words: + wt = w.get("text", "").lower().strip() + if not wt: + continue + is_dup = False + w_right = w["left"] + w.get("width", 0) + w_bottom = w["top"] + w.get("height", 0) + for existing in result: + et = existing.get("text", "").lower().strip() + if wt != et: + continue + ox_l = max(w["left"], existing["left"]) + ox_r = min(w_right, existing["left"] + existing.get("width", 0)) + ox = max(0, ox_r - ox_l) + min_w = min(w.get("width", 1), existing.get("width", 1)) + if min_w <= 0 or ox / min_w < 0.5: + continue + oy_t = max(w["top"], existing["top"]) + oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) + oy = max(0, oy_b - oy_t) + min_h = min(w.get("height", 1), existing.get("height", 1)) + if min_h > 0 and oy / min_h >= 0.5: + is_dup = True + break + if not is_dup: + result.append(w) + + removed = len(words) - len(result) + if removed: + logger.info("dedup: removed %d duplicate words", removed) + return result diff --git a/klausur-service/backend/ocr/pipeline/ocr_merge.py b/klausur-service/backend/ocr/pipeline/ocr_merge.py new file mode 100644 index 0000000..17f73a1 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/ocr_merge.py @@ -0,0 +1,266 @@ +""" +OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints. + +Merge helper functions live in ocr_merge_helpers.py. +This module re-exports them for backward compatibility. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException + +from cv_words_first import build_grid_from_words +from .common import _cache, _append_pipeline_log +from .session_store import get_session_image, update_session_db + +# Re-export merge helpers for backward compatibility +from .merge_helpers import ( # noqa: F401 + _split_paddle_multi_words, + _group_words_into_rows, + _row_center_y, + _merge_row_sequences, + _merge_paddle_tesseract, + _deduplicate_words, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +def _run_tesseract_words(img_bgr) -> list: + """Run Tesseract OCR on an image and return word dicts.""" + from PIL import Image + import pytesseract + + pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) + data = pytesseract.image_to_data( + pil_img, lang="eng+deu", + config="--psm 6 --oem 3", + output_type=pytesseract.Output.DICT, + ) + tess_words = [] + for i in range(len(data["text"])): + text = str(data["text"][i]).strip() + conf_raw = str(data["conf"][i]) + conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 + if not text or conf < 20: + continue + tess_words.append({ + "text": text, + "left": data["left"][i], + "top": data["top"][i], + "width": data["width"][i], + "height": data["height"][i], + "conf": conf, + }) + return tess_words + + +def _build_kombi_word_result( + cells: list, + columns_meta: list, + img_w: int, + img_h: int, + duration: float, + engine_name: str, + raw_engine_words: list, + raw_engine_words_split: list, + tess_words: list, + merged_words: list, + raw_engine_key: str = "raw_paddle_words", + raw_split_key: str = "raw_paddle_words_split", +) -> dict: + """Build the word_result dict for kombi endpoints.""" + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + return { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": engine_name, + "grid_method": engine_name, + raw_engine_key: raw_engine_words, + raw_split_key: raw_engine_words_split, + "raw_tesseract_words": tess_words, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words), + raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + }, + } + + +async def _load_session_image(session_id: str): + """Load preprocessed image for kombi endpoints.""" + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode image") + + return img_png, img_bgr + + +# --------------------------------------------------------------------------- +# Kombi endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/paddle-kombi") +async def paddle_kombi(session_id: str): + """Run PaddleOCR + Tesseract on the preprocessed image and merge results.""" + img_png, img_bgr = await _load_session_image(session_id) + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_paddle + + t0 = time.time() + + paddle_words = await ocr_region_paddle(img_bgr, region=None) + if not paddle_words: + paddle_words = [] + + tess_words = _run_tesseract_words(img_bgr) + + paddle_words_split = _split_paddle_multi_words(paddle_words) + logger.info( + "paddle_kombi: split %d paddle boxes -> %d individual words", + len(paddle_words), len(paddle_words_split), + ) + + if not paddle_words_split and not tess_words: + raise HTTPException(status_code=400, detail="Both OCR engines returned no words") + + merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words) + merged_words = _deduplicate_words(merged_words) + + cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) + duration = time.time() - t0 + + for cell in cells: + cell["ocr_engine"] = "kombi" + + word_result = _build_kombi_word_result( + cells, columns_meta, img_w, img_h, duration, "kombi", + paddle_words, paddle_words_split, tess_words, merged_words, + "raw_paddle_words", "raw_paddle_words_split", + ) + + await update_session_db( + session_id, word_result=word_result, cropped_png=img_png, current_step=8, + ) + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info( + "paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " + "[paddle=%d, tess=%d, merged=%d]", + session_id, len(cells), word_result["grid_shape"]["rows"], + word_result["grid_shape"]["cols"], duration, + len(paddle_words), len(tess_words), len(merged_words), + ) + + await _append_pipeline_log(session_id, "paddle_kombi", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "paddle_words": len(paddle_words), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + "ocr_engine": "kombi", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +@router.post("/sessions/{session_id}/rapid-kombi") +async def rapid_kombi(session_id: str): + """Run RapidOCR + Tesseract on the preprocessed image and merge results.""" + img_png, img_bgr = await _load_session_image(session_id) + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_rapid + from cv_vocab_types import PageRegion + + t0 = time.time() + + full_region = PageRegion( + type="full_page", x=0, y=0, width=img_w, height=img_h, + ) + rapid_words = ocr_region_rapid(img_bgr, full_region) + if not rapid_words: + rapid_words = [] + + tess_words = _run_tesseract_words(img_bgr) + + rapid_words_split = _split_paddle_multi_words(rapid_words) + logger.info( + "rapid_kombi: split %d rapid boxes -> %d individual words", + len(rapid_words), len(rapid_words_split), + ) + + if not rapid_words_split and not tess_words: + raise HTTPException(status_code=400, detail="Both OCR engines returned no words") + + merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words) + merged_words = _deduplicate_words(merged_words) + + cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) + duration = time.time() - t0 + + for cell in cells: + cell["ocr_engine"] = "rapid_kombi" + + word_result = _build_kombi_word_result( + cells, columns_meta, img_w, img_h, duration, "rapid_kombi", + rapid_words, rapid_words_split, tess_words, merged_words, + "raw_rapid_words", "raw_rapid_words_split", + ) + + await update_session_db( + session_id, word_result=word_result, cropped_png=img_png, current_step=8, + ) + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info( + "rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " + "[rapid=%d, tess=%d, merged=%d]", + session_id, len(cells), word_result["grid_shape"]["rows"], + word_result["grid_shape"]["cols"], duration, + len(rapid_words), len(tess_words), len(merged_words), + ) + + await _append_pipeline_log(session_id, "rapid_kombi", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "rapid_words": len(rapid_words), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + "ocr_engine": "rapid_kombi", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} diff --git a/klausur-service/backend/ocr/pipeline/orientation_api.py b/klausur-service/backend/ocr/pipeline/orientation_api.py new file mode 100644 index 0000000..254eb3c --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/orientation_api.py @@ -0,0 +1,188 @@ +""" +Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline). +""" + +import logging +import time +from typing import Any, Dict + +import cv2 +from fastapi import APIRouter, HTTPException + +from cv_vocab_pipeline import detect_and_fix_orientation +from .page_crop import detect_page_splits +from .session_store import update_session_db + +from .orientation_crop_helpers import ensure_cached, append_pipeline_log +from .page_sub_sessions import create_page_sub_sessions_full + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 1: Orientation +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/orientation") +async def detect_orientation(session_id: str): + """Detect and fix 90/180/270 degree rotations from scanners. + + Reads the original image, applies orientation correction, + stores the result as oriented_png. + """ + cached = await ensure_cached(session_id) + + img_bgr = cached.get("original_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Original image not available") + + t0 = time.time() + + # Detect and fix orientation + oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy()) + + duration = time.time() - t0 + + orientation_result = { + "orientation_degrees": orientation_deg, + "corrected": orientation_deg != 0, + "duration_seconds": round(duration, 2), + } + + # Encode oriented image + success, png_buf = cv2.imencode(".png", oriented_bgr) + oriented_png = png_buf.tobytes() if success else b"" + + # Update cache + cached["oriented_bgr"] = oriented_bgr + cached["orientation_result"] = orientation_result + + # Persist to DB + await update_session_db( + session_id, + oriented_png=oriented_png, + orientation_result=orientation_result, + current_step=2, + ) + + logger.info( + "OCR Pipeline: orientation session %s: %d° (%s) in %.2fs", + session_id, orientation_deg, + "corrected" if orientation_deg else "no change", + duration, + ) + + await append_pipeline_log(session_id, "orientation", { + "orientation_degrees": orientation_deg, + "corrected": orientation_deg != 0, + }, duration_ms=int(duration * 1000)) + + h, w = oriented_bgr.shape[:2] + return { + "session_id": session_id, + **orientation_result, + "image_width": w, + "image_height": h, + "oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented", + } + + +# --------------------------------------------------------------------------- +# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/page-split") +async def detect_page_split(session_id: str): + """Detect if the image is a double-page book spread and split into sub-sessions. + + Must be called **after orientation** (step 1) and **before deskew** (step 2). + Each sub-session receives the raw page region and goes through the full + pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid) + independently, so each page gets its own deskew correction. + + Returns ``{"multi_page": false}`` if only one page is detected. + """ + cached = await ensure_cached(session_id) + + # Use oriented (preferred), fall back to original + img_bgr = next( + (v for k in ("oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), + None, + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for page-split detection") + + t0 = time.time() + page_splits = detect_page_splits(img_bgr) + used_original = False + + if not page_splits or len(page_splits) < 2: + # Orientation may have rotated a landscape double-page spread to + # portrait. Try the original (pre-orientation) image as fallback. + orig_bgr = cached.get("original_bgr") + if orig_bgr is not None and orig_bgr is not img_bgr: + page_splits_orig = detect_page_splits(orig_bgr) + if page_splits_orig and len(page_splits_orig) >= 2: + logger.info( + "OCR Pipeline: page-split session %s: spread detected on " + "ORIGINAL (orientation rotated it away)", + session_id, + ) + img_bgr = orig_bgr + page_splits = page_splits_orig + used_original = True + + if not page_splits or len(page_splits) < 2: + duration = time.time() - t0 + logger.info( + "OCR Pipeline: page-split session %s: single page (%.2fs)", + session_id, duration, + ) + return { + "session_id": session_id, + "multi_page": False, + "duration_seconds": round(duration, 2), + } + + # Multi-page spread detected — create sub-sessions for full pipeline. + # start_step=2 means "ready for deskew" (orientation already applied). + # start_step=1 means "needs orientation too" (split from original image). + start_step = 1 if used_original else 2 + sub_sessions = await create_page_sub_sessions_full( + session_id, cached, img_bgr, page_splits, start_step=start_step, + ) + duration = time.time() - t0 + + split_info: Dict[str, Any] = { + "multi_page": True, + "page_count": len(page_splits), + "page_splits": page_splits, + "used_original": used_original, + "duration_seconds": round(duration, 2), + } + + # Mark parent session as split and hidden from session list + await update_session_db(session_id, crop_result=split_info, status='split') + cached["crop_result"] = split_info + + await append_pipeline_log(session_id, "page_split", { + "multi_page": True, + "page_count": len(page_splits), + }, duration_ms=int(duration * 1000)) + + logger.info( + "OCR Pipeline: page-split session %s: %d pages detected in %.2fs", + session_id, len(page_splits), duration, + ) + + h, w = img_bgr.shape[:2] + return { + "session_id": session_id, + **split_info, + "image_width": w, + "image_height": h, + "sub_sessions": sub_sessions, + } diff --git a/klausur-service/backend/ocr/pipeline/orientation_crop_api.py b/klausur-service/backend/ocr/pipeline/orientation_crop_api.py new file mode 100644 index 0000000..f21417e --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/orientation_crop_api.py @@ -0,0 +1,16 @@ +""" +Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline. + +Barrel re-export: merges routers from orientation_api and crop_api, +and re-exports set_cache_ref for main.py. +""" + +from fastapi import APIRouter + +from .orientation_crop_helpers import set_cache_ref # noqa: F401 +from .orientation_api import router as _orientation_router +from .crop_api import router as _crop_router + +router = APIRouter() +router.include_router(_orientation_router) +router.include_router(_crop_router) diff --git a/klausur-service/backend/ocr/pipeline/orientation_crop_helpers.py b/klausur-service/backend/ocr/pipeline/orientation_crop_helpers.py new file mode 100644 index 0000000..0949bee --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/orientation_crop_helpers.py @@ -0,0 +1,86 @@ +""" +Orientation & Crop shared helpers - cache management and pipeline logging. +""" + +import logging +from typing import Any, Dict + +import cv2 +import numpy as np +from fastapi import HTTPException + +from .session_store import ( + get_session_db, + get_session_image, + update_session_db, +) + +logger = logging.getLogger(__name__) + + +# Reference to the shared cache from ocr_pipeline_api (set in main.py) +_cache: Dict[str, Dict[str, Any]] = {} + + +def set_cache_ref(cache: Dict[str, Dict[str, Any]]): + """Set reference to the shared cache from ocr_pipeline_api.""" + global _cache + _cache = cache + + +def get_cache_ref() -> Dict[str, Dict[str, Any]]: + """Get reference to the shared cache.""" + return _cache + + +async def ensure_cached(session_id: str) -> Dict[str, Any]: + """Ensure session is in cache, loading from DB if needed.""" + if session_id in _cache: + return _cache[session_id] + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + cache_entry: Dict[str, Any] = { + "id": session_id, + **session, + "original_bgr": None, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + } + + for img_type, bgr_key in [ + ("original", "original_bgr"), + ("oriented", "oriented_bgr"), + ("cropped", "cropped_bgr"), + ("deskewed", "deskewed_bgr"), + ("dewarped", "dewarped_bgr"), + ]: + png_data = await get_session_image(session_id, img_type) + if png_data: + arr = np.frombuffer(png_data, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + cache_entry[bgr_key] = bgr + + _cache[session_id] = cache_entry + return cache_entry + + +async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int): + """Append a step entry to the pipeline log.""" + from datetime import datetime + session = await get_session_db(session_id) + if not session: + return + pipeline_log = session.get("pipeline_log") or {"steps": []} + pipeline_log["steps"].append({ + "step": step, + "completed_at": datetime.utcnow().isoformat(), + "success": True, + "duration_ms": duration_ms, + "metrics": metrics, + }) + await update_session_db(session_id, pipeline_log=pipeline_log) diff --git a/klausur-service/backend/ocr/pipeline/overlay_grid.py b/klausur-service/backend/ocr/pipeline/overlay_grid.py new file mode 100644 index 0000000..f3069e1 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/overlay_grid.py @@ -0,0 +1,333 @@ +""" +Overlay rendering for columns, rows, and words (grid-based overlays). + +Extracted from ocr_pipeline_overlays.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List + +import cv2 +import numpy as np +from fastapi import HTTPException +from fastapi.responses import Response + +from .common import _get_base_image_png +from .session_store import get_session_db +from .rows import _draw_box_exclusion_overlay + +logger = logging.getLogger(__name__) + + +async def _get_columns_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with column borders drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result or not column_result.get("columns"): + raise HTTPException(status_code=404, detail="No column data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + # Color map for region types (BGR) + colors = { + "column_en": (255, 180, 0), # Blue + "column_de": (0, 200, 0), # Green + "column_example": (0, 140, 255), # Orange + "column_text": (200, 200, 0), # Cyan/Turquoise + "page_ref": (200, 0, 200), # Purple + "column_marker": (0, 0, 220), # Red + "column_ignore": (180, 180, 180), # Light Gray + "header": (128, 128, 128), # Gray + "footer": (128, 128, 128), # Gray + "margin_top": (100, 100, 100), # Dark Gray + "margin_bottom": (100, 100, 100), # Dark Gray + } + + overlay = img.copy() + for col in column_result["columns"]: + x, y = col["x"], col["y"] + w, h = col["width"], col["height"] + color = colors.get(col.get("type", ""), (200, 200, 200)) + + # Semi-transparent fill + cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) + + # Solid border + cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) + + # Label with confidence + label = col.get("type", "unknown").replace("column_", "").upper() + conf = col.get("classification_confidence") + if conf is not None and conf < 1.0: + label = f"{label} {int(conf * 100)}%" + cv2.putText(img, label, (x + 10, y + 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) + + # Blend overlay at 20% opacity + cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) + + # Draw detected box boundaries as dashed rectangles + zones = column_result.get("zones") or [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + box_color = (0, 200, 255) # Yellow (BGR) + # Draw dashed rectangle by drawing short line segments + dash_len = 15 + for edge_x in range(bx, bx + bw, dash_len * 2): + end_x = min(edge_x + dash_len, bx + bw) + cv2.line(img, (edge_x, by), (end_x, by), box_color, 2) + cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2) + for edge_y in range(by, by + bh, dash_len * 2): + end_y = min(edge_y + dash_len, by + bh) + cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2) + cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2) + cv2.putText(img, "BOX", (bx + 10, by + bh - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) + + # Red semi-transparent overlay for box zones + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") + + +async def _get_rows_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with row bands drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + row_result = session.get("row_result") + if not row_result or not row_result.get("rows"): + raise HTTPException(status_code=404, detail="No row data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + # Color map for row types (BGR) + row_colors = { + "content": (255, 180, 0), # Blue + "header": (128, 128, 128), # Gray + "footer": (128, 128, 128), # Gray + "margin_top": (100, 100, 100), # Dark Gray + "margin_bottom": (100, 100, 100), # Dark Gray + } + + overlay = img.copy() + for row in row_result["rows"]: + x, y = row["x"], row["y"] + w, h = row["width"], row["height"] + row_type = row.get("row_type", "content") + color = row_colors.get(row_type, (200, 200, 200)) + + # Semi-transparent fill + cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) + + # Solid border + cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) + + # Label + idx = row.get("index", 0) + label = f"R{idx} {row_type.upper()}" + wc = row.get("word_count", 0) + if wc: + label = f"{label} ({wc}w)" + cv2.putText(img, label, (x + 5, y + 18), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) + + # Blend overlay at 15% opacity + cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) + + # Draw zone separator lines if zones exist + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + if zones: + img_w_px = img.shape[1] + zone_color = (0, 200, 255) # Yellow (BGR) + dash_len = 20 + for zone in zones: + if zone.get("zone_type") == "box": + zy = zone["y"] + zh = zone["height"] + for line_y in [zy, zy + zh]: + for sx in range(0, img_w_px, dash_len * 2): + ex = min(sx + dash_len, img_w_px) + cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) + + # Red semi-transparent overlay for box zones + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") + + +async def _get_words_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with cell grid drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=404, detail="No word data available") + + # Support both new cell-based and legacy entry-based formats + cells = word_result.get("cells") + if not cells and not word_result.get("entries"): + raise HTTPException(status_code=404, detail="No word data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + img_h, img_w = img.shape[:2] + + overlay = img.copy() + + if cells: + # New cell-based overlay: color by column index + col_palette = [ + (255, 180, 0), # Blue (BGR) + (0, 200, 0), # Green + (0, 140, 255), # Orange + (200, 100, 200), # Purple + (200, 200, 0), # Cyan + (100, 200, 200), # Yellow-ish + ] + + for cell in cells: + bbox = cell.get("bbox_px", {}) + cx = bbox.get("x", 0) + cy = bbox.get("y", 0) + cw = bbox.get("w", 0) + ch = bbox.get("h", 0) + if cw <= 0 or ch <= 0: + continue + + col_idx = cell.get("col_index", 0) + color = col_palette[col_idx % len(col_palette)] + + # Cell rectangle border + cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1) + # Semi-transparent fill + cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1) + + # Cell-ID label (top-left corner) + cell_id = cell.get("cell_id", "") + cv2.putText(img, cell_id, (cx + 2, cy + 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1) + + # Text label (bottom of cell) + text = cell.get("text", "") + if text: + conf = cell.get("confidence", 0) + if conf >= 70: + text_color = (0, 180, 0) + elif conf >= 50: + text_color = (0, 180, 220) + else: + text_color = (0, 0, 220) + + label = text.replace('\n', ' ')[:30] + cv2.putText(img, label, (cx + 3, cy + ch - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) + else: + # Legacy fallback: entry-based overlay (for old sessions) + column_result = session.get("column_result") + row_result = session.get("row_result") + col_colors = { + "column_en": (255, 180, 0), + "column_de": (0, 200, 0), + "column_example": (0, 140, 255), + } + + columns = [] + if column_result and column_result.get("columns"): + columns = [c for c in column_result["columns"] + if c.get("type", "").startswith("column_")] + + content_rows_data = [] + if row_result and row_result.get("rows"): + content_rows_data = [r for r in row_result["rows"] + if r.get("row_type") == "content"] + + for col in columns: + col_type = col.get("type", "") + color = col_colors.get(col_type, (200, 200, 200)) + cx, cw = col["x"], col["width"] + for row in content_rows_data: + ry, rh = row["y"], row["height"] + cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1) + cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1) + + entries = word_result["entries"] + entry_by_row: Dict[int, Dict] = {} + for entry in entries: + entry_by_row[entry.get("row_index", -1)] = entry + + for row_idx, row in enumerate(content_rows_data): + entry = entry_by_row.get(row_idx) + if not entry: + continue + conf = entry.get("confidence", 0) + text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220) + ry, rh = row["y"], row["height"] + for col in columns: + col_type = col.get("type", "") + cx, cw = col["x"], col["width"] + field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "") + text = entry.get(field, "") if field else "" + if text: + label = text.replace('\n', ' ')[:30] + cv2.putText(img, label, (cx + 3, ry + rh - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) + + # Blend overlay at 10% opacity + cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) + + # Red semi-transparent overlay for box zones + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") diff --git a/klausur-service/backend/ocr/pipeline/overlay_structure.py b/klausur-service/backend/ocr/pipeline/overlay_structure.py new file mode 100644 index 0000000..f621d10 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/overlay_structure.py @@ -0,0 +1,205 @@ +""" +Overlay rendering for structure detection (boxes, zones, colors, graphics). + +Extracted from ocr_pipeline_overlays.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List + +import cv2 +import numpy as np +from fastapi import HTTPException +from fastapi.responses import Response + +from .common import _get_base_image_png +from .session_store import get_session_db +from cv_color_detect import _COLOR_HEX, _COLOR_RANGES +from cv_box_detect import detect_boxes, split_page_into_zones + +logger = logging.getLogger(__name__) + + +async def _get_structure_overlay(session_id: str) -> Response: + """Generate overlay image showing detected boxes, zones, and color regions.""" + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + h, w = img.shape[:2] + + # Get structure result (run detection if not cached) + session = await get_session_db(session_id) + structure = (session or {}).get("structure_result") + + if not structure: + # Run detection on-the-fly + margin = int(min(w, h) * 0.03) + content_x, content_y = margin, margin + content_w_px = w - 2 * margin + content_h_px = h - 2 * margin + boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px) + zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes) + structure = { + "boxes": [ + {"x": b.x, "y": b.y, "w": b.width, "h": b.height, + "confidence": b.confidence, "border_thickness": b.border_thickness} + for b in boxes + ], + "zones": [ + {"index": z.index, "zone_type": z.zone_type, + "y": z.y, "h": z.height, "x": z.x, "w": z.width} + for z in zones + ], + } + + overlay = img.copy() + + # --- Draw zone boundaries --- + zone_colors = { + "content": (200, 200, 200), # light gray + "box": (255, 180, 0), # blue-ish (BGR) + } + for zone in structure.get("zones", []): + zx = zone["x"] + zy = zone["y"] + zw = zone["w"] + zh = zone["h"] + color = zone_colors.get(zone["zone_type"], (200, 200, 200)) + + # Draw zone boundary as dashed line + dash_len = 12 + for edge_x in range(zx, zx + zw, dash_len * 2): + end_x = min(edge_x + dash_len, zx + zw) + cv2.line(img, (edge_x, zy), (end_x, zy), color, 1) + cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1) + + # Zone label + zone_label = f"Zone {zone['index']} ({zone['zone_type']})" + cv2.putText(img, zone_label, (zx + 5, zy + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1) + + # --- Draw detected boxes --- + # Color map for box backgrounds (BGR) + bg_hex_to_bgr = { + "#dc2626": (38, 38, 220), # red + "#2563eb": (235, 99, 37), # blue + "#16a34a": (74, 163, 22), # green + "#ea580c": (12, 88, 234), # orange + "#9333ea": (234, 51, 147), # purple + "#ca8a04": (4, 138, 202), # yellow + "#6b7280": (128, 114, 107), # gray + } + + for box_data in structure.get("boxes", []): + bx = box_data["x"] + by = box_data["y"] + bw = box_data["w"] + bh = box_data["h"] + conf = box_data.get("confidence", 0) + thickness = box_data.get("border_thickness", 0) + bg_hex = box_data.get("bg_color_hex", "#6b7280") + bg_name = box_data.get("bg_color_name", "") + + # Box fill color + fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107)) + + # Semi-transparent fill + cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1) + + # Solid border + border_color = fill_bgr + cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3) + + # Label + label = f"BOX" + if bg_name and bg_name not in ("unknown", "white"): + label += f" ({bg_name})" + if thickness > 0: + label += f" border={thickness}px" + label += f" {int(conf * 100)}%" + cv2.putText(img, label, (bx + 8, by + 22), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2) + cv2.putText(img, label, (bx + 8, by + 22), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1) + + # Blend overlay at 15% opacity + cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) + + # --- Draw color regions (HSV masks) --- + hsv = cv2.cvtColor( + cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR), + cv2.COLOR_BGR2HSV, + ) + color_bgr_map = { + "red": (0, 0, 255), + "orange": (0, 140, 255), + "yellow": (0, 200, 255), + "green": (0, 200, 0), + "blue": (255, 150, 0), + "purple": (200, 0, 200), + } + for color_name, ranges in _COLOR_RANGES.items(): + mask = np.zeros((h, w), dtype=np.uint8) + for lower, upper in ranges: + mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) + # Only draw if there are significant colored pixels + if np.sum(mask > 0) < 100: + continue + # Draw colored contours + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + draw_color = color_bgr_map.get(color_name, (200, 200, 200)) + for cnt in contours: + area = cv2.contourArea(cnt) + if area < 20: + continue + cv2.drawContours(img, [cnt], -1, draw_color, 2) + + # --- Draw graphic elements --- + graphics_data = structure.get("graphics", []) + shape_icons = { + "image": "IMAGE", + "illustration": "ILLUST", + } + for gfx in graphics_data: + gx, gy = gfx["x"], gfx["y"] + gw, gh = gfx["w"], gfx["h"] + shape = gfx.get("shape", "icon") + color_hex = gfx.get("color_hex", "#6b7280") + conf = gfx.get("confidence", 0) + + # Pick draw color based on element color (BGR) + gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) + + # Draw bounding box (dashed style via short segments) + dash = 6 + for seg_x in range(gx, gx + gw, dash * 2): + end_x = min(seg_x + dash, gx + gw) + cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) + cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) + for seg_y in range(gy, gy + gh, dash * 2): + end_y = min(seg_y + dash, gy + gh) + cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) + cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) + + # Label + icon = shape_icons.get(shape, shape.upper()[:5]) + label = f"{icon} {int(conf * 100)}%" + # White background for readability + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) + lx = gx + 2 + ly = max(gy - 4, th + 4) + cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) + cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) + + # Encode result + _, png_buf = cv2.imencode(".png", img) + return Response(content=png_buf.tobytes(), media_type="image/png") diff --git a/klausur-service/backend/ocr/pipeline/overlays.py b/klausur-service/backend/ocr/pipeline/overlays.py new file mode 100644 index 0000000..d7b29bc --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/overlays.py @@ -0,0 +1,34 @@ +""" +Overlay image rendering for OCR pipeline — barrel re-export. + +All implementation split into: + ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics) + ocr_pipeline_overlay_grid — columns, rows, words overlays + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +from fastapi import HTTPException +from fastapi.responses import Response + +from .overlay_structure import _get_structure_overlay # noqa: F401 +from .overlay_grid import ( # noqa: F401 + _get_columns_overlay, + _get_rows_overlay, + _get_words_overlay, +) + + +async def render_overlay(overlay_type: str, session_id: str) -> Response: + """Dispatch to the appropriate overlay renderer.""" + if overlay_type == "structure": + return await _get_structure_overlay(session_id) + elif overlay_type == "columns": + return await _get_columns_overlay(session_id) + elif overlay_type == "rows": + return await _get_rows_overlay(session_id) + elif overlay_type == "words": + return await _get_words_overlay(session_id) + else: + raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}") diff --git a/klausur-service/backend/ocr/pipeline/page_crop.py b/klausur-service/backend/ocr/pipeline/page_crop.py new file mode 100644 index 0000000..9f0dacb --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/page_crop.py @@ -0,0 +1,33 @@ +""" +Page Crop — Barrel Re-export + +Content-based crop for scanned pages and book scans. + +Split into: +- page_crop_edges.py — Edge detection (spine shadow, gutter, projection) +- page_crop_core.py — Main crop algorithm and format detection + +All public names are re-exported here for backward compatibility. +License: Apache 2.0 +""" + +# Core: main crop functions and format detection +from .page_crop_core import ( # noqa: F401 + PAPER_FORMATS, + detect_page_splits, + detect_and_crop_page, + _detect_format, +) + +# Edge detection helpers +from .page_crop_edges import ( # noqa: F401 + _INK_THRESHOLD, + _MIN_RUN_FRAC, + _detect_spine_shadow, + _detect_gutter_continuity, + _detect_left_edge_shadow, + _detect_right_edge_shadow, + _detect_top_bottom_edges, + _detect_edge_projection, + _filter_narrow_runs, +) diff --git a/klausur-service/backend/ocr/pipeline/page_crop_core.py b/klausur-service/backend/ocr/pipeline/page_crop_core.py new file mode 100644 index 0000000..3eca367 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/page_crop_core.py @@ -0,0 +1,342 @@ +""" +Page Crop - Core Crop and Format Detection + +Content-based crop for scanned pages and book scans. Detects the content +boundary by analysing ink density projections and (for book scans) the +spine shadow gradient. + +Extracted from page_crop.py to keep files under 500 LOC. +License: Apache 2.0 +""" + +import logging +from typing import Dict, Any, Tuple + +import cv2 +import numpy as np + +from .page_crop_edges import ( + _detect_left_edge_shadow, + _detect_right_edge_shadow, + _detect_top_bottom_edges, +) + +logger = logging.getLogger(__name__) + +# Known paper format aspect ratios (height / width, portrait orientation) +PAPER_FORMATS = { + "A4": 297.0 / 210.0, # 1.4143 + "A5": 210.0 / 148.0, # 1.4189 + "Letter": 11.0 / 8.5, # 1.2941 + "Legal": 14.0 / 8.5, # 1.6471 + "A3": 420.0 / 297.0, # 1.4141 +} + + +def detect_page_splits( + img_bgr: np.ndarray, +) -> list: + """Detect if the image is a multi-page spread and return split rectangles. + + Uses **brightness** (not ink density) to find the spine area: + the scanner bed produces a characteristic gray strip where pages meet, + which is darker than the white paper on either side. + + Returns a list of page dicts ``{x, y, width, height, page_index}`` + or an empty list if only one page is detected. + """ + h, w = img_bgr.shape[:2] + + # Only check landscape-ish images (width > height * 1.15) + if w < h * 1.15: + return [] + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + + # Column-mean brightness (0-255) — the spine is darker (gray scanner bed) + col_brightness = np.mean(gray, axis=0).astype(np.float64) + + # Heavy smoothing to ignore individual text lines + kern = max(11, w // 50) + if kern % 2 == 0: + kern += 1 + brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same") + + # Page paper is bright (typically > 200), spine/scanner bed is darker + page_brightness = float(np.max(brightness_smooth)) + if page_brightness < 100: + return [] # Very dark image, skip + + # Spine threshold: significantly darker than the page + spine_thresh = page_brightness * 0.88 + + # Search in center region (30-70% of width) + center_lo = int(w * 0.30) + center_hi = int(w * 0.70) + + # Find the darkest valley in the center region + center_brightness = brightness_smooth[center_lo:center_hi] + darkest_val = float(np.min(center_brightness)) + + if darkest_val >= spine_thresh: + logger.debug("No spine detected: min brightness %.0f >= threshold %.0f", + darkest_val, spine_thresh) + return [] + + # Find ALL contiguous dark runs in the center region + is_dark = center_brightness < spine_thresh + dark_runs: list = [] + run_start = -1 + for i in range(len(is_dark)): + if is_dark[i]: + if run_start < 0: + run_start = i + else: + if run_start >= 0: + dark_runs.append((run_start, i)) + run_start = -1 + if run_start >= 0: + dark_runs.append((run_start, len(is_dark))) + + # Filter out runs that are too narrow (< 1% of image width) + min_spine_px = int(w * 0.01) + dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px] + + if not dark_runs: + logger.debug("No dark runs wider than %dpx in center region", min_spine_px) + return [] + + # Score each dark run: prefer centered, dark, narrow valleys + center_region_len = center_hi - center_lo + image_center_in_region = (w * 0.5 - center_lo) + best_score = -1.0 + best_start, best_end = dark_runs[0] + + for rs, re in dark_runs: + run_width = re - rs + run_center = (rs + re) / 2.0 + + sigma = center_region_len * 0.15 + dist = abs(run_center - image_center_in_region) + center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2)) + + run_brightness = float(np.mean(center_brightness[rs:re])) + darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh) + + width_frac = run_width / w + if width_frac <= 0.05: + narrowness_bonus = 1.0 + elif width_frac <= 0.15: + narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 + else: + narrowness_bonus = 0.0 + + score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus) + + logger.debug( + "Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f", + center_lo + rs, center_lo + re, run_width, + center_factor, darkness_factor, narrowness_bonus, score, + ) + + if score > best_score: + best_score = score + best_start, best_end = rs, re + + spine_w = best_end - best_start + spine_x = center_lo + best_start + spine_center = spine_x + spine_w // 2 + + logger.debug( + "Best spine candidate: x=%d..%d (w=%d), score=%.4f", + spine_x, spine_x + spine_w, spine_w, best_score, + ) + + # Verify: must have bright (paper) content on BOTH sides + left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x])) + right_end = center_lo + best_end + right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)])) + + if left_brightness < spine_thresh or right_brightness < spine_thresh: + logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f", + left_brightness, right_brightness, spine_thresh) + return [] + + logger.info( + "Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, " + "left_paper=%.0f, right_paper=%.0f", + spine_x, right_end, spine_w, darkest_val, page_brightness, + left_brightness, right_brightness, + ) + + # Split at the spine center + split_points = [spine_center] + + # Build page rectangles + pages: list = [] + prev_x = 0 + for i, sx in enumerate(split_points): + pages.append({"x": prev_x, "y": 0, "width": sx - prev_x, + "height": h, "page_index": i}) + prev_x = sx + pages.append({"x": prev_x, "y": 0, "width": w - prev_x, + "height": h, "page_index": len(split_points)}) + + # Filter out tiny pages (< 15% of total width) + pages = [p for p in pages if p["width"] >= w * 0.15] + if len(pages) < 2: + return [] + + # Re-index + for i, p in enumerate(pages): + p["page_index"] = i + + logger.info( + "Page split detected: %d pages, spine_w=%d, split_points=%s", + len(pages), spine_w, split_points, + ) + return pages + + +def detect_and_crop_page( + img_bgr: np.ndarray, + margin_frac: float = 0.01, +) -> Tuple[np.ndarray, Dict[str, Any]]: + """Detect content boundary and crop scanner/book borders. + + Algorithm (4-edge detection): + 1. Adaptive threshold -> binary (text=255, bg=0) + 2. Left edge: spine-shadow detection via grayscale column means, + fallback to binary vertical projection + 3. Right edge: binary vertical projection (last ink column) + 4. Top/bottom edges: binary horizontal projection + 5. Sanity checks, then crop with configurable margin + + Args: + img_bgr: Input BGR image (should already be deskewed/dewarped) + margin_frac: Extra margin around content (fraction of dimension, default 1%) + + Returns: + Tuple of (cropped_image, result_dict) + """ + h, w = img_bgr.shape[:2] + total_area = h * w + + result: Dict[str, Any] = { + "crop_applied": False, + "crop_rect": None, + "crop_rect_pct": None, + "original_size": {"width": w, "height": h}, + "cropped_size": {"width": w, "height": h}, + "detected_format": None, + "format_confidence": 0.0, + "aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4), + "border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0}, + } + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + + # --- Binarise with adaptive threshold --- + binary = cv2.adaptiveThreshold( + gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, blockSize=51, C=15, + ) + + # --- Edge detection --- + left_edge = _detect_left_edge_shadow(gray, binary, w, h) + right_edge = _detect_right_edge_shadow(gray, binary, w, h) + top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h) + + # Compute border fractions + border_top = top_edge / h + border_bottom = (h - bottom_edge) / h + border_left = left_edge / w + border_right = (w - right_edge) / w + + result["border_fractions"] = { + "top": round(border_top, 4), + "bottom": round(border_bottom, 4), + "left": round(border_left, 4), + "right": round(border_right, 4), + } + + # Sanity: only crop if at least one edge has > 2% border + min_border = 0.02 + if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]): + logger.info("All borders < %.0f%% — no crop needed", min_border * 100) + result["detected_format"], result["format_confidence"] = _detect_format(w, h) + return img_bgr, result + + # Add margin + margin_x = int(w * margin_frac) + margin_y = int(h * margin_frac) + + crop_x = max(0, left_edge - margin_x) + crop_y = max(0, top_edge - margin_y) + crop_x2 = min(w, right_edge + margin_x) + crop_y2 = min(h, bottom_edge + margin_y) + + crop_w = crop_x2 - crop_x + crop_h = crop_y2 - crop_y + + # Sanity: cropped area must be >= 40% of original + if crop_w * crop_h < 0.40 * total_area: + logger.warning("Cropped area too small (%.0f%%) — skipping crop", + 100.0 * crop_w * crop_h / total_area) + result["detected_format"], result["format_confidence"] = _detect_format(w, h) + return img_bgr, result + + cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy() + + detected_format, format_confidence = _detect_format(crop_w, crop_h) + + result["crop_applied"] = True + result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h} + result["crop_rect_pct"] = { + "x": round(100.0 * crop_x / w, 2), + "y": round(100.0 * crop_y / h, 2), + "width": round(100.0 * crop_w / w, 2), + "height": round(100.0 * crop_h / h, 2), + } + result["cropped_size"] = {"width": crop_w, "height": crop_h} + result["detected_format"] = detected_format + result["format_confidence"] = format_confidence + result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4) + + logger.info( + "Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), " + "borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%", + w, h, crop_w, crop_h, detected_format, format_confidence * 100, + border_top * 100, border_bottom * 100, + border_left * 100, border_right * 100, + ) + + return cropped, result + + +# --------------------------------------------------------------------------- +# Format detection (kept as optional metadata) +# --------------------------------------------------------------------------- + +def _detect_format(width: int, height: int) -> Tuple[str, float]: + """Detect paper format from dimensions by comparing aspect ratios.""" + if width <= 0 or height <= 0: + return "unknown", 0.0 + + aspect = max(width, height) / min(width, height) + + best_format = "unknown" + best_diff = float("inf") + + for fmt, expected_ratio in PAPER_FORMATS.items(): + diff = abs(aspect - expected_ratio) + if diff < best_diff: + best_diff = diff + best_format = fmt + + confidence = max(0.0, 1.0 - best_diff * 5.0) + + if confidence < 0.3: + return "unknown", 0.0 + + return best_format, round(confidence, 3) diff --git a/klausur-service/backend/ocr/pipeline/page_crop_edges.py b/klausur-service/backend/ocr/pipeline/page_crop_edges.py new file mode 100644 index 0000000..b231078 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/page_crop_edges.py @@ -0,0 +1,388 @@ +""" +Page Crop - Edge Detection Helpers + +Spine shadow detection, gutter continuity analysis, projection-based +edge detection, and narrow-run filtering for content cropping. + +Extracted from page_crop.py to keep files under 500 LOC. +License: Apache 2.0 +""" + +import logging +from typing import Optional, Tuple + +import cv2 +import numpy as np + +logger = logging.getLogger(__name__) + +# Minimum ink density (fraction of pixels) to count a row/column as "content" +_INK_THRESHOLD = 0.003 # 0.3% + +# Minimum run length (fraction of dimension) to keep — shorter runs are noise +_MIN_RUN_FRAC = 0.005 # 0.5% + + +def _detect_spine_shadow( + gray: np.ndarray, + search_region: np.ndarray, + offset_x: int, + w: int, + side: str, +) -> Optional[int]: + """Find the book spine center (darkest point) in a scanner shadow. + + The scanner produces a gray strip where the book spine presses against + the glass. The darkest column in that strip is the spine center — + that's where we crop. + + Distinguishes real spine shadows from text content by checking: + 1. Strong brightness range (> 40 levels) + 2. Darkest point is genuinely dark (< 180 mean brightness) + 3. The dark area is a NARROW valley, not a text-content plateau + 4. Brightness rises significantly toward the page content side + + Args: + gray: Full grayscale image (for context). + search_region: Column slice of the grayscale image to search in. + offset_x: X offset of search_region relative to full image. + w: Full image width. + side: 'left' or 'right' (for logging). + + Returns: + X coordinate (in full image) of the spine center, or None. + """ + region_w = search_region.shape[1] + if region_w < 10: + return None + + # Column-mean brightness in the search region + col_means = np.mean(search_region, axis=0).astype(np.float64) + + # Smooth with boxcar kernel (width = 1% of image width, min 5) + kernel_size = max(5, w // 100) + if kernel_size % 2 == 0: + kernel_size += 1 + kernel = np.ones(kernel_size) / kernel_size + smoothed_raw = np.convolve(col_means, kernel, mode="same") + + # Trim convolution edge artifacts (edges are zero-padded -> artificially low) + margin = kernel_size // 2 + if region_w <= 2 * margin + 10: + return None + smoothed = smoothed_raw[margin:region_w - margin] + trim_offset = margin # offset of smoothed[0] relative to search_region + + val_min = float(np.min(smoothed)) + val_max = float(np.max(smoothed)) + shadow_range = val_max - val_min + + # --- Check 1: Strong brightness gradient --- + if shadow_range <= 40: + logger.debug( + "%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range, + ) + return None + + # --- Check 2: Darkest point must be genuinely dark --- + if val_min > 180: + logger.debug( + "%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min, + ) + return None + + spine_idx = int(np.argmin(smoothed)) # index in trimmed array + spine_local = spine_idx + trim_offset # index in search_region + trimmed_len = len(smoothed) + + # --- Check 3: Valley width (spine is narrow, text plateau is wide) --- + valley_thresh = val_min + shadow_range * 0.20 + valley_mask = smoothed < valley_thresh + valley_width = int(np.sum(valley_mask)) + max_valley_frac = 0.50 + if valley_width > trimmed_len * max_valley_frac: + logger.debug( + "%s edge: no spine (valley too wide: %d/%d = %.0f%%)", + side.capitalize(), valley_width, trimmed_len, + 100.0 * valley_width / trimmed_len, + ) + return None + + # --- Check 4: Brightness must rise toward page content --- + rise_check_w = max(5, trimmed_len // 5) + if side == "left": + right_start = min(spine_idx + 5, trimmed_len - 1) + right_end = min(right_start + rise_check_w, trimmed_len) + if right_end > right_start: + rise_brightness = float(np.mean(smoothed[right_start:right_end])) + rise = rise_brightness - val_min + if rise < shadow_range * 0.3: + logger.debug( + "%s edge: no spine (insufficient rise: %.0f, need %.0f)", + side.capitalize(), rise, shadow_range * 0.3, + ) + return None + else: # right + left_end = max(spine_idx - 5, 0) + left_start = max(left_end - rise_check_w, 0) + if left_end > left_start: + rise_brightness = float(np.mean(smoothed[left_start:left_end])) + rise = rise_brightness - val_min + if rise < shadow_range * 0.3: + logger.debug( + "%s edge: no spine (insufficient rise: %.0f, need %.0f)", + side.capitalize(), rise, shadow_range * 0.3, + ) + return None + + spine_x = offset_x + spine_local + + logger.info( + "%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)", + side.capitalize(), spine_x, val_min, shadow_range, valley_width, + ) + return spine_x + + +def _detect_gutter_continuity( + gray: np.ndarray, + search_region: np.ndarray, + offset_x: int, + w: int, + side: str, +) -> Optional[int]: + """Detect gutter shadow via vertical continuity analysis. + + Camera book scans produce a subtle brightness gradient at the gutter + that is too faint for scanner-shadow detection (range < 40). However, + the gutter shadow has a unique property: it runs **continuously from + top to bottom** without interruption. + + Algorithm: + 1. Divide image into N horizontal strips (~60px each) + 2. For each column, compute what fraction of strips are darker than + the page median (from the center 50% of the full image) + 3. A "gutter column" has >= 75% of strips darker than page_median - d + 4. Smooth the dark-fraction profile and find the transition point + 5. Validate: gutter band must be 0.5%-10% of image width + """ + region_h, region_w = search_region.shape[:2] + if region_w < 20 or region_h < 100: + return None + + # --- 1. Divide into horizontal strips --- + strip_target_h = 60 + n_strips = max(10, region_h // strip_target_h) + strip_h = region_h // n_strips + + strip_means = np.zeros((n_strips, region_w), dtype=np.float64) + for s in range(n_strips): + y0 = s * strip_h + y1 = min((s + 1) * strip_h, region_h) + strip_means[s] = np.mean(search_region[y0:y1, :], axis=0) + + # --- 2. Page median from center 50% of full image --- + center_lo = w // 4 + center_hi = 3 * w // 4 + page_median = float(np.median(gray[:, center_lo:center_hi])) + + dark_thresh = page_median - 5.0 + + if page_median < 180: + return None + + # --- 3. Per-column dark fraction --- + dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64) + dark_frac = dark_count / n_strips + + # --- 4. Smooth and find transition --- + smooth_w = max(5, w // 100) + if smooth_w % 2 == 0: + smooth_w += 1 + kernel = np.ones(smooth_w) / smooth_w + frac_smooth = np.convolve(dark_frac, kernel, mode="same") + + margin = smooth_w // 2 + if region_w <= 2 * margin + 10: + return None + + transition_thresh = 0.50 + peak_frac = float(np.max(frac_smooth[margin:region_w - margin])) + + if peak_frac < 0.70: + logger.debug( + "%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac, + ) + return None + + peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin + gutter_inner = None + + if side == "right": + for x in range(peak_x, margin, -1): + if frac_smooth[x] < transition_thresh: + gutter_inner = x + 1 + break + else: + for x in range(peak_x, region_w - margin): + if frac_smooth[x] < transition_thresh: + gutter_inner = x - 1 + break + + if gutter_inner is None: + return None + + # --- 5. Validate gutter width --- + if side == "right": + gutter_width = region_w - gutter_inner + else: + gutter_width = gutter_inner + + min_gutter = max(3, int(w * 0.005)) + max_gutter = int(w * 0.10) + + if gutter_width < min_gutter: + logger.debug( + "%s gutter: too narrow (%dpx < %dpx)", side.capitalize(), + gutter_width, min_gutter, + ) + return None + + if gutter_width > max_gutter: + logger.debug( + "%s gutter: too wide (%dpx > %dpx)", side.capitalize(), + gutter_width, max_gutter, + ) + return None + + if side == "right": + gutter_brightness = float(np.mean(strip_means[:, gutter_inner:])) + else: + gutter_brightness = float(np.mean(strip_means[:, :gutter_inner])) + + brightness_drop = page_median - gutter_brightness + if brightness_drop < 3: + logger.debug( + "%s gutter: insufficient brightness drop (%.1f levels)", + side.capitalize(), brightness_drop, + ) + return None + + gutter_x = offset_x + gutter_inner + + logger.info( + "%s gutter (continuity): x=%d, width=%dpx (%.1f%%), " + "brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f", + side.capitalize(), gutter_x, gutter_width, + 100.0 * gutter_width / w, gutter_brightness, page_median, + brightness_drop, float(frac_smooth[gutter_inner]), + ) + return gutter_x + + +def _detect_left_edge_shadow( + gray: np.ndarray, + binary: np.ndarray, + w: int, + h: int, +) -> int: + """Detect left content edge, accounting for book-spine shadow. + + Tries three methods in order: + 1. Scanner spine-shadow (dark gradient, range > 40) + 2. Camera gutter continuity (subtle shadow running top-to-bottom) + 3. Binary projection fallback (first ink column) + """ + search_w = max(1, w // 4) + spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left") + if spine_x is not None: + return spine_x + + gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left") + if gutter_x is not None: + return gutter_x + + return _detect_edge_projection(binary, axis=0, from_start=True, dim=w) + + +def _detect_right_edge_shadow( + gray: np.ndarray, + binary: np.ndarray, + w: int, + h: int, +) -> int: + """Detect right content edge, accounting for book-spine shadow. + + Tries three methods in order: + 1. Scanner spine-shadow (dark gradient, range > 40) + 2. Camera gutter continuity (subtle shadow running top-to-bottom) + 3. Binary projection fallback (last ink column) + """ + search_w = max(1, w // 4) + right_start = w - search_w + spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right") + if spine_x is not None: + return spine_x + + gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right") + if gutter_x is not None: + return gutter_x + + return _detect_edge_projection(binary, axis=0, from_start=False, dim=w) + + +def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]: + """Detect top and bottom content edges via binary horizontal projection.""" + top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h) + bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h) + return top, bottom + + +def _detect_edge_projection( + binary: np.ndarray, + axis: int, + from_start: bool, + dim: int, +) -> int: + """Find the first/last row or column with ink density above threshold. + + axis=0 -> project vertically (column densities) -> returns x position + axis=1 -> project horizontally (row densities) -> returns y position + + Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension. + """ + projection = np.mean(binary, axis=axis) / 255.0 + + ink_mask = projection >= _INK_THRESHOLD + + min_run = max(1, int(dim * _MIN_RUN_FRAC)) + ink_mask = _filter_narrow_runs(ink_mask, min_run) + + ink_positions = np.where(ink_mask)[0] + if len(ink_positions) == 0: + return 0 if from_start else dim + + if from_start: + return int(ink_positions[0]) + else: + return int(ink_positions[-1]) + + +def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray: + """Remove True-runs shorter than min_run pixels.""" + if min_run <= 1: + return mask + + result = mask.copy() + n = len(result) + i = 0 + while i < n: + if result[i]: + start = i + while i < n and result[i]: + i += 1 + if i - start < min_run: + result[start:i] = False + else: + i += 1 + return result diff --git a/klausur-service/backend/ocr/pipeline/page_sub_sessions.py b/klausur-service/backend/ocr/pipeline/page_sub_sessions.py new file mode 100644 index 0000000..595e1eb --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/page_sub_sessions.py @@ -0,0 +1,189 @@ +""" +Sub-session creation for multi-page spreads. + +Used by both the page-split and crop steps when a double-page scan is detected. +""" + +import logging +import uuid as uuid_mod +from typing import Any, Dict, List + +import cv2 +import numpy as np + +from .page_crop import detect_and_crop_page +from .session_store import ( + create_session_db, + get_sub_sessions, + update_session_db, +) +from .orientation_crop_helpers import get_cache_ref + +logger = logging.getLogger(__name__) + + +async def create_page_sub_sessions( + parent_session_id: str, + parent_cached: dict, + full_img_bgr: np.ndarray, + page_splits: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Create sub-sessions for each detected page in a multi-page spread. + + Each page region is individually cropped, then stored as a sub-session + with its own cropped image ready for the rest of the pipeline. + """ + # Check for existing sub-sessions (idempotent) + existing = await get_sub_sessions(parent_session_id) + if existing: + return [ + {"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)} + for i, s in enumerate(existing) + ] + + parent_name = parent_cached.get("name", "Scan") + parent_filename = parent_cached.get("filename", "scan.png") + + sub_sessions: List[Dict[str, Any]] = [] + + for page in page_splits: + pi = page["page_index"] + px, py = page["x"], page["y"] + pw, ph = page["width"], page["height"] + + # Extract page region + page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy() + + # Crop each page individually (remove its own borders) + cropped_page, page_crop_info = detect_and_crop_page(page_bgr) + + # Encode as PNG + ok, png_buf = cv2.imencode(".png", cropped_page) + page_png = png_buf.tobytes() if ok else b"" + + sub_id = str(uuid_mod.uuid4()) + sub_name = f"{parent_name} — Seite {pi + 1}" + + await create_session_db( + session_id=sub_id, + name=sub_name, + filename=parent_filename, + original_png=page_png, + ) + + # Pre-populate: set cropped = original (already cropped) + await update_session_db( + sub_id, + cropped_png=page_png, + crop_result=page_crop_info, + current_step=5, + ) + + ch, cw = cropped_page.shape[:2] + sub_sessions.append({ + "id": sub_id, + "name": sub_name, + "page_index": pi, + "source_rect": page, + "cropped_size": {"width": cw, "height": ch}, + "detected_format": page_crop_info.get("detected_format"), + }) + + logger.info( + "Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d", + sub_id, pi + 1, px, pw, cw, ch, + ) + + return sub_sessions + + +async def create_page_sub_sessions_full( + parent_session_id: str, + parent_cached: dict, + full_img_bgr: np.ndarray, + page_splits: List[Dict[str, Any]], + start_step: int = 2, +) -> List[Dict[str, Any]]: + """Create sub-sessions for each page with RAW regions for full pipeline processing. + + Unlike ``create_page_sub_sessions`` (used by the crop step), these + sub-sessions store the *uncropped* page region and start at + ``start_step`` (default 2 = ready for deskew; 1 if orientation still + needed). Each page goes through its own pipeline independently, + which is essential for book spreads where each page has a different tilt. + """ + _cache = get_cache_ref() + + # Idempotent: reuse existing sub-sessions + existing = await get_sub_sessions(parent_session_id) + if existing: + return [ + {"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)} + for i, s in enumerate(existing) + ] + + parent_name = parent_cached.get("name", "Scan") + parent_filename = parent_cached.get("filename", "scan.png") + + sub_sessions: List[Dict[str, Any]] = [] + + for page in page_splits: + pi = page["page_index"] + px, py = page["x"], page["y"] + pw, ph = page["width"], page["height"] + + # Extract RAW page region — NO individual cropping here; each + # sub-session will run its own crop step after deskew + dewarp. + page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy() + + # Encode as PNG + ok, png_buf = cv2.imencode(".png", page_bgr) + page_png = png_buf.tobytes() if ok else b"" + + sub_id = str(uuid_mod.uuid4()) + sub_name = f"{parent_name} — Seite {pi + 1}" + + await create_session_db( + session_id=sub_id, + name=sub_name, + filename=parent_filename, + original_png=page_png, + ) + + # start_step=2 -> ready for deskew (orientation already done on spread) + # start_step=1 -> needs its own orientation (split from original image) + await update_session_db(sub_id, current_step=start_step) + + # Cache the BGR so the pipeline can start immediately + _cache[sub_id] = { + "id": sub_id, + "filename": parent_filename, + "name": sub_name, + "original_bgr": page_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": start_step, + } + + rh, rw = page_bgr.shape[:2] + sub_sessions.append({ + "id": sub_id, + "name": sub_name, + "page_index": pi, + "source_rect": page, + "image_size": {"width": rw, "height": rh}, + }) + + logger.info( + "Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d", + sub_id, pi + 1, px, pw, rw, rh, + ) + + return sub_sessions diff --git a/klausur-service/backend/ocr/pipeline/postprocess.py b/klausur-service/backend/ocr/pipeline/postprocess.py new file mode 100644 index 0000000..9c63206 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/postprocess.py @@ -0,0 +1,26 @@ +""" +OCR Pipeline Postprocessing API — composite router assembling LLM review, +reconstruction, export, validation, image detection/generation, and +handwriting removal endpoints. + +Split into sub-modules: + ocr_pipeline_llm_review — LLM review + apply corrections + ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX + ocr_pipeline_validation — image detection, generation, validation, handwriting removal + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +from fastapi import APIRouter + +from .llm_review import router as _llm_review_router +from .reconstruction import router as _reconstruction_router +from .validation import router as _validation_router + +# Composite router — drop-in replacement for the old monolithic router. +# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``. +router = APIRouter() +router.include_router(_llm_review_router) +router.include_router(_reconstruction_router) +router.include_router(_validation_router) diff --git a/klausur-service/backend/ocr/pipeline/reconstruction.py b/klausur-service/backend/ocr/pipeline/reconstruction.py new file mode 100644 index 0000000..0cff5a8 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/reconstruction.py @@ -0,0 +1,362 @@ +""" +OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Dict + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from .session_store import ( + get_session_db, + get_sub_sessions, + update_session_db, +) +from .common import _cache + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 9: Reconstruction + Fabric JSON export +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction") +async def save_reconstruction(session_id: str, request: Request): + """Save edited cell texts from reconstruction step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + body = await request.json() + cell_updates = body.get("cells", []) + + if not cell_updates: + await update_session_db(session_id, current_step=10) + return {"session_id": session_id, "updated": 0} + + # Build update map: cell_id -> new text + update_map = {c["cell_id"]: c["text"] for c in cell_updates} + + # Separate sub-session updates (cell_ids prefixed with "box{N}_") + sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} + main_updates: Dict[str, str] = {} + for cell_id, text in update_map.items(): + m = re.match(r'^box(\d+)_(.+)$', cell_id) + if m: + bi = int(m.group(1)) + original_id = m.group(2) + sub_updates.setdefault(bi, {})[original_id] = text + else: + main_updates[cell_id] = text + + # Update main session cells + cells = word_result.get("cells", []) + updated_count = 0 + for cell in cells: + if cell["cell_id"] in main_updates: + cell["text"] = main_updates[cell["cell_id"]] + cell["status"] = "edited" + updated_count += 1 + + word_result["cells"] = cells + + # Also update vocab_entries if present + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if entries: + for entry in entries: + row_idx = entry.get("row_index", -1) + for col_idx, field_name in enumerate(["english", "german", "example"]): + cell_id = f"R{row_idx:02d}_C{col_idx}" + cell_id_alt = f"R{row_idx}_C{col_idx}" + new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) + if new_text is not None: + entry[field_name] = new_text + + word_result["vocab_entries"] = entries + if "entries" in word_result: + word_result["entries"] = entries + + await update_session_db(session_id, word_result=word_result, current_step=10) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + # Route sub-session updates + sub_updated = 0 + if sub_updates: + subs = await get_sub_sessions(session_id) + sub_by_index = {s.get("box_index"): s["id"] for s in subs} + for bi, updates in sub_updates.items(): + sub_id = sub_by_index.get(bi) + if not sub_id: + continue + sub_session = await get_session_db(sub_id) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word: + continue + sub_cells = sub_word.get("cells", []) + for cell in sub_cells: + if cell["cell_id"] in updates: + cell["text"] = updates[cell["cell_id"]] + cell["status"] = "edited" + sub_updated += 1 + sub_word["cells"] = sub_cells + await update_session_db(sub_id, word_result=sub_word) + if sub_id in _cache: + _cache[sub_id]["word_result"] = sub_word + + total_updated = updated_count + sub_updated + logger.info(f"Reconstruction saved for session {session_id}: " + f"{updated_count} main + {sub_updated} sub-session cells updated") + + return { + "session_id": session_id, + "updated": total_updated, + "main_updated": updated_count, + "sub_updated": sub_updated, + } + + +@router.get("/sessions/{session_id}/reconstruction/fabric-json") +async def get_fabric_json(session_id: str): + """Return cell grid as Fabric.js-compatible JSON for the canvas editor.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = list(word_result.get("cells", [])) + img_w = word_result.get("image_width", 800) + img_h = word_result.get("image_height", 600) + + # Merge sub-session cells at box positions + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word or not sub_word.get("cells"): + continue + + bi = sub.get("box_index", 0) + if bi < len(box_zones): + box = box_zones[bi]["box"] + box_y, box_x = box["y"], box["x"] + else: + box_y, box_x = 0, 0 + + for cell in sub_word["cells"]: + cell_copy = dict(cell) + cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" + cell_copy["source"] = f"box_{bi}" + bbox = cell_copy.get("bbox_px", {}) + if bbox: + bbox = dict(bbox) + bbox["x"] = bbox.get("x", 0) + box_x + bbox["y"] = bbox.get("y", 0) + box_y + cell_copy["bbox_px"] = bbox + cells.append(cell_copy) + + from services.layout_reconstruction_service import cells_to_fabric_json + fabric_json = cells_to_fabric_json(cells, img_w, img_h) + + return fabric_json + + +# --------------------------------------------------------------------------- +# Vocab entries merged + PDF/DOCX export +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/vocab-entries/merged") +async def get_merged_vocab_entries(session_id: str): + """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") or {} + entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) + + for e in entries: + e.setdefault("source", "main") + + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") or {} + sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] + + bi = sub.get("box_index", 0) + box_y = 0 + if bi < len(box_zones): + box_y = box_zones[bi]["box"]["y"] + + for e in sub_entries: + e_copy = dict(e) + e_copy["source"] = f"box_{bi}" + e_copy["source_y"] = box_y + entries.append(e_copy) + + def _sort_key(e): + if e.get("source", "main") == "main": + return e.get("row_index", 0) * 100 + return e.get("source_y", 0) * 100 + e.get("row_index", 0) + + entries.sort(key=_sort_key) + + return { + "session_id": session_id, + "entries": entries, + "total": len(entries), + "sources": list(set(e.get("source", "main") for e in entries)), + } + + +@router.get("/sessions/{session_id}/reconstruction/export/pdf") +async def export_reconstruction_pdf(session_id: str): + """Export the reconstructed cell grid as a PDF table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + # Build table data: rows x columns + table_data: list[list[str]] = [] + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + table_data.append(header) + + for r in range(n_rows): + row_texts = [] + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + row_texts.append(cell.get("text", "") if cell else "") + table_data.append(row_texts) + + try: + from reportlab.lib.pagesizes import A4 + from reportlab.lib import colors + from reportlab.platypus import SimpleDocTemplate, Table, TableStyle + import io as _io + + buf = _io.BytesIO() + doc = SimpleDocTemplate(buf, pagesize=A4) + if not table_data or not table_data[0]: + raise HTTPException(status_code=400, detail="No data to export") + + t = Table(table_data) + t.setStyle(TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), + ('FONTSIZE', (0, 0), (-1, -1), 9), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('WORDWRAP', (0, 0), (-1, -1), True), + ])) + doc.build([t]) + buf.seek(0) + + return StreamingResponse( + buf, + media_type="application/pdf", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="reportlab not installed") + + +@router.get("/sessions/{session_id}/reconstruction/export/docx") +async def export_reconstruction_docx(session_id: str): + """Export the reconstructed cell grid as a DOCX table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + try: + from docx import Document + from docx.shared import Pt + import io as _io + + doc = Document() + doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1) + + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + + table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) + table.style = 'Table Grid' + + for ci, h in enumerate(header): + table.rows[0].cells[ci].text = h + + for r in range(n_rows): + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" + + buf = _io.BytesIO() + doc.save(buf) + buf.seek(0) + + return StreamingResponse( + buf, + media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="python-docx not installed") diff --git a/klausur-service/backend/ocr/pipeline/regression.py b/klausur-service/backend/ocr/pipeline/regression.py new file mode 100644 index 0000000..aa3037c --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/regression.py @@ -0,0 +1,22 @@ +""" +OCR Pipeline Regression Tests — barrel re-export. + +All implementation split into: + ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison + ocr_pipeline_regression_endpoints — FastAPI routes + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +# Helpers (used by grid_editor_api_grid.py) +from .regression_helpers import ( # noqa: F401 + _init_regression_table, + _persist_regression_run, + _extract_cells_for_comparison, + _build_reference_snapshot, + compare_grids, +) + +# Endpoints (router used by ocr_pipeline_api.py) +from .regression_endpoints import router # noqa: F401 diff --git a/klausur-service/backend/ocr/pipeline/regression_endpoints.py b/klausur-service/backend/ocr/pipeline/regression_endpoints.py new file mode 100644 index 0000000..d2d2c34 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/regression_endpoints.py @@ -0,0 +1,421 @@ +""" +OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression. + +Extracted from ocr_pipeline_regression.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, Optional + +from fastapi import APIRouter, HTTPException, Query + +from grid_editor_api import _build_grid_core +from .session_store import ( + get_session_db, + list_ground_truth_sessions_db, + update_session_db, +) +from .regression_helpers import ( + _build_reference_snapshot, + _init_regression_table, + _persist_regression_run, + compare_grids, + get_pool, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/mark-ground-truth") +async def mark_ground_truth( + session_id: str, + pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), +): + """Save the current build-grid result as ground-truth reference.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + grid_result = session.get("grid_editor_result") + if not grid_result or not grid_result.get("zones"): + raise HTTPException( + status_code=400, + detail="No grid_editor_result found. Run build-grid first.", + ) + + # Auto-detect pipeline from word_result if not provided + if not pipeline: + wr = session.get("word_result") or {} + engine = wr.get("ocr_engine", "") + if engine in ("kombi", "rapid_kombi"): + pipeline = "kombi" + elif engine == "paddle_direct": + pipeline = "paddle-direct" + else: + pipeline = "pipeline" + + reference = _build_reference_snapshot(grid_result, pipeline=pipeline) + + # Merge into existing ground_truth JSONB + gt = session.get("ground_truth") or {} + gt["build_grid_reference"] = reference + await update_session_db(session_id, ground_truth=gt, current_step=11) + + # Compare with auto-snapshot if available (shows what the user corrected) + auto_snapshot = gt.get("auto_grid_snapshot") + correction_diff = None + if auto_snapshot: + correction_diff = compare_grids(auto_snapshot, reference) + + logger.info( + "Ground truth marked for session %s: %d cells (corrections: %s)", + session_id, + len(reference["cells"]), + correction_diff["summary"] if correction_diff else "no auto-snapshot", + ) + + return { + "status": "ok", + "session_id": session_id, + "cells_saved": len(reference["cells"]), + "summary": reference["summary"], + "correction_diff": correction_diff, + } + + +@router.delete("/sessions/{session_id}/mark-ground-truth") +async def unmark_ground_truth(session_id: str): + """Remove the ground-truth reference from a session.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + if "build_grid_reference" not in gt: + raise HTTPException(status_code=404, detail="No ground truth reference found") + + del gt["build_grid_reference"] + await update_session_db(session_id, ground_truth=gt) + + logger.info("Ground truth removed for session %s", session_id) + return {"status": "ok", "session_id": session_id} + + +@router.get("/sessions/{session_id}/correction-diff") +async def get_correction_diff(session_id: str): + """Compare automatic OCR grid with manually corrected ground truth. + + Returns a diff showing exactly which cells the user corrected, + broken down by col_type (english, german, ipa, etc.). + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + auto_snapshot = gt.get("auto_grid_snapshot") + reference = gt.get("build_grid_reference") + + if not auto_snapshot: + raise HTTPException( + status_code=404, + detail="No auto_grid_snapshot found. Re-run build-grid to create one.", + ) + if not reference: + raise HTTPException( + status_code=404, + detail="No ground truth reference found. Mark as ground truth first.", + ) + + diff = compare_grids(auto_snapshot, reference) + + # Enrich with per-col_type breakdown + col_type_stats: Dict[str, Dict[str, int]] = {} + for cell_diff in diff.get("cell_diffs", []): + if cell_diff["type"] != "text_change": + continue + # Find col_type from reference cells + cell_id = cell_diff["cell_id"] + ref_cell = next( + (c for c in reference.get("cells", []) if c["cell_id"] == cell_id), + None, + ) + ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown" + if ct not in col_type_stats: + col_type_stats[ct] = {"total": 0, "corrected": 0} + col_type_stats[ct]["corrected"] += 1 + + # Count total cells per col_type from reference + for cell in reference.get("cells", []): + ct = cell.get("col_type", "unknown") + if ct not in col_type_stats: + col_type_stats[ct] = {"total": 0, "corrected": 0} + col_type_stats[ct]["total"] += 1 + + # Calculate accuracy per col_type + for ct, stats in col_type_stats.items(): + total = stats["total"] + corrected = stats["corrected"] + stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0 + + diff["col_type_breakdown"] = col_type_stats + + return diff + + +@router.get("/ground-truth-sessions") +async def list_ground_truth_sessions(): + """List all sessions that have a ground-truth reference.""" + sessions = await list_ground_truth_sessions_db() + + result = [] + for s in sessions: + gt = s.get("ground_truth") or {} + ref = gt.get("build_grid_reference", {}) + result.append({ + "session_id": s["id"], + "name": s.get("name", ""), + "filename": s.get("filename", ""), + "document_category": s.get("document_category"), + "pipeline": ref.get("pipeline"), + "saved_at": ref.get("saved_at"), + "summary": ref.get("summary", {}), + }) + + return {"sessions": result, "count": len(result)} + + +@router.post("/sessions/{session_id}/regression/run") +async def run_single_regression(session_id: str): + """Re-run build_grid for a single session and compare to ground truth.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + reference = gt.get("build_grid_reference") + if not reference: + raise HTTPException( + status_code=400, + detail="No ground truth reference found for this session", + ) + + # Re-compute grid without persisting + try: + new_result = await _build_grid_core(session_id, session) + except (ValueError, Exception) as e: + return { + "session_id": session_id, + "name": session.get("name", ""), + "status": "error", + "error": str(e), + } + + new_snapshot = _build_reference_snapshot(new_result) + diff = compare_grids(reference, new_snapshot) + + logger.info( + "Regression test session %s: %s (%d structural, %d cell diffs)", + session_id, diff["status"], + diff["summary"]["structural_changes"], + sum(v for k, v in diff["summary"].items() if k != "structural_changes"), + ) + + return { + "session_id": session_id, + "name": session.get("name", ""), + "status": diff["status"], + "diff": diff, + "reference_summary": reference.get("summary", {}), + "current_summary": new_snapshot.get("summary", {}), + } + + +@router.post("/regression/run") +async def run_all_regressions( + triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"), +): + """Re-run build_grid for ALL ground-truth sessions and compare.""" + start_time = time.monotonic() + sessions = await list_ground_truth_sessions_db() + + if not sessions: + return { + "status": "pass", + "message": "No ground truth sessions found", + "results": [], + "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, + } + + results = [] + passed = 0 + failed = 0 + errors = 0 + + for s in sessions: + session_id = s["id"] + gt = s.get("ground_truth") or {} + reference = gt.get("build_grid_reference") + if not reference: + continue + + # Re-load full session (list query may not include all JSONB fields) + full_session = await get_session_db(session_id) + if not full_session: + results.append({ + "session_id": session_id, + "name": s.get("name", ""), + "status": "error", + "error": "Session not found during re-load", + }) + errors += 1 + continue + + try: + new_result = await _build_grid_core(session_id, full_session) + except (ValueError, Exception) as e: + results.append({ + "session_id": session_id, + "name": s.get("name", ""), + "status": "error", + "error": str(e), + }) + errors += 1 + continue + + new_snapshot = _build_reference_snapshot(new_result) + diff = compare_grids(reference, new_snapshot) + + entry = { + "session_id": session_id, + "name": s.get("name", ""), + "status": diff["status"], + "diff_summary": diff["summary"], + "reference_summary": reference.get("summary", {}), + "current_summary": new_snapshot.get("summary", {}), + } + + # Include full diffs only for failures (keep response compact) + if diff["status"] == "fail": + entry["structural_diffs"] = diff["structural_diffs"] + entry["cell_diffs"] = diff["cell_diffs"] + failed += 1 + else: + passed += 1 + + results.append(entry) + + overall = "pass" if failed == 0 and errors == 0 else "fail" + duration_ms = int((time.monotonic() - start_time) * 1000) + + summary = { + "total": len(results), + "passed": passed, + "failed": failed, + "errors": errors, + } + + logger.info( + "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms", + overall, passed, failed, errors, len(results), duration_ms, + ) + + # Persist to DB + run_id = await _persist_regression_run( + status=overall, + summary=summary, + results=results, + duration_ms=duration_ms, + triggered_by=triggered_by, + ) + + return { + "status": overall, + "run_id": run_id, + "duration_ms": duration_ms, + "results": results, + "summary": summary, + } + + +@router.get("/regression/history") +async def get_regression_history( + limit: int = Query(20, ge=1, le=100), +): + """Get recent regression run history from the database.""" + try: + await _init_regression_table() + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, run_at, status, total, passed, failed, errors, + duration_ms, triggered_by + FROM regression_runs + ORDER BY run_at DESC + LIMIT $1 + """, + limit, + ) + return { + "runs": [ + { + "id": str(row["id"]), + "run_at": row["run_at"].isoformat() if row["run_at"] else None, + "status": row["status"], + "total": row["total"], + "passed": row["passed"], + "failed": row["failed"], + "errors": row["errors"], + "duration_ms": row["duration_ms"], + "triggered_by": row["triggered_by"], + } + for row in rows + ], + "count": len(rows), + } + except Exception as e: + logger.warning("Failed to fetch regression history: %s", e) + return {"runs": [], "count": 0, "error": str(e)} + + +@router.get("/regression/history/{run_id}") +async def get_regression_run_detail(run_id: str): + """Get detailed results of a specific regression run.""" + try: + await _init_regression_table() + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT * FROM regression_runs WHERE id = $1", + run_id, + ) + if not row: + raise HTTPException(status_code=404, detail="Run not found") + return { + "id": str(row["id"]), + "run_at": row["run_at"].isoformat() if row["run_at"] else None, + "status": row["status"], + "total": row["total"], + "passed": row["passed"], + "failed": row["failed"], + "errors": row["errors"], + "duration_ms": row["duration_ms"], + "triggered_by": row["triggered_by"], + "results": json.loads(row["results"]) if row["results"] else [], + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/klausur-service/backend/ocr/pipeline/regression_helpers.py b/klausur-service/backend/ocr/pipeline/regression_helpers.py new file mode 100644 index 0000000..e1a2d67 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/regression_helpers.py @@ -0,0 +1,207 @@ +""" +OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison. + +Extracted from ocr_pipeline_regression.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from .session_store import get_pool + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# DB persistence for regression runs +# --------------------------------------------------------------------------- + +async def _init_regression_table(): + """Ensure regression_runs table exists (idempotent).""" + pool = await get_pool() + async with pool.acquire() as conn: + migration_path = os.path.join( + os.path.dirname(__file__), + "migrations/008_regression_runs.sql", + ) + if os.path.exists(migration_path): + with open(migration_path, "r") as f: + sql = f.read() + await conn.execute(sql) + + +async def _persist_regression_run( + status: str, + summary: dict, + results: list, + duration_ms: int, + triggered_by: str = "manual", +) -> str: + """Save a regression run to the database. Returns the run ID.""" + try: + await _init_regression_table() + pool = await get_pool() + run_id = str(uuid.uuid4()) + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO regression_runs + (id, status, total, passed, failed, errors, duration_ms, results, triggered_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9) + """, + run_id, + status, + summary.get("total", 0), + summary.get("passed", 0), + summary.get("failed", 0), + summary.get("errors", 0), + duration_ms, + json.dumps(results), + triggered_by, + ) + logger.info("Regression run %s persisted: %s", run_id, status) + return run_id + except Exception as e: + logger.warning("Failed to persist regression run: %s", e) + return "" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]: + """Extract a flat list of cells from a grid_editor_result for comparison. + + Only keeps fields relevant for comparison: cell_id, row_index, col_index, + col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold. + """ + cells = [] + for zone in grid_result.get("zones", []): + for cell in zone.get("cells", []): + cells.append({ + "cell_id": cell.get("cell_id", ""), + "row_index": cell.get("row_index"), + "col_index": cell.get("col_index"), + "col_type": cell.get("col_type", ""), + "text": cell.get("text", ""), + }) + return cells + + +def _build_reference_snapshot( + grid_result: dict, + pipeline: Optional[str] = None, +) -> dict: + """Build a ground-truth reference snapshot from a grid_editor_result.""" + cells = _extract_cells_for_comparison(grid_result) + + total_zones = len(grid_result.get("zones", [])) + total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", [])) + total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", [])) + + snapshot = { + "saved_at": datetime.now(timezone.utc).isoformat(), + "version": 1, + "pipeline": pipeline, + "summary": { + "total_zones": total_zones, + "total_columns": total_columns, + "total_rows": total_rows, + "total_cells": len(cells), + }, + "cells": cells, + } + return snapshot + + +def compare_grids(reference: dict, current: dict) -> dict: + """Compare a reference grid snapshot with a newly computed one. + + Returns a diff report with: + - status: "pass" or "fail" + - structural_diffs: changes in zone/row/column counts + - cell_diffs: list of individual cell changes + """ + ref_summary = reference.get("summary", {}) + cur_summary = current.get("summary", {}) + + structural_diffs = [] + for key in ("total_zones", "total_columns", "total_rows", "total_cells"): + ref_val = ref_summary.get(key, 0) + cur_val = cur_summary.get(key, 0) + if ref_val != cur_val: + structural_diffs.append({ + "field": key, + "reference": ref_val, + "current": cur_val, + }) + + # Build cell lookup by cell_id + ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])} + cur_cells = {c["cell_id"]: c for c in current.get("cells", [])} + + cell_diffs: List[Dict[str, Any]] = [] + + # Check for missing cells (in reference but not in current) + for cell_id in ref_cells: + if cell_id not in cur_cells: + cell_diffs.append({ + "type": "cell_missing", + "cell_id": cell_id, + "reference_text": ref_cells[cell_id].get("text", ""), + }) + + # Check for added cells (in current but not in reference) + for cell_id in cur_cells: + if cell_id not in ref_cells: + cell_diffs.append({ + "type": "cell_added", + "cell_id": cell_id, + "current_text": cur_cells[cell_id].get("text", ""), + }) + + # Check for changes in shared cells + for cell_id in ref_cells: + if cell_id not in cur_cells: + continue + ref_cell = ref_cells[cell_id] + cur_cell = cur_cells[cell_id] + + if ref_cell.get("text", "") != cur_cell.get("text", ""): + cell_diffs.append({ + "type": "text_change", + "cell_id": cell_id, + "reference": ref_cell.get("text", ""), + "current": cur_cell.get("text", ""), + }) + + if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""): + cell_diffs.append({ + "type": "col_type_change", + "cell_id": cell_id, + "reference": ref_cell.get("col_type", ""), + "current": cur_cell.get("col_type", ""), + }) + + status = "pass" if not structural_diffs and not cell_diffs else "fail" + + return { + "status": status, + "structural_diffs": structural_diffs, + "cell_diffs": cell_diffs, + "summary": { + "structural_changes": len(structural_diffs), + "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"), + "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"), + "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"), + "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"), + }, + } diff --git a/klausur-service/backend/ocr/pipeline/reprocess.py b/klausur-service/backend/ocr/pipeline/reprocess.py new file mode 100644 index 0000000..2392c63 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/reprocess.py @@ -0,0 +1,94 @@ +""" +OCR Pipeline Reprocess Endpoint. + +POST /sessions/{session_id}/reprocess — clear downstream + restart from step. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict + +from fastapi import APIRouter, HTTPException, Request + +from .common import _cache +from .session_store import get_session_db, update_session_db + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["ocr-pipeline"]) + + +@router.post("/sessions/{session_id}/reprocess") +async def reprocess_session(session_id: str, request: Request): + """Re-run pipeline from a specific step, clearing downstream data. + + Body: {"from_step": 5} (1-indexed step number) + + Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) -> + Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10) + + Clears downstream results: + - from_step <= 1: orientation_result + all downstream + - from_step <= 2: deskew_result + all downstream + - from_step <= 3: dewarp_result + all downstream + - from_step <= 4: crop_result + all downstream + - from_step <= 5: column_result, row_result, word_result + - from_step <= 6: row_result, word_result + - from_step <= 7: word_result (cells, vocab_entries) + - from_step <= 8: word_result.llm_review only + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + body = await request.json() + from_step = body.get("from_step", 1) + if not isinstance(from_step, int) or from_step < 1 or from_step > 10: + raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") + + update_kwargs: Dict[str, Any] = {"current_step": from_step} + + # Clear downstream data based on from_step + # New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) -> + # Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11) + if from_step <= 8: + update_kwargs["word_result"] = None + elif from_step == 9: + # Only clear LLM review from word_result + word_result = session.get("word_result") + if word_result: + word_result.pop("llm_review", None) + word_result.pop("llm_corrections", None) + update_kwargs["word_result"] = word_result + + if from_step <= 7: + update_kwargs["row_result"] = None + if from_step <= 6: + update_kwargs["column_result"] = None + if from_step <= 4: + update_kwargs["crop_result"] = None + if from_step <= 3: + update_kwargs["dewarp_result"] = None + if from_step <= 2: + update_kwargs["deskew_result"] = None + if from_step <= 1: + update_kwargs["orientation_result"] = None + + await update_session_db(session_id, **update_kwargs) + + # Also clear cache + if session_id in _cache: + for key in list(update_kwargs.keys()): + if key != "current_step": + _cache[session_id][key] = update_kwargs[key] + _cache[session_id]["current_step"] = from_step + + logger.info(f"Session {session_id} reprocessing from step {from_step}") + + return { + "session_id": session_id, + "from_step": from_step, + "cleared": [k for k in update_kwargs if k != "current_step"], + } diff --git a/klausur-service/backend/ocr/pipeline/rows.py b/klausur-service/backend/ocr/pipeline/rows.py new file mode 100644 index 0000000..67179d7 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/rows.py @@ -0,0 +1,348 @@ +""" +OCR Pipeline - Row Detection Endpoints. + +Extracted from ocr_pipeline_api.py. +Handles row detection (auto + manual) and row ground truth. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException + +from cv_vocab_pipeline import ( + create_ocr_image, + detect_column_geometry, + detect_row_geometry, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, + ManualRowsRequest, + RowGroundTruthRequest, +) +from .session_store import ( + get_session_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Helper: Box-exclusion overlay (used by rows overlay and columns overlay) +# --------------------------------------------------------------------------- + +def _draw_box_exclusion_overlay( + img: np.ndarray, + zones: List[Dict], + *, + label: str = "BOX — separat verarbeitet", +) -> None: + """Draw red semi-transparent rectangles over box zones (in-place). + + Reusable for columns, rows, and words overlays. + """ + for zone in zones: + if zone.get("zone_type") != "box" or not zone.get("box"): + continue + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + + # Red semi-transparent fill (~25 %) + box_overlay = img.copy() + cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1) + cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img) + + # Border + cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2) + + # Label + cv2.putText(img, label, (bx + 10, by + bh - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + +# --------------------------------------------------------------------------- +# Row Detection Endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/rows") +async def detect_rows(session_id: str): + """Run row detection on the cropped (or dewarped) image using horizontal gap analysis.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if dewarped_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection") + + t0 = time.time() + + # Try to reuse cached word_dicts and inv from column detection + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + # Not cached — run column geometry to get intermediates + ocr_img = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img, dewarped_bgr) + if geo_result is None: + raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows") + _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + else: + left_x, right_x, top_y, bottom_y = content_bounds + + # Read zones from column_result to exclude box regions + session = await get_session_db(session_id) + column_result = (session or {}).get("column_result") or {} + is_sub_session = bool((session or {}).get("parent_session_id")) + + # Sub-sessions (box crops): use word-grouping instead of gap-based + # row detection. Box images are small with complex internal layouts + # (headings, sub-columns) where the horizontal projection approach + # merges rows. Word-grouping directly clusters words by Y proximity, + # which is more robust for these cases. + if is_sub_session and word_dicts: + from cv_layout import _build_rows_from_word_grouping + rows = _build_rows_from_word_grouping( + word_dicts, left_x, right_x, top_y, bottom_y, + right_x - left_x, bottom_y - top_y, + ) + logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows") + else: + zones = column_result.get("zones") or [] # zones can be None for sub-sessions + + # Collect box y-ranges for filtering. + # Use border_thickness to shrink the exclusion zone: the border pixels + # belong visually to the box frame, but text rows above/below the box + # may overlap with the border area and must not be clipped. + box_ranges = [] # [(y_start, y_end)] + box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin + box_ranges.append((box["y"], box["y"] + box["height"])) + # Inner range: shrink by border thickness so boundary rows aren't excluded + box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) + + if box_ranges and inv is not None: + # Combined-image approach: strip box regions from inv image, + # run row detection on the combined image, then remap y-coords back. + content_strips = [] # [(y_start, y_end)] in absolute coords + # Build content strips by subtracting box inner ranges from [top_y, bottom_y]. + # Using inner ranges means the border area is included in the content + # strips, so the last row above a box isn't clipped by the border. + sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0]) + strip_start = top_y + for by_start, by_end in sorted_boxes: + if by_start > strip_start: + content_strips.append((strip_start, by_start)) + strip_start = max(strip_start, by_end) + if strip_start < bottom_y: + content_strips.append((strip_start, bottom_y)) + + # Filter to strips with meaningful height + content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20] + + if content_strips: + # Stack content strips vertically + inv_strips = [inv[ys:ye, :] for ys, ye in content_strips] + combined_inv = np.vstack(inv_strips) + + # Filter word_dicts to only include words from content strips + combined_words = [] + cum_y = 0 + strip_offsets = [] # (combined_y_start, strip_height, abs_y_start) + for ys, ye in content_strips: + h = ye - ys + strip_offsets.append((cum_y, h, ys)) + for w in word_dicts: + w_abs_y = w['top'] + top_y # word y is relative to content top + w_center = w_abs_y + w['height'] / 2 + if ys <= w_center < ye: + # Remap to combined coordinates + w_copy = dict(w) + w_copy['top'] = cum_y + (w_abs_y - ys) + combined_words.append(w_copy) + cum_y += h + + # Run row detection on combined image + combined_h = combined_inv.shape[0] + rows = detect_row_geometry( + combined_inv, combined_words, left_x, right_x, 0, combined_h, + ) + + # Remap y-coordinates back to absolute page coords + def _combined_y_to_abs(cy: int) -> int: + for c_start, s_h, abs_start in strip_offsets: + if cy < c_start + s_h: + return abs_start + (cy - c_start) + last_c, last_h, last_abs = strip_offsets[-1] + return last_abs + last_h + + for r in rows: + abs_y = _combined_y_to_abs(r.y) + abs_y_end = _combined_y_to_abs(r.y + r.height) + r.y = abs_y + r.height = abs_y_end - abs_y + else: + rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + else: + # No boxes — standard row detection + rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + duration = time.time() - t0 + + # Assign zone_index based on which content zone each row falls in + # Build content zone list with indices + zones = column_result.get("zones") or [] + content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else [] + + # Build serializable result (exclude words to keep payload small) + rows_data = [] + for r in rows: + # Determine zone_index + zone_idx = 0 + row_center_y = r.y + r.height / 2 + for zi, zone in content_zones: + zy = zone["y"] + zh = zone["height"] + if zy <= row_center_y < zy + zh: + zone_idx = zi + break + + rd = { + "index": r.index, + "x": r.x, + "y": r.y, + "width": r.width, + "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + "zone_index": zone_idx, + } + rows_data.append(rd) + + type_counts = {} + for r in rows: + type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1 + + row_result = { + "rows": rows_data, + "summary": type_counts, + "total_rows": len(rows), + "duration_seconds": round(duration, 2), + } + + # Persist to DB — also invalidate word_result since rows changed + await update_session_db( + session_id, + row_result=row_result, + word_result=None, + current_step=7, + ) + + cached["row_result"] = row_result + cached.pop("word_result", None) + + logger.info(f"OCR Pipeline: rows session {session_id}: " + f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}") + + content_rows = sum(1 for r in rows if r.row_type == "content") + avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0 + await _append_pipeline_log(session_id, "rows", { + "total_rows": len(rows), + "content_rows": content_rows, + "artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0), + "avg_row_height_px": avg_height, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **row_result, + } + + +@router.post("/sessions/{session_id}/rows/manual") +async def set_manual_rows(session_id: str, req: ManualRowsRequest): + """Override detected rows with manual definitions.""" + row_result = { + "rows": req.rows, + "total_rows": len(req.rows), + "duration_seconds": 0, + "method": "manual", + } + + await update_session_db(session_id, row_result=row_result, word_result=None) + + if session_id in _cache: + _cache[session_id]["row_result"] = row_result + _cache[session_id].pop("word_result", None) + + logger.info(f"OCR Pipeline: manual rows session {session_id}: " + f"{len(req.rows)} rows set") + + return {"session_id": session_id, **row_result} + + +@router.post("/sessions/{session_id}/ground-truth/rows") +async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest): + """Save ground truth feedback for the row detection step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_rows": req.corrected_rows, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "row_result": session.get("row_result"), + } + ground_truth["rows"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@router.get("/sessions/{session_id}/ground-truth/rows") +async def get_row_ground_truth(session_id: str): + """Retrieve saved ground truth for row detection.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + rows_gt = ground_truth.get("rows") + if not rows_gt: + raise HTTPException(status_code=404, detail="No row ground truth saved") + + return { + "session_id": session_id, + "rows_gt": rows_gt, + "rows_auto": session.get("row_result"), + } diff --git a/klausur-service/backend/ocr/pipeline/scan_quality.py b/klausur-service/backend/ocr/pipeline/scan_quality.py new file mode 100644 index 0000000..d869140 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/scan_quality.py @@ -0,0 +1,102 @@ +""" +Scan Quality Assessment — Measures image quality before OCR. + +Computes blur score, contrast score, and an overall quality rating. +Used to gate enhancement steps and warn users about degraded scans. + +All operations use OpenCV (Apache-2.0), no additional dependencies. +""" + +import logging +from dataclasses import dataclass, asdict +from typing import Dict, Any + +import cv2 +import numpy as np + +logger = logging.getLogger(__name__) + +# Thresholds (empirically tuned on textbook scans) +BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry +CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast +CONFIDENCE_GOOD = 40 # OCR min confidence for good scans +CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans + + +@dataclass +class ScanQualityReport: + """Result of scan quality assessment.""" + blur_score: float # Laplacian variance (higher = sharper) + contrast_score: float # Grayscale std deviation (higher = more contrast) + brightness: float # Mean grayscale value (0-255) + is_blurry: bool + is_low_contrast: bool + is_degraded: bool # True if any quality issue detected + quality_pct: int # 0-100 overall quality estimate + recommended_min_conf: int # Recommended OCR confidence threshold + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport: + """ + Assess the quality of a scanned image. + + Uses: + - Laplacian variance for blur detection + - Grayscale standard deviation for contrast + - Mean brightness for exposure assessment + + Args: + img_bgr: BGR image (numpy array from OpenCV) + + Returns: + ScanQualityReport with scores and recommendations + """ + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + + # Blur detection: Laplacian variance + # Higher = sharper edges = better quality + laplacian = cv2.Laplacian(gray, cv2.CV_64F) + blur_score = float(laplacian.var()) + + # Contrast: standard deviation of grayscale + contrast_score = float(np.std(gray)) + + # Brightness: mean grayscale + brightness = float(np.mean(gray)) + + # Quality flags + is_blurry = blur_score < BLUR_THRESHOLD + is_low_contrast = contrast_score < CONTRAST_THRESHOLD + is_degraded = is_blurry or is_low_contrast + + # Overall quality percentage (simple weighted combination) + blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50) + contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50) + quality_pct = int(min(100, blur_pct + contrast_pct)) + + # Recommended confidence threshold + recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD + + report = ScanQualityReport( + blur_score=round(blur_score, 1), + contrast_score=round(contrast_score, 1), + brightness=round(brightness, 1), + is_blurry=is_blurry, + is_low_contrast=is_low_contrast, + is_degraded=is_degraded, + quality_pct=quality_pct, + recommended_min_conf=recommended_min_conf, + ) + + logger.info( + f"Scan quality: blur={report.blur_score} " + f"contrast={report.contrast_score} " + f"quality={report.quality_pct}% " + f"degraded={report.is_degraded} " + f"min_conf={report.recommended_min_conf}" + ) + + return report diff --git a/klausur-service/backend/ocr/pipeline/session_store.py b/klausur-service/backend/ocr/pipeline/session_store.py new file mode 100644 index 0000000..dc80c55 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/session_store.py @@ -0,0 +1,388 @@ +""" +OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions. + +Replaces in-memory storage with database persistence. +See migrations/002_ocr_pipeline_sessions.sql for schema. +""" + +import os +import uuid +import logging +import json +from typing import Optional, List, Dict, Any + +import asyncpg + +logger = logging.getLogger(__name__) + +# Database configuration (same as vocab_session_store) +DATABASE_URL = os.getenv( + "DATABASE_URL", + "postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db" +) + +# Connection pool (initialized lazily) +_pool: Optional[asyncpg.Pool] = None + + +async def get_pool() -> asyncpg.Pool: + """Get or create the database connection pool.""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + return _pool + + +async def init_ocr_pipeline_tables(): + """Initialize OCR pipeline tables if they don't exist.""" + pool = await get_pool() + async with pool.acquire() as conn: + tables_exist = await conn.fetchval(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'ocr_pipeline_sessions' + ) + """) + + if not tables_exist: + logger.info("Creating OCR pipeline tables...") + migration_path = os.path.join( + os.path.dirname(__file__), + "migrations/002_ocr_pipeline_sessions.sql" + ) + if os.path.exists(migration_path): + with open(migration_path, "r") as f: + sql = f.read() + await conn.execute(sql) + logger.info("OCR pipeline tables created successfully") + else: + logger.warning(f"Migration file not found: {migration_path}") + else: + logger.debug("OCR pipeline tables already exist") + + # Ensure new columns exist (idempotent ALTER TABLE) + await conn.execute(""" + ALTER TABLE ocr_pipeline_sessions + ADD COLUMN IF NOT EXISTS clean_png BYTEA, + ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB, + ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50), + ADD COLUMN IF NOT EXISTS doc_type_result JSONB, + ADD COLUMN IF NOT EXISTS document_category VARCHAR(50), + ADD COLUMN IF NOT EXISTS pipeline_log JSONB, + ADD COLUMN IF NOT EXISTS oriented_png BYTEA, + ADD COLUMN IF NOT EXISTS cropped_png BYTEA, + ADD COLUMN IF NOT EXISTS orientation_result JSONB, + ADD COLUMN IF NOT EXISTS crop_result JSONB, + ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE, + ADD COLUMN IF NOT EXISTS box_index INT, + ADD COLUMN IF NOT EXISTS grid_editor_result JSONB, + ADD COLUMN IF NOT EXISTS structure_result JSONB, + ADD COLUMN IF NOT EXISTS document_group_id UUID, + ADD COLUMN IF NOT EXISTS page_number INT + """) + + # Index for document group lookups + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group + ON ocr_pipeline_sessions (document_group_id) + WHERE document_group_id IS NOT NULL + """) + + +# ============================================================================= +# SESSION CRUD +# ============================================================================= + +async def create_session_db( + session_id: str, + name: str, + filename: str, + original_png: bytes, + parent_session_id: Optional[str] = None, + box_index: Optional[int] = None, + document_group_id: Optional[str] = None, + page_number: Optional[int] = None, +) -> Dict[str, Any]: + """Create a new OCR pipeline session. + + Args: + parent_session_id: If set, this is a sub-session for a box region. + box_index: 0-based index of the box this sub-session represents. + document_group_id: Groups multi-page uploads into one document. + page_number: 1-based page index within the document group. + """ + pool = await get_pool() + parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None + group_uuid = uuid.UUID(document_group_id) if document_group_id else None + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + INSERT INTO ocr_pipeline_sessions ( + id, name, filename, original_png, status, current_step, + parent_session_id, box_index, document_group_id, page_number + ) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8) + RETURNING id, name, filename, status, current_step, + orientation_result, crop_result, + deskew_result, dewarp_result, column_result, row_result, + word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, + document_category, pipeline_log, + grid_editor_result, structure_result, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at + """, uuid.UUID(session_id), name, filename, original_png, + parent_uuid, box_index, group_uuid, page_number) + + return _row_to_dict(row) + + +async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]: + """Get session metadata (without images).""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + SELECT id, name, filename, status, current_step, + orientation_result, crop_result, + deskew_result, dewarp_result, column_result, row_result, + word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, + document_category, pipeline_log, + grid_editor_result, structure_result, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at + FROM ocr_pipeline_sessions WHERE id = $1 + """, uuid.UUID(session_id)) + + if row: + return _row_to_dict(row) + return None + + +async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]: + """Load a single image (BYTEA) from the session.""" + column_map = { + "original": "original_png", + "oriented": "oriented_png", + "cropped": "cropped_png", + "deskewed": "deskewed_png", + "binarized": "binarized_png", + "dewarped": "dewarped_png", + "clean": "clean_png", + } + column = column_map.get(image_type) + if not column: + return None + + pool = await get_pool() + async with pool.acquire() as conn: + return await conn.fetchval( + f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1", + uuid.UUID(session_id) + ) + + +async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]: + """Update session fields dynamically.""" + pool = await get_pool() + + fields = [] + values = [] + param_idx = 1 + + allowed_fields = { + 'name', 'filename', 'status', 'current_step', + 'original_png', 'oriented_png', 'cropped_png', + 'deskewed_png', 'binarized_png', 'dewarped_png', + 'clean_png', 'handwriting_removal_meta', + 'orientation_result', 'crop_result', + 'deskew_result', 'dewarp_result', 'column_result', 'row_result', + 'word_result', 'ground_truth', 'auto_shear_degrees', + 'doc_type', 'doc_type_result', + 'document_category', 'pipeline_log', + 'grid_editor_result', 'structure_result', + 'parent_session_id', 'box_index', + 'document_group_id', 'page_number', + } + + jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'} + + for key, value in kwargs.items(): + if key in allowed_fields: + fields.append(f"{key} = ${param_idx}") + if key in jsonb_fields and value is not None and not isinstance(value, str): + value = json.dumps(value) + values.append(value) + param_idx += 1 + + if not fields: + return await get_session_db(session_id) + + # Always update updated_at + fields.append(f"updated_at = NOW()") + + values.append(uuid.UUID(session_id)) + + async with pool.acquire() as conn: + row = await conn.fetchrow(f""" + UPDATE ocr_pipeline_sessions + SET {', '.join(fields)} + WHERE id = ${param_idx} + RETURNING id, name, filename, status, current_step, + orientation_result, crop_result, + deskew_result, dewarp_result, column_result, row_result, + word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, + document_category, pipeline_log, + grid_editor_result, structure_result, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at + """, *values) + + if row: + return _row_to_dict(row) + return None + + +async def list_sessions_db( + limit: int = 50, + include_sub_sessions: bool = False, +) -> List[Dict[str, Any]]: + """List sessions (metadata only, no images). + + By default, sub-sessions (those with parent_session_id) are excluded. + Pass include_sub_sessions=True to include them. + """ + pool = await get_pool() + async with pool.acquire() as conn: + where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')" + rows = await conn.fetch(f""" + SELECT id, name, filename, status, current_step, + document_category, doc_type, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at, + ground_truth + FROM ocr_pipeline_sessions + {where} + ORDER BY created_at DESC + LIMIT $1 + """, limit) + + results = [] + for row in rows: + d = _row_to_dict(row) + # Derive is_ground_truth flag from JSONB, then drop the heavy field + gt = d.pop("ground_truth", None) or {} + d["is_ground_truth"] = bool(gt.get("build_grid_reference")) + results.append(d) + return results + + +async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]: + """Get all sub-sessions for a parent session, ordered by box_index.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, filename, status, current_step, + document_category, doc_type, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at + FROM ocr_pipeline_sessions + WHERE parent_session_id = $1 + ORDER BY box_index ASC + """, uuid.UUID(parent_session_id)) + + return [_row_to_dict(row) for row in rows] + + +async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]: + """Get all sessions in a document group, ordered by page_number.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, filename, status, current_step, + document_category, doc_type, + parent_session_id, box_index, + document_group_id, page_number, + created_at, updated_at + FROM ocr_pipeline_sessions + WHERE document_group_id = $1 + ORDER BY page_number ASC + """, uuid.UUID(document_group_id)) + + return [_row_to_dict(row) for row in rows] + + +async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]: + """List sessions that have a build_grid_reference in ground_truth.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, filename, status, current_step, + document_category, doc_type, + ground_truth, + parent_session_id, box_index, + created_at, updated_at + FROM ocr_pipeline_sessions + WHERE ground_truth IS NOT NULL + AND ground_truth::text LIKE '%build_grid_reference%' + AND parent_session_id IS NULL + ORDER BY created_at DESC + """) + + return [_row_to_dict(row) for row in rows] + + +async def delete_session_db(session_id: str) -> bool: + """Delete a session.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute(""" + DELETE FROM ocr_pipeline_sessions WHERE id = $1 + """, uuid.UUID(session_id)) + return result == "DELETE 1" + + +async def delete_all_sessions_db() -> int: + """Delete all sessions. Returns number of deleted rows.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute("DELETE FROM ocr_pipeline_sessions") + # result is e.g. "DELETE 5" + try: + return int(result.split()[-1]) + except (ValueError, IndexError): + return 0 + + +# ============================================================================= +# HELPER +# ============================================================================= + +def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: + """Convert asyncpg Record to JSON-serializable dict.""" + if row is None: + return {} + + result = dict(row) + + # UUID → string + for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']: + if key in result and result[key] is not None: + result[key] = str(result[key]) + + # datetime → ISO string + for key in ['created_at', 'updated_at']: + if key in result and result[key] is not None: + result[key] = result[key].isoformat() + + # JSONB → parsed (asyncpg returns str for JSONB) + for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']: + if key in result and result[key] is not None: + if isinstance(result[key], str): + result[key] = json.loads(result[key]) + + return result diff --git a/klausur-service/backend/ocr/pipeline/sessions.py b/klausur-service/backend/ocr/pipeline/sessions.py new file mode 100644 index 0000000..7c49f7f --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/sessions.py @@ -0,0 +1,20 @@ +""" +OCR Pipeline Sessions API — barrel re-export. + +All implementation split into: + ocr_pipeline_sessions_crud — session CRUD, box sessions + ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +from fastapi import APIRouter + +from .sessions_crud import router as _crud_router # noqa: F401 +from .sessions_images import router as _images_router # noqa: F401 + +# Composite router (used by ocr_pipeline_api.py) +router = APIRouter() +router.include_router(_crud_router) +router.include_router(_images_router) diff --git a/klausur-service/backend/ocr/pipeline/sessions_crud.py b/klausur-service/backend/ocr/pipeline/sessions_crud.py new file mode 100644 index 0000000..41bd7ad --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/sessions_crud.py @@ -0,0 +1,449 @@ +""" +OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions. + +Extracted from ocr_pipeline_sessions.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import uuid +from typing import Any, Dict, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile + +from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res +from .common import ( + VALID_DOCUMENT_CATEGORIES, + UpdateSessionRequest, + _cache, +) +from .session_store import ( + create_session_db, + delete_all_sessions_db, + delete_session_db, + get_session_db, + get_session_image, + get_sub_sessions, + list_sessions_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Session Management Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions") +async def list_sessions(include_sub_sessions: bool = False): + """List OCR pipeline sessions. + + By default, sub-sessions (box regions) are hidden. + Pass ?include_sub_sessions=true to show them. + """ + sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) + return {"sessions": sessions} + + +@router.post("/sessions") +async def create_session( + file: UploadFile = File(...), + name: Optional[str] = Form(None), +): + """Upload a PDF or image file and create a pipeline session. + + For multi-page PDFs (> 1 page), each page becomes its own session + grouped under a ``document_group_id``. The response includes a + ``pages`` array with one entry per page/session. + """ + file_data = await file.read() + filename = file.filename or "upload" + content_type = file.content_type or "" + + is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") + session_name = name or filename + + # --- Multi-page PDF handling --- + if is_pdf: + try: + import fitz # PyMuPDF + pdf_doc = fitz.open(stream=file_data, filetype="pdf") + page_count = pdf_doc.page_count + pdf_doc.close() + except Exception as e: + raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}") + + if page_count > 1: + return await _create_multi_page_sessions( + file_data, filename, session_name, page_count, + ) + + # --- Single page (image or 1-page PDF) --- + session_id = str(uuid.uuid4()) + + try: + if is_pdf: + img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) + else: + img_bgr = render_image_high_res(file_data) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Could not process file: {e}") + + # Encode original as PNG bytes + success, png_buf = cv2.imencode(".png", img_bgr) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode image") + + original_png = png_buf.tobytes() + + # Persist to DB + await create_session_db( + session_id=session_id, + name=session_name, + filename=filename, + original_png=original_png, + ) + + # Cache BGR array for immediate processing + _cache[session_id] = { + "id": session_id, + "filename": filename, + "name": session_name, + "original_bgr": img_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + logger.info(f"OCR Pipeline: created session {session_id} from {filename} " + f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") + + return { + "session_id": session_id, + "filename": filename, + "name": session_name, + "image_width": img_bgr.shape[1], + "image_height": img_bgr.shape[0], + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + } + + +async def _create_multi_page_sessions( + pdf_data: bytes, + filename: str, + base_name: str, + page_count: int, +) -> dict: + """Create one session per PDF page, grouped by document_group_id.""" + document_group_id = str(uuid.uuid4()) + pages = [] + + for page_idx in range(page_count): + session_id = str(uuid.uuid4()) + page_name = f"{base_name} — Seite {page_idx + 1}" + + try: + img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0) + except Exception as e: + logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}") + continue + + ok, png_buf = cv2.imencode(".png", img_bgr) + if not ok: + continue + page_png = png_buf.tobytes() + + await create_session_db( + session_id=session_id, + name=page_name, + filename=filename, + original_png=page_png, + document_group_id=document_group_id, + page_number=page_idx + 1, + ) + + _cache[session_id] = { + "id": session_id, + "filename": filename, + "name": page_name, + "original_bgr": img_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + h, w = img_bgr.shape[:2] + pages.append({ + "session_id": session_id, + "name": page_name, + "page_number": page_idx + 1, + "image_width": w, + "image_height": h, + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + }) + + logger.info( + f"OCR Pipeline: created page session {session_id} " + f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})" + ) + + # Include session_id pointing to first page for backwards compatibility + # (frontends that expect a single session_id will navigate to page 1) + first_session_id = pages[0]["session_id"] if pages else None + + return { + "session_id": first_session_id, + "document_group_id": document_group_id, + "filename": filename, + "name": base_name, + "page_count": page_count, + "pages": pages, + } + + +@router.get("/sessions/{session_id}") +async def get_session_info(session_id: str): + """Get session info including deskew/dewarp/column results for step navigation.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # Get image dimensions from original PNG + original_png = await get_session_image(session_id, "original") + if original_png: + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) + else: + img_w, img_h = 0, 0 + + result = { + "session_id": session["id"], + "filename": session.get("filename", ""), + "name": session.get("name", ""), + "image_width": img_w, + "image_height": img_h, + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + "current_step": session.get("current_step", 1), + "document_category": session.get("document_category"), + "doc_type": session.get("doc_type"), + } + + if session.get("orientation_result"): + result["orientation_result"] = session["orientation_result"] + if session.get("crop_result"): + result["crop_result"] = session["crop_result"] + if session.get("deskew_result"): + result["deskew_result"] = session["deskew_result"] + if session.get("dewarp_result"): + result["dewarp_result"] = session["dewarp_result"] + if session.get("column_result"): + result["column_result"] = session["column_result"] + if session.get("row_result"): + result["row_result"] = session["row_result"] + if session.get("word_result"): + result["word_result"] = session["word_result"] + if session.get("doc_type_result"): + result["doc_type_result"] = session["doc_type_result"] + if session.get("structure_result"): + result["structure_result"] = session["structure_result"] + if session.get("grid_editor_result"): + # Include summary only to keep response small + gr = session["grid_editor_result"] + result["grid_editor_result"] = { + "summary": gr.get("summary", {}), + "zones_count": len(gr.get("zones", [])), + "edited": gr.get("edited", False), + } + if session.get("ground_truth"): + result["ground_truth"] = session["ground_truth"] + + # Box sub-session info (zone_type='box' from column detection — NOT page-split) + if session.get("parent_session_id"): + result["parent_session_id"] = session["parent_session_id"] + result["box_index"] = session.get("box_index") + else: + # Check for box sub-sessions (column detection creates these) + subs = await get_sub_sessions(session_id) + if subs: + result["sub_sessions"] = [ + {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} + for s in subs + ] + + return result + + +@router.put("/sessions/{session_id}") +async def update_session(session_id: str, req: UpdateSessionRequest): + """Update session name and/or document category.""" + kwargs: Dict[str, Any] = {} + if req.name is not None: + kwargs["name"] = req.name + if req.document_category is not None: + if req.document_category not in VALID_DOCUMENT_CATEGORIES: + raise HTTPException( + status_code=400, + detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", + ) + kwargs["document_category"] = req.document_category + if not kwargs: + raise HTTPException(status_code=400, detail="Nothing to update") + updated = await update_session_db(session_id, **kwargs) + if not updated: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, **kwargs} + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a session.""" + _cache.pop(session_id, None) + deleted = await delete_session_db(session_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "deleted": True} + + +@router.delete("/sessions") +async def delete_all_sessions(): + """Delete ALL sessions (cleanup).""" + _cache.clear() + count = await delete_all_sessions_db() + return {"deleted_count": count} + + +@router.post("/sessions/{session_id}/create-box-sessions") +async def create_box_sessions(session_id: str): + """Create sub-sessions for each detected box region. + + Crops box regions from the cropped/dewarped image and creates + independent sub-sessions that can be processed through the pipeline. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result: + raise HTTPException(status_code=400, detail="Column detection must be completed first") + + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + if not box_zones: + return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} + + # Check for existing sub-sessions + existing = await get_sub_sessions(session_id) + if existing: + return { + "session_id": session_id, + "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], + "message": f"{len(existing)} sub-session(s) already exist", + } + + # Load base image + base_png = await get_session_image(session_id, "cropped") + if not base_png: + base_png = await get_session_image(session_id, "dewarped") + if not base_png: + raise HTTPException(status_code=400, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + parent_name = session.get("name", "Session") + created = [] + + for i, zone in enumerate(box_zones): + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + + # Crop box region with small padding + pad = 5 + y1 = max(0, by - pad) + y2 = min(img.shape[0], by + bh + pad) + x1 = max(0, bx - pad) + x2 = min(img.shape[1], bx + bw + pad) + crop = img[y1:y2, x1:x2] + + # Encode as PNG + success, png_buf = cv2.imencode(".png", crop) + if not success: + logger.warning(f"Failed to encode box {i} crop for session {session_id}") + continue + + sub_id = str(uuid.uuid4()) + sub_name = f"{parent_name} — Box {i + 1}" + + await create_session_db( + session_id=sub_id, + name=sub_name, + filename=session.get("filename", "box-crop.png"), + original_png=png_buf.tobytes(), + parent_session_id=session_id, + box_index=i, + ) + + # Cache the BGR for immediate processing + # Promote original to cropped so column/row/word detection finds it + box_bgr = crop.copy() + _cache[sub_id] = { + "id": sub_id, + "filename": session.get("filename", "box-crop.png"), + "name": sub_name, + "parent_session_id": session_id, + "original_bgr": box_bgr, + "oriented_bgr": None, + "cropped_bgr": box_bgr, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + created.append({ + "id": sub_id, + "name": sub_name, + "box_index": i, + "box": box, + "image_width": crop.shape[1], + "image_height": crop.shape[0], + }) + + logger.info(f"Created box sub-session {sub_id} for session {session_id} " + f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") + + return { + "session_id": session_id, + "sub_sessions": created, + "total": len(created), + } diff --git a/klausur-service/backend/ocr/pipeline/sessions_images.py b/klausur-service/backend/ocr/pipeline/sessions_images.py new file mode 100644 index 0000000..9be42da --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/sessions_images.py @@ -0,0 +1,176 @@ +""" +OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log, +categories, and document type detection. + +Extracted from ocr_pipeline_sessions.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from typing import Any, Dict + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import Response + +from cv_vocab_pipeline import create_ocr_image, detect_document_type +from .common import ( + VALID_DOCUMENT_CATEGORIES, + _append_pipeline_log, + _cache, + _get_base_image_png, + _get_cached, + _load_session_to_cache, +) +from .overlays import render_overlay +from .session_store import ( + get_session_db, + get_session_image, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Thumbnail & Log Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/thumbnail") +async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): + """Return a small thumbnail of the original image.""" + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + h, w = img.shape[:2] + scale = size / max(h, w) + new_w, new_h = int(w * scale), int(h * scale) + thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + _, png_bytes = cv2.imencode(".png", thumb) + return Response(content=png_bytes.tobytes(), media_type="image/png", + headers={"Cache-Control": "public, max-age=3600"}) + + +@router.get("/sessions/{session_id}/pipeline-log") +async def get_pipeline_log(session_id: str): + """Get the pipeline execution log for a session.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} + + +@router.get("/categories") +async def list_categories(): + """List valid document categories.""" + return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} + + +# --------------------------------------------------------------------------- +# Image Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/image/{image_type}") +async def get_image(session_id: str, image_type: str): + """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" + valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} + if image_type not in valid_types: + raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") + + if image_type == "structure-overlay": + return await render_overlay("structure", session_id) + + if image_type == "columns-overlay": + return await render_overlay("columns", session_id) + + if image_type == "rows-overlay": + return await render_overlay("rows", session_id) + + if image_type == "words-overlay": + return await render_overlay("words", session_id) + + # Try cache first for fast serving + cached = _cache.get(session_id) + if cached: + png_key = f"{image_type}_png" if image_type != "original" else None + bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None + + # For binarized, check if we have it cached as PNG + if image_type == "binarized" and cached.get("binarized_png"): + return Response(content=cached["binarized_png"], media_type="image/png") + + # Load from DB — for cropped/dewarped, fall back through the chain + if image_type in ("cropped", "dewarped"): + data = await _get_base_image_png(session_id) + else: + data = await get_session_image(session_id, image_type) + if not data: + raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") + + return Response(content=data, media_type="image/png") + + +# --------------------------------------------------------------------------- +# Document Type Detection (between Dewarp and Columns) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-type") +async def detect_type(session_id: str): + """Detect document type (vocab_table, full_text, generic_table). + + Should be called after crop (clean image available). + Falls back to dewarped if crop was skipped. + Stores result in session for frontend to decide pipeline flow. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") + + t0 = time.time() + ocr_img = create_ocr_image(img_bgr) + result = detect_document_type(ocr_img, img_bgr) + duration = time.time() - t0 + + result_dict = { + "doc_type": result.doc_type, + "confidence": result.confidence, + "pipeline": result.pipeline, + "skip_steps": result.skip_steps, + "features": result.features, + "duration_seconds": round(duration, 2), + } + + # Persist to DB + await update_session_db( + session_id, + doc_type=result.doc_type, + doc_type_result=result_dict, + ) + + cached["doc_type_result"] = result_dict + + logger.info(f"OCR Pipeline: detect-type session {session_id}: " + f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") + + await _append_pipeline_log(session_id, "detect_type", { + "doc_type": result.doc_type, + "pipeline": result.pipeline, + "confidence": result.confidence, + **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **result_dict} diff --git a/klausur-service/backend/ocr/pipeline/structure.py b/klausur-service/backend/ocr/pipeline/structure.py new file mode 100644 index 0000000..6908cc4 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/structure.py @@ -0,0 +1,299 @@ +""" +OCR Pipeline Structure Detection and Exclude Regions + +Detect document structure (boxes, zones, color regions, graphics) +and manage user-drawn exclude regions. +Extracted from ocr_pipeline_geometry.py for file-size compliance. +""" + +import logging +import time +from typing import Any, Dict, List + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from cv_box_detect import detect_boxes +from cv_color_detect import _COLOR_RANGES, _COLOR_HEX +from cv_graphic_detect import detect_graphic_elements +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _filter_border_ghost_words, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Structure Detection Endpoint +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-structure") +async def detect_structure(session_id: str): + """Detect document structure: boxes, zones, and color regions. + + Runs box detection (line + shading) and color analysis on the cropped + image. Returns structured JSON with all detected elements for the + structure visualization step. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = ( + cached.get("cropped_bgr") + if cached.get("cropped_bgr") is not None + else cached.get("dewarped_bgr") + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") + + t0 = time.time() + h, w = img_bgr.shape[:2] + + # --- Content bounds from word result (if available) or full image --- + word_result = cached.get("word_result") + words: List[Dict] = [] + if word_result and word_result.get("cells"): + for cell in word_result["cells"]: + for wb in (cell.get("word_boxes") or []): + words.append(wb) + # Fallback: use raw OCR words if cell word_boxes are empty + if not words and word_result: + for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"): + raw = word_result.get(key, []) + if raw: + words = raw + logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key) + break + # If no words yet, use image dimensions with small margin + if words: + content_x = max(0, min(int(wb["left"]) for wb in words)) + content_y = max(0, min(int(wb["top"]) for wb in words)) + content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words)) + content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words)) + content_w_px = content_r - content_x + content_h_px = content_b - content_y + else: + margin = int(min(w, h) * 0.03) + content_x, content_y = margin, margin + content_w_px = w - 2 * margin + content_h_px = h - 2 * margin + + # --- Box detection --- + boxes = detect_boxes( + img_bgr, + content_x=content_x, + content_w=content_w_px, + content_y=content_y, + content_h=content_h_px, + ) + + # --- Zone splitting --- + from cv_box_detect import split_page_into_zones as _split_zones + zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes) + + # --- Color region sampling --- + # Sample background shading in each detected box + hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) + box_colors = [] + for box in boxes: + # Sample the center region of each box + cy1 = box.y + box.height // 4 + cy2 = box.y + 3 * box.height // 4 + cx1 = box.x + box.width // 4 + cx2 = box.x + 3 * box.width // 4 + cy1 = max(0, min(cy1, h - 1)) + cy2 = max(0, min(cy2, h - 1)) + cx1 = max(0, min(cx1, w - 1)) + cx2 = max(0, min(cx2, w - 1)) + if cy2 > cy1 and cx2 > cx1: + roi_hsv = hsv[cy1:cy2, cx1:cx2] + med_h = float(np.median(roi_hsv[:, :, 0])) + med_s = float(np.median(roi_hsv[:, :, 1])) + med_v = float(np.median(roi_hsv[:, :, 2])) + if med_s > 15: + from cv_color_detect import _hue_to_color_name + bg_name = _hue_to_color_name(med_h) + bg_hex = _COLOR_HEX.get(bg_name, "#6b7280") + else: + bg_name = "gray" if med_v < 220 else "white" + bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff" + else: + bg_name = "unknown" + bg_hex = "#6b7280" + box_colors.append({"color_name": bg_name, "color_hex": bg_hex}) + + # --- Color text detection overview --- + # Quick scan for colored text regions across the page + color_summary: Dict[str, int] = {} + for color_name, ranges in _COLOR_RANGES.items(): + mask = np.zeros((h, w), dtype=np.uint8) + for lower, upper in ranges: + mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) + pixel_count = int(np.sum(mask > 0)) + if pixel_count > 50: # minimum threshold + color_summary[color_name] = pixel_count + + # --- Graphic element detection --- + box_dicts = [ + {"x": b.x, "y": b.y, "w": b.width, "h": b.height} + for b in boxes + ] + graphics = detect_graphic_elements( + img_bgr, words, + detected_boxes=box_dicts, + ) + + # --- Filter border-ghost words from OCR result --- + ghost_count = 0 + if boxes and word_result: + ghost_count = _filter_border_ghost_words(word_result, boxes) + if ghost_count: + logger.info("detect-structure: removed %d border-ghost words", ghost_count) + await update_session_db(session_id, word_result=word_result) + cached["word_result"] = word_result + + duration = time.time() - t0 + + # Preserve user-drawn exclude regions from previous run + prev_sr = cached.get("structure_result") or {} + prev_exclude = prev_sr.get("exclude_regions", []) + + result_dict = { + "image_width": w, + "image_height": h, + "content_bounds": { + "x": content_x, "y": content_y, + "w": content_w_px, "h": content_h_px, + }, + "boxes": [ + { + "x": b.x, "y": b.y, "w": b.width, "h": b.height, + "confidence": b.confidence, + "border_thickness": b.border_thickness, + "bg_color_name": box_colors[i]["color_name"], + "bg_color_hex": box_colors[i]["color_hex"], + } + for i, b in enumerate(boxes) + ], + "zones": [ + { + "index": z.index, + "zone_type": z.zone_type, + "y": z.y, "h": z.height, + "x": z.x, "w": z.width, + } + for z in zones + ], + "graphics": [ + { + "x": g.x, "y": g.y, "w": g.width, "h": g.height, + "area": g.area, + "shape": g.shape, + "color_name": g.color_name, + "color_hex": g.color_hex, + "confidence": round(g.confidence, 2), + } + for g in graphics + ], + "exclude_regions": prev_exclude, + "color_pixel_counts": color_summary, + "has_words": len(words) > 0, + "word_count": len(words), + "border_ghosts_removed": ghost_count, + "duration_seconds": round(duration, 2), + } + + # Persist to session + await update_session_db(session_id, structure_result=result_dict) + cached["structure_result"] = result_dict + + logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs", + session_id, len(boxes), len(zones), len(graphics), duration) + + return {"session_id": session_id, **result_dict} + + +# --------------------------------------------------------------------------- +# Exclude Regions -- user-drawn rectangles to exclude from OCR results +# --------------------------------------------------------------------------- + +class _ExcludeRegionIn(BaseModel): + x: int + y: int + w: int + h: int + label: str = "" + + +class _ExcludeRegionsBatchIn(BaseModel): + regions: list[_ExcludeRegionIn] + + +@router.put("/sessions/{session_id}/exclude-regions") +async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn): + """Replace all exclude regions for a session. + + Regions are stored inside ``structure_result.exclude_regions``. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + sr = session.get("structure_result") or {} + sr["exclude_regions"] = [r.model_dump() for r in body.regions] + + # Invalidate grid so it rebuilds with new exclude regions + await update_session_db(session_id, structure_result=sr, grid_editor_result=None) + + # Update cache + if session_id in _cache: + _cache[session_id]["structure_result"] = sr + _cache[session_id].pop("grid_editor_result", None) + + return { + "session_id": session_id, + "exclude_regions": sr["exclude_regions"], + "count": len(sr["exclude_regions"]), + } + + +@router.delete("/sessions/{session_id}/exclude-regions/{region_index}") +async def delete_exclude_region(session_id: str, region_index: int): + """Remove a single exclude region by index.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + sr = session.get("structure_result") or {} + regions = sr.get("exclude_regions", []) + + if region_index < 0 or region_index >= len(regions): + raise HTTPException(status_code=404, detail="Region index out of range") + + removed = regions.pop(region_index) + sr["exclude_regions"] = regions + + # Invalidate grid so it rebuilds with new exclude regions + await update_session_db(session_id, structure_result=sr, grid_editor_result=None) + + if session_id in _cache: + _cache[session_id]["structure_result"] = sr + _cache[session_id].pop("grid_editor_result", None) + + return { + "session_id": session_id, + "removed": removed, + "remaining": len(regions), + } diff --git a/klausur-service/backend/ocr/pipeline/validation.py b/klausur-service/backend/ocr/pipeline/validation.py new file mode 100644 index 0000000..074ad49 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/validation.py @@ -0,0 +1,362 @@ +""" +OCR Pipeline Validation — image detection, generation, validation save, +and handwriting removal endpoints. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +from datetime import datetime +from typing import Optional + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from .session_store import ( + get_session_db, + get_session_image, + update_session_db, +) +from .common import ( + _cache, + RemoveHandwritingRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +STYLE_SUFFIXES = { + "educational": "educational illustration, textbook style, clear, colorful", + "cartoon": "cartoon, child-friendly, simple shapes", + "sketch": "pencil sketch, hand-drawn, black and white", + "clipart": "clipart, flat vector style, simple", + "realistic": "photorealistic, high detail", +} + + +class ValidationRequest(BaseModel): + notes: Optional[str] = None + score: Optional[int] = None + + +class GenerateImageRequest(BaseModel): + region_index: int + prompt: str + style: str = "educational" + + +# --------------------------------------------------------------------------- +# Image detection + generation +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction/detect-images") +async def detect_image_regions(session_id: str): + """Detect illustration/image regions in the original scan using VLM.""" + import base64 + import httpx + import re + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=400, detail="No original image found") + + word_result = session.get("word_result") or {} + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + vocab_context = "" + if entries: + sample = entries[:10] + words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] + if words: + vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "Analyze this scanned page. Find ALL illustration/image/picture regions " + "(NOT text, NOT table cells, NOT blank areas). " + "For each image region found, return its bounding box as percentage of page dimensions " + "and a short English description of what the image shows. " + "Reply with ONLY a JSON array like: " + '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' + "where x, y, w, h are percentages (0-100) of the page width/height. " + "If there are NO images on the page, return an empty array: []" + f"{vocab_context}" + ) + + img_b64 = base64.b64encode(original_png).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + match = re.search(r'\[.*?\]', text, re.DOTALL) + if match: + raw_regions = json.loads(match.group(0)) + else: + raw_regions = [] + + regions = [] + for r in raw_regions: + regions.append({ + "bbox_pct": { + "x": max(0, min(100, float(r.get("x", 0)))), + "y": max(0, min(100, float(r.get("y", 0)))), + "w": max(1, min(100, float(r.get("w", 10)))), + "h": max(1, min(100, float(r.get("h", 10)))), + }, + "description": r.get("description", ""), + "prompt": r.get("description", ""), + "image_b64": None, + "style": "educational", + }) + + # Enrich prompts with nearby vocab context + if entries: + for region in regions: + ry = region["bbox_pct"]["y"] + rh = region["bbox_pct"]["h"] + nearby = [ + e for e in entries + if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 + ] + if nearby: + en_words = [e.get("english", "") for e in nearby if e.get("english")] + de_words = [e.get("german", "") for e in nearby if e.get("german")] + if en_words or de_words: + context = f" (vocabulary context: {', '.join(en_words[:5])}" + if de_words: + context += f" / {', '.join(de_words[:5])}" + context += ")" + region["prompt"] = region["description"] + context + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["image_regions"] = regions + validation["detected_at"] = datetime.utcnow().isoformat() + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Detected {len(regions)} image regions for session {session_id}") + + return {"regions": regions, "count": len(regions)} + + except httpx.ConnectError: + logger.warning(f"VLM not available at {ollama_base} for image detection") + return {"regions": [], "count": 0, "error": "VLM not available"} + except Exception as e: + logger.error(f"Image detection failed for {session_id}: {e}") + return {"regions": [], "count": 0, "error": str(e)} + + +@router.post("/sessions/{session_id}/reconstruction/generate-image") +async def generate_image_for_region(session_id: str, req: GenerateImageRequest): + """Generate a replacement image for a detected region using mflux.""" + import httpx + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + regions = validation.get("image_regions") or [] + + if req.region_index < 0 or req.region_index >= len(regions): + raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") + + mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") + style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) + full_prompt = f"{req.prompt}, {style_suffix}" + + region = regions[req.region_index] + bbox = region["bbox_pct"] + aspect = bbox["w"] / max(bbox["h"], 1) + if aspect > 1.3: + width, height = 768, 512 + elif aspect < 0.7: + width, height = 512, 768 + else: + width, height = 512, 512 + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{mflux_url}/generate", json={ + "prompt": full_prompt, + "width": width, + "height": height, + "steps": 4, + }) + resp.raise_for_status() + data = resp.json() + image_b64 = data.get("image_b64") + + if not image_b64: + return {"image_b64": None, "success": False, "error": "No image returned"} + + regions[req.region_index]["image_b64"] = image_b64 + regions[req.region_index]["prompt"] = req.prompt + regions[req.region_index]["style"] = req.style + validation["image_regions"] = regions + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Generated image for session {session_id} region {req.region_index}") + return {"image_b64": image_b64, "success": True} + + except httpx.ConnectError: + logger.warning(f"mflux-service not available at {mflux_url}") + return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} + except Exception as e: + logger.error(f"Image generation failed for {session_id}: {e}") + return {"image_b64": None, "success": False, "error": str(e)} + + +# --------------------------------------------------------------------------- +# Validation save/get +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction/validate") +async def save_validation(session_id: str, req: ValidationRequest): + """Save final validation results for step 8.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["validated_at"] = datetime.utcnow().isoformat() + validation["notes"] = req.notes + validation["score"] = req.score + ground_truth["validation"] = validation + + await update_session_db(session_id, ground_truth=ground_truth, current_step=11) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Validation saved for session {session_id}: score={req.score}") + + return {"session_id": session_id, "validation": validation} + + +@router.get("/sessions/{session_id}/reconstruction/validation") +async def get_validation(session_id: str): + """Retrieve saved validation data for step 8.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") + + return { + "session_id": session_id, + "validation": validation, + "word_result": session.get("word_result"), + } + + +# --------------------------------------------------------------------------- +# Remove handwriting +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/remove-handwriting") +async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): + """Remove handwriting from a session image using inpainting.""" + import time as _time + + from services.handwriting_detection import detect_handwriting + from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + t0 = _time.monotonic() + + # 1. Determine source image + source = req.use_source + if source == "auto": + deskewed = await get_session_image(session_id, "deskewed") + source = "deskewed" if deskewed else "original" + + image_bytes = await get_session_image(session_id, source) + if not image_bytes: + raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") + + # 2. Detect handwriting mask + detection = detect_handwriting(image_bytes, target_ink=req.target_ink) + + # 3. Convert mask to PNG bytes and dilate + import io + from PIL import Image as _PILImage + mask_img = _PILImage.fromarray(detection.mask) + mask_buf = io.BytesIO() + mask_img.save(mask_buf, format="PNG") + mask_bytes = mask_buf.getvalue() + + if req.dilation > 0: + mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) + + # 4. Inpaint + method_map = { + "telea": InpaintingMethod.OPENCV_TELEA, + "ns": InpaintingMethod.OPENCV_NS, + "auto": InpaintingMethod.AUTO, + } + inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) + + result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) + if not result.success: + raise HTTPException(status_code=500, detail="Inpainting failed") + + elapsed_ms = int((_time.monotonic() - t0) * 1000) + + meta = { + "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), + "handwriting_ratio": round(detection.handwriting_ratio, 4), + "detection_confidence": round(detection.confidence, 4), + "target_ink": req.target_ink, + "dilation": req.dilation, + "source_image": source, + "processing_time_ms": elapsed_ms, + } + + # 5. Persist clean image + clean_png_bytes = image_to_png(result.image) + await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) + + return { + **meta, + "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", + "session_id": session_id, + } diff --git a/klausur-service/backend/ocr/pipeline/vision_fusion.py b/klausur-service/backend/ocr/pipeline/vision_fusion.py new file mode 100644 index 0000000..c2ef216 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/vision_fusion.py @@ -0,0 +1,261 @@ +""" +Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading. + +Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL. +The LLM can read degraded text using context understanding and visual inspection, +while OCR coordinates provide structural hints (where text is, column positions). + +Uses Ollama API (same pattern as handwriting_htr_api.py). +""" + +import base64 +import json +import logging +import os +import re +from typing import Any, Dict, List, Optional + +import cv2 +import httpx +import numpy as np + +logger = logging.getLogger(__name__) + +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b") + +# Document category → prompt context +CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = { + "vokabelseite": { + "label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)", + "columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.", + }, + "woerterbuch": { + "label": "Woerterbuchseite", + "columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.", + }, + "arbeitsblatt": { + "label": "Arbeitsblatt", + "columns": "Erkenne die Spaltenstruktur aus dem Layout.", + }, + "buchseite": { + "label": "Schulbuchseite", + "columns": "Erkenne die Spaltenstruktur aus dem Layout.", + }, +} + + +def _group_words_into_lines( + words: List[Dict], y_tolerance: float = 15.0, +) -> List[List[Dict]]: + """Group OCR words into lines by Y-proximity.""" + if not words: + return [] + sorted_w = sorted(words, key=lambda w: w.get("top", 0)) + lines: List[List[Dict]] = [[sorted_w[0]]] + for w in sorted_w[1:]: + last_line = lines[-1] + avg_y = sum(ww["top"] for ww in last_line) / len(last_line) + if abs(w["top"] - avg_y) <= y_tolerance: + last_line.append(w) + else: + lines.append([w]) + # Sort words within each line by X + for line in lines: + line.sort(key=lambda w: w.get("left", 0)) + return lines + + +def _build_ocr_context(words: List[Dict], img_h: int) -> str: + """Build a text description of OCR words with positions for the prompt.""" + lines = _group_words_into_lines(words) + context_parts = [] + for i, line in enumerate(lines): + word_descs = [] + for w in line: + text = w.get("text", "").strip() + x = w.get("left", 0) + conf = w.get("conf", 0) + marker = " (?)" if conf < 50 else "" + word_descs.append(f'x={x} "{text}"{marker}') + avg_y = int(sum(w["top"] for w in line) / len(line)) + context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}") + return "\n".join(context_parts) + + +def _build_prompt( + ocr_context: str, category: str, img_w: int, img_h: int, +) -> str: + """Build the Vision-LLM prompt with OCR context and document type.""" + cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"]) + + return f"""Du siehst eine eingescannte {cat_info['label']}. +{cat_info['columns']} + +Die OCR-Software hat folgende Woerter an diesen Positionen erkannt. +Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch: + +{ocr_context} + +Bildgroesse: {img_w} x {img_h} Pixel. + +AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle. +- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst +- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist, + gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht) +- Behalte die Reihenfolge bei + +Antworte NUR mit einem JSON-Array, keine Erklaerungen: +[ + {{"row": 1, "english": "...", "german": "...", "example": "..."}}, + {{"row": 2, "english": "...", "german": "...", "example": "..."}} +]""" + + +def _parse_llm_response(response_text: str) -> Optional[List[Dict]]: + """Parse the LLM JSON response, handling markdown code blocks.""" + text = response_text.strip() + + # Strip markdown code block if present + if text.startswith("```"): + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```\s*$", "", text) + + # Try to find JSON array + match = re.search(r"\[[\s\S]*\]", text) + if not match: + logger.warning("vision_fuse_ocr: no JSON array found in LLM response") + return None + + try: + data = json.loads(match.group()) + if not isinstance(data, list): + return None + return data + except json.JSONDecodeError as e: + logger.warning(f"vision_fuse_ocr: JSON parse error: {e}") + return None + + +def _vocab_rows_to_words( + rows: List[Dict], img_w: int, img_h: int, +) -> List[Dict]: + """Convert LLM vocab rows back to word dicts for grid building. + + Distributes words across estimated column positions so the + existing grid builder can process them normally. + """ + words = [] + # Estimate column positions (3-column vocab layout) + col_positions = [ + (0.02, 0.28), # EN: 2%-28% of width + (0.30, 0.55), # DE: 30%-55% + (0.57, 0.98), # Example: 57%-98% + ] + + median_h = max(15, img_h // (len(rows) * 3)) if rows else 20 + y_step = max(median_h + 5, img_h // max(len(rows), 1)) + + for i, row in enumerate(rows): + y = int(i * y_step + 20) + row_num = row.get("row", i + 1) + + for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([ + ("english", col_positions[0]), + ("german", col_positions[1]), + ("example", col_positions[2]), + ]): + text = (row.get(field) or "").strip() + if not text: + continue + x = int(x_start_pct * img_w) + w = int((x_end_pct - x_start_pct) * img_w) + words.append({ + "text": text, + "left": x, + "top": y, + "width": w, + "height": median_h, + "conf": 95, # LLM-corrected → high confidence + "_source": "vision_llm", + "_row": row_num, + "_col_type": f"column_{['en', 'de', 'example'][col_idx]}", + }) + + logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words") + return words + + +async def vision_fuse_ocr( + img_bgr: np.ndarray, + ocr_words: List[Dict], + document_category: str = "vokabelseite", +) -> List[Dict]: + """Fuse traditional OCR results with Vision-LLM reading. + + Sends the image + OCR word positions to Qwen2.5-VL which can: + - Read degraded text that traditional OCR cannot + - Use document context (knows what a vocab table looks like) + - Merge continuation rows (understands table structure) + + Args: + img_bgr: The cropped/dewarped scan image (BGR) + ocr_words: Traditional OCR word list with positions + document_category: Type of document being scanned + + Returns: + Corrected word list in same format as input, ready for grid building. + Falls back to original ocr_words on error. + """ + img_h, img_w = img_bgr.shape[:2] + + # Build OCR context string + ocr_context = _build_ocr_context(ocr_words, img_h) + + # Build prompt + prompt = _build_prompt(ocr_context, document_category, img_w, img_h) + + # Encode image as base64 + _, img_encoded = cv2.imencode(".png", img_bgr) + img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8") + + # Call Qwen2.5-VL via Ollama + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post( + f"{OLLAMA_BASE_URL}/api/generate", + json={ + "model": VISION_FUSION_MODEL, + "prompt": prompt, + "images": [img_b64], + "stream": False, + "options": {"temperature": 0.1, "num_predict": 4096}, + }, + ) + resp.raise_for_status() + data = resp.json() + response_text = data.get("response", "").strip() + except Exception as e: + logger.error(f"vision_fuse_ocr: Ollama call failed: {e}") + return ocr_words # Fallback to original + + if not response_text: + logger.warning("vision_fuse_ocr: empty LLM response") + return ocr_words + + # Parse JSON response + rows = _parse_llm_response(response_text) + if not rows: + logger.warning( + "vision_fuse_ocr: could not parse LLM response, " + "first 200 chars: %s", response_text[:200], + ) + return ocr_words + + logger.info( + f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows " + f"(from {len(ocr_words)} OCR words)" + ) + + # Convert back to word format for grid building + return _vocab_rows_to_words(rows, img_w, img_h) diff --git a/klausur-service/backend/ocr/pipeline/words.py b/klausur-service/backend/ocr/pipeline/words.py new file mode 100644 index 0000000..dd43c56 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/words.py @@ -0,0 +1,185 @@ +""" +OCR Pipeline Words — composite router for word detection, PaddleOCR direct, +and ground truth endpoints. + +Split into sub-modules: + ocr_pipeline_words_detect — main detect_words endpoint (Step 7) + ocr_pipeline_words_stream — SSE streaming generators + +This barrel module contains the PaddleOCR direct endpoint and ground truth +endpoints, and assembles all word-related routers. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from cv_words_first import build_grid_from_words +from .session_store import ( + get_session_db, + get_session_image, + update_session_db, +) +from .common import ( + _cache, + _append_pipeline_log, +) +from .words_detect import router as _detect_router + +logger = logging.getLogger(__name__) + +_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Pydantic models +# --------------------------------------------------------------------------- + +class WordGroundTruthRequest(BaseModel): + is_correct: bool + corrected_entries: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +# --------------------------------------------------------------------------- +# PaddleOCR Direct Endpoint +# --------------------------------------------------------------------------- + +@_local_router.post("/sessions/{session_id}/paddle-direct") +async def paddle_direct(session_id: str): + """Run PaddleOCR on the preprocessed image and build a word grid directly.""" + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode original image") + + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_paddle + + t0 = time.time() + word_dicts = await ocr_region_paddle(img_bgr, region=None) + if not word_dicts: + raise HTTPException(status_code=400, detail="PaddleOCR returned no words") + + cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h) + duration = time.time() - t0 + + for cell in cells: + cell["ocr_engine"] = "paddle_direct" + + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": "paddle_direct", + "grid_method": "paddle_direct", + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + await update_session_db( + session_id, + word_result=word_result, + cropped_png=img_png, + current_step=8, + ) + + logger.info( + "paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs", + session_id, len(cells), n_rows, n_cols, duration, + ) + + await _append_pipeline_log(session_id, "paddle_direct", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "ocr_engine": "paddle_direct", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +# --------------------------------------------------------------------------- +# Ground Truth Words Endpoints +# --------------------------------------------------------------------------- + +@_local_router.post("/sessions/{session_id}/ground-truth/words") +async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): + """Save ground truth feedback for the word recognition step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_entries": req.corrected_entries, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "word_result": session.get("word_result"), + } + ground_truth["words"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@_local_router.get("/sessions/{session_id}/ground-truth/words") +async def get_word_ground_truth(session_id: str): + """Retrieve saved ground truth for word recognition.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + words_gt = ground_truth.get("words") + if not words_gt: + raise HTTPException(status_code=404, detail="No word ground truth saved") + + return { + "session_id": session_id, + "words_gt": words_gt, + "words_auto": session.get("word_result"), + } + + +# --------------------------------------------------------------------------- +# Composite router +# --------------------------------------------------------------------------- + +router = APIRouter() +router.include_router(_detect_router) +router.include_router(_local_router) diff --git a/klausur-service/backend/ocr/pipeline/words_detect.py b/klausur-service/backend/ocr/pipeline/words_detect.py new file mode 100644 index 0000000..2770d28 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/words_detect.py @@ -0,0 +1,393 @@ +""" +OCR Pipeline Words Detect — main word detection endpoint (Step 7). + +Extracted from ocr_pipeline_words.py. Contains the ``detect_words`` +endpoint which handles both v2 and words_first grid methods. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, List + +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _fix_phonetic_brackets, + fix_cell_phonetics, + build_cell_grid_v2, + create_ocr_image, + detect_column_geometry, +) +from cv_words_first import build_grid_from_words +from .session_store import ( + get_session_db, + update_session_db, +) +from .common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, +) +from .words_stream import ( + _word_batch_stream_generator, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Word Detection Endpoint (Step 7) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/words") +async def detect_words( + session_id: str, + request: Request, + engine: str = "auto", + pronunciation: str = "british", + stream: bool = False, + skip_heal_gaps: bool = False, + grid_method: str = "v2", +): + """Build word grid from columns x rows, OCR each cell. + + Query params: + engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' + pronunciation: 'british' (default) or 'american' + stream: false (default) for JSON response, true for SSE streaming + skip_heal_gaps: false (default). When true, cells keep exact row geometry. + grid_method: 'v2' (default) or 'words_first' + """ + # PaddleOCR is full-page remote OCR -> force words_first grid method + if engine == "paddle" and grid_method != "words_first": + logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) + grid_method = "words_first" + + if session_id not in _cache: + logger.info("detect_words: session %s not in cache, loading from DB", session_id) + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if dewarped_bgr is None: + logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", + session_id, [k for k in cached.keys() if k.endswith('_bgr')]) + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + row_result = session.get("row_result") + if not column_result or not column_result.get("columns"): + img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] + column_result = { + "columns": [{ + "type": "column_text", + "x": 0, "y": 0, + "width": img_w_tmp, "height": img_h_tmp, + "classification_confidence": 1.0, + "classification_method": "full_page_fallback", + }], + "zones": [], + "duration_seconds": 0, + } + logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) + if grid_method != "words_first" and (not row_result or not row_result.get("rows")): + raise HTTPException(status_code=400, detail="Row detection must be completed first") + + # Convert column dicts back to PageRegion objects + col_regions = [ + PageRegion( + type=c["type"], + x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + # Convert row dicts back to RowGeometry objects + row_geoms = [ + RowGeometry( + index=r["index"], + x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), + words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + # Populate word counts from cached words + word_dicts = cached.get("_word_dicts") + if word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if word_dicts: + content_bounds = cached.get("_content_bounds") + if content_bounds: + _lx, _rx, top_y, _by = content_bounds + else: + top_y = min(r.y for r in row_geoms) if row_geoms else 0 + + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + # Exclude rows that fall within box zones + zones = column_result.get("zones") or [] + box_ranges_inner = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bt = max(box.get("border_thickness", 0), 5) + box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) + + if box_ranges_inner: + def _row_in_box(r): + center_y = r.y + r.height / 2 + return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) + + before_count = len(row_geoms) + row_geoms = [r for r in row_geoms if not _row_in_box(r)] + excluded = before_count - len(row_geoms) + if excluded: + logger.info(f"detect_words: excluded {excluded} rows inside box zones") + + # --- Words-First path --- + if grid_method == "words_first": + return await _words_first_path( + session_id, cached, dewarped_bgr, engine, pronunciation, zones, + ) + + if stream: + return StreamingResponse( + _word_batch_stream_generator( + session_id, cached, col_regions, row_geoms, + dewarped_bgr, engine, pronunciation, request, + skip_heal_gaps=skip_heal_gaps, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + # --- Non-streaming path (grid_method=v2) --- + return await _v2_path( + session_id, cached, col_regions, row_geoms, + dewarped_bgr, engine, pronunciation, skip_heal_gaps, + ) + + +async def _words_first_path( + session_id: str, + cached: Dict[str, Any], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + zones: list, +) -> dict: + """Words-first grid construction path.""" + t0 = time.time() + img_h, img_w = dewarped_bgr.shape[:2] + + if engine == "paddle": + from cv_ocr_engines import ocr_region_paddle + wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) + cached["_paddle_word_dicts"] = wf_word_dicts + else: + wf_word_dicts = cached.get("_word_dicts") + if wf_word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result + cached["_word_dicts"] = wf_word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if not wf_word_dicts: + raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid") + + # Convert word coordinates to absolute if needed + if engine != "paddle": + content_bounds = cached.get("_content_bounds") + if content_bounds: + lx, _rx, ty, _by = content_bounds + abs_words = [] + for w in wf_word_dicts: + abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty}) + wf_word_dicts = abs_words + + box_rects = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box_rects.append(zone["box"]) + + cells, columns_meta = build_grid_from_words( + wf_word_dicts, img_w, img_h, box_rects=box_rects or None, + ) + duration = time.time() - t0 + + fix_cell_phonetics(cells, pronunciation=pronunciation) + for cell in cells: + cell.setdefault("zone_index", 0) + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + used_engine = "paddle" if engine == "paddle" else "words_first" + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "grid_method": "words_first", + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + if is_vocab or 'column_text' in col_types: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words-first session {session_id}: " + f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") + + await _append_pipeline_log(session_id, "words", { + "grid_method": "words_first", + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +async def _v2_path( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + skip_heal_gaps: bool, +) -> dict: + """Cell-First OCR v2 non-streaming path.""" + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + cells, columns_meta = build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ) + duration = time.time() - t0 + + for cell in cells: + cell.setdefault("zone_index", 0) + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len(columns_meta) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + fix_cell_phonetics(cells, pronunciation=pronunciation) + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") + + await _append_pipeline_log(session_id, "words", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "low_confidence_count": word_result["summary"]["low_confidence"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + "entry_count": word_result.get("entry_count", 0), + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} diff --git a/klausur-service/backend/ocr/pipeline/words_stream.py b/klausur-service/backend/ocr/pipeline/words_stream.py new file mode 100644 index 0000000..9ff06d1 --- /dev/null +++ b/klausur-service/backend/ocr/pipeline/words_stream.py @@ -0,0 +1,303 @@ +""" +OCR Pipeline Words Stream — SSE streaming generators for word detection. + +Extracted from ocr_pipeline_words.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, List + +import numpy as np +from fastapi import Request + +from cv_vocab_pipeline import ( + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + build_cell_grid_v2, + build_cell_grid_v2_streaming, + create_ocr_image, +) +from .session_store import update_session_db +from .common import _cache + +logger = logging.getLogger(__name__) + + +async def _word_batch_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, + skip_heal_gaps: bool = False, +): + """SSE generator that runs batch OCR (parallel) then streams results. + + Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR, + then emits all cells as SSE events. + """ + import asyncio + + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + total_cells = n_content_rows * n_cols + + # 1. Send meta event immediately + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + # 2. Send preparing event (keepalive for proxy) + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" + + # 3. Run batch OCR in thread pool with periodic keepalive events. + loop = asyncio.get_event_loop() + ocr_future = loop.run_in_executor( + None, + lambda: build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ), + ) + + # Send keepalive events every 5 seconds while OCR runs + keepalive_count = 0 + while not ocr_future.done(): + try: + cells, columns_meta = await asyncio.wait_for( + asyncio.shield(ocr_future), timeout=5.0, + ) + break # OCR finished + except asyncio.TimeoutError: + keepalive_count += 1 + elapsed = int(time.time() - t0) + yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected during OCR for {session_id}") + ocr_future.cancel() + return + else: + cells, columns_meta = ocr_future.result() + + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected after OCR for {session_id}") + return + + # 4. Apply IPA phonetic fixes + fix_cell_phonetics(cells, pronunciation=pronunciation) + + # 5. Send columns meta + if columns_meta: + yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" + + # 6. Stream all cells + for idx, cell in enumerate(cells): + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": idx + 1, "total": len(cells)}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # 7. Build final result and persist + duration = time.time() - t0 + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " + f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") + + # 8. Send complete event + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" + + +async def _word_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, +): + """SSE generator that yields cell-by-cell OCR progress.""" + t0 = time.time() + + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + + columns_meta = None + total_cells = n_content_rows * n_cols + + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" + + all_cells: List[Dict[str, Any]] = [] + cell_idx = 0 + last_keepalive = time.time() + + for cell, cols_meta, total in build_cell_grid_v2_streaming( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + ): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during streaming for {session_id}") + return + + if columns_meta is None: + columns_meta = cols_meta + meta_update = {"type": "columns", "columns_used": cols_meta} + yield f"data: {json.dumps(meta_update)}\n\n" + + all_cells.append(cell) + cell_idx += 1 + + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": cell_idx, "total": total}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # All cells done + duration = time.time() - t0 + if columns_meta is None: + columns_meta = [] + + # Remove all-empty rows + rows_with_text: set = set() + for c in all_cells: + if c.get("text", "").strip(): + rows_with_text.add(c["row_index"]) + before_filter = len(all_cells) + all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] + empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) + if empty_rows_removed > 0: + logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") + + used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine + + fix_cell_phonetics(all_cells, pronunciation=pronunciation) + + word_result = { + "cells": all_cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(all_cells), + "non_empty_cells": sum(1 for c in all_cells if c.get("text")), + "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), + }, + } + + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(all_cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(all_cells)} cells ({duration:.2f}s)") + + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" diff --git a/klausur-service/backend/ocr_labeling_api.py b/klausur-service/backend/ocr_labeling_api.py index 924f964..0c415e9 100644 --- a/klausur-service/backend/ocr_labeling_api.py +++ b/klausur-service/backend/ocr_labeling_api.py @@ -1,81 +1,4 @@ -""" -OCR Labeling API — Barrel Re-export - -Split into: -- ocr_labeling_models.py — Pydantic models and constants -- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing -- ocr_labeling_routes.py — Session/queue/labeling route handlers -- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers - -All public names are re-exported here for backward compatibility. -""" - -# Models -from ocr_labeling_models import ( # noqa: F401 - LOCAL_STORAGE_PATH, - SessionCreate, - SessionResponse, - ItemResponse, - ConfirmRequest, - CorrectRequest, - SkipRequest, - ExportRequest, - StatsResponse, -) - -# Helpers -from ocr_labeling_helpers import ( # noqa: F401 - VISION_OCR_AVAILABLE, - PADDLEOCR_AVAILABLE, - TROCR_AVAILABLE, - DONUT_AVAILABLE, - MINIO_AVAILABLE, - TRAINING_EXPORT_AVAILABLE, - compute_image_hash, - run_ocr_on_image, - run_vision_ocr_wrapper, - run_paddleocr_wrapper, - run_trocr_wrapper, - run_donut_wrapper, - save_image_locally, - get_image_url, -) - -# Conditional re-exports from helpers' optional imports -try: - from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401 -except ImportError: - pass - -try: - from training_export_service import ( # noqa: F401 - TrainingExportService, - TrainingSample, - get_training_export_service, - ) -except ImportError: - pass - -try: - from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401 -except ImportError: - pass - -try: - from services.trocr_service import run_trocr_ocr # noqa: F401 -except ImportError: - pass - -try: - from services.donut_ocr_service import run_donut_ocr # noqa: F401 -except ImportError: - pass - -try: - from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401 -except ImportError: - pass - -# Routes (router is the main export for app.include_router) -from ocr_labeling_routes import router # noqa: F401 -from ocr_labeling_upload_routes import router as upload_router # noqa: F401 +# Backward-compat shim -- module moved to ocr/labeling/api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.labeling.api") diff --git a/klausur-service/backend/ocr_labeling_helpers.py b/klausur-service/backend/ocr_labeling_helpers.py index 5188670..85835d2 100644 --- a/klausur-service/backend/ocr_labeling_helpers.py +++ b/klausur-service/backend/ocr_labeling_helpers.py @@ -1,205 +1,4 @@ -""" -OCR Labeling - Helper Functions and OCR Wrappers - -Extracted from ocr_labeling_api.py to keep files under 500 LOC. - -DATENSCHUTZ/PRIVACY: -- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama) -- Keine Daten werden an externe Server gesendet -""" - -import os -import hashlib - -from ocr_labeling_models import LOCAL_STORAGE_PATH - -# Try to import Vision OCR service -try: - import sys - sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services')) - from vision_ocr_service import get_vision_ocr_service, VisionOCRService - VISION_OCR_AVAILABLE = True -except ImportError: - VISION_OCR_AVAILABLE = False - print("Warning: Vision OCR service not available") - -# Try to import PaddleOCR from hybrid_vocab_extractor -try: - from hybrid_vocab_extractor import run_paddle_ocr - PADDLEOCR_AVAILABLE = True -except ImportError: - PADDLEOCR_AVAILABLE = False - print("Warning: PaddleOCR not available") - -# Try to import TrOCR service -try: - from services.trocr_service import run_trocr_ocr - TROCR_AVAILABLE = True -except ImportError: - TROCR_AVAILABLE = False - print("Warning: TrOCR service not available") - -# Try to import Donut service -try: - from services.donut_ocr_service import run_donut_ocr - DONUT_AVAILABLE = True -except ImportError: - DONUT_AVAILABLE = False - print("Warning: Donut OCR service not available") - -# Try to import MinIO storage -try: - from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET - MINIO_AVAILABLE = True -except ImportError: - MINIO_AVAILABLE = False - print("Warning: MinIO storage not available, using local storage") - -# Try to import Training Export Service -try: - from training_export_service import ( - TrainingExportService, - TrainingSample, - get_training_export_service, - ) - TRAINING_EXPORT_AVAILABLE = True -except ImportError: - TRAINING_EXPORT_AVAILABLE = False - print("Warning: Training export service not available") - - -# ============================================================================= -# Helper Functions -# ============================================================================= - -def compute_image_hash(image_data: bytes) -> str: - """Compute SHA256 hash of image data.""" - return hashlib.sha256(image_data).hexdigest() - - -async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple: - """ - Run OCR on an image using the specified model. - - Models: - - llama3.2-vision:11b: Vision LLM (default, best for handwriting) - - trocr: Microsoft TrOCR (fast for printed text) - - paddleocr: PaddleOCR + LLM hybrid (4x faster) - - donut: Document Understanding Transformer (structured documents) - - Returns: - Tuple of (ocr_text, confidence) - """ - print(f"Running OCR with model: {model}") - - # Route to appropriate OCR service based on model - if model == "paddleocr": - return await run_paddleocr_wrapper(image_data, filename) - elif model == "donut": - return await run_donut_wrapper(image_data, filename) - elif model == "trocr": - return await run_trocr_wrapper(image_data, filename) - else: - # Default: Vision LLM (llama3.2-vision or similar) - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple: - """Vision LLM OCR wrapper.""" - if not VISION_OCR_AVAILABLE: - print("Vision OCR service not available") - return None, 0.0 - - try: - service = get_vision_ocr_service() - if not await service.is_available(): - print("Vision OCR service not available (is_available check failed)") - return None, 0.0 - - result = await service.extract_text( - image_data, - filename=filename, - is_handwriting=True - ) - return result.text, result.confidence - except Exception as e: - print(f"Vision OCR failed: {e}") - return None, 0.0 - - -async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple: - """PaddleOCR wrapper - uses hybrid_vocab_extractor.""" - if not PADDLEOCR_AVAILABLE: - print("PaddleOCR not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - # run_paddle_ocr returns (regions, raw_text) - regions, raw_text = run_paddle_ocr(image_data) - - if not raw_text: - print("PaddleOCR returned empty text") - return None, 0.0 - - # Calculate average confidence from regions - if regions: - avg_confidence = sum(r.confidence for r in regions) / len(regions) - else: - avg_confidence = 0.5 - - return raw_text, avg_confidence - except Exception as e: - print(f"PaddleOCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple: - """TrOCR wrapper.""" - if not TROCR_AVAILABLE: - print("TrOCR not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - text, confidence = await run_trocr_ocr(image_data) - return text, confidence - except Exception as e: - print(f"TrOCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple: - """Donut OCR wrapper.""" - if not DONUT_AVAILABLE: - print("Donut not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - text, confidence = await run_donut_ocr(image_data) - return text, confidence - except Exception as e: - print(f"Donut OCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str: - """Save image to local storage.""" - session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id) - os.makedirs(session_dir, exist_ok=True) - - filename = f"{item_id}.{extension}" - filepath = os.path.join(session_dir, filename) - - with open(filepath, 'wb') as f: - f.write(image_data) - - return filepath - - -def get_image_url(image_path: str) -> str: - """Get URL for an image.""" - # For local images, return a relative path that the frontend can use - if image_path.startswith(LOCAL_STORAGE_PATH): - relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/') - return f"/api/v1/ocr-label/images/{relative_path}" - # For MinIO images, the path is already a URL or key - return image_path +# Backward-compat shim -- module moved to ocr/labeling/helpers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.labeling.helpers") diff --git a/klausur-service/backend/ocr_labeling_models.py b/klausur-service/backend/ocr_labeling_models.py index f27601f..5985c90 100644 --- a/klausur-service/backend/ocr_labeling_models.py +++ b/klausur-service/backend/ocr_labeling_models.py @@ -1,86 +1,4 @@ -""" -OCR Labeling - Pydantic Models and Constants - -Extracted from ocr_labeling_api.py to keep files under 500 LOC. -""" - -import os -from pydantic import BaseModel -from typing import Optional, Dict -from datetime import datetime - - -# Local storage path (fallback if MinIO not available) -LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling") - - -# ============================================================================= -# Pydantic Models -# ============================================================================= - -class SessionCreate(BaseModel): - name: str - source_type: str = "klausur" # klausur, handwriting_sample, scan - description: Optional[str] = None - ocr_model: Optional[str] = "llama3.2-vision:11b" - - -class SessionResponse(BaseModel): - id: str - name: str - source_type: str - description: Optional[str] - ocr_model: Optional[str] - total_items: int - labeled_items: int - confirmed_items: int - corrected_items: int - skipped_items: int - created_at: datetime - - -class ItemResponse(BaseModel): - id: str - session_id: str - session_name: str - image_path: str - image_url: Optional[str] - ocr_text: Optional[str] - ocr_confidence: Optional[float] - ground_truth: Optional[str] - status: str - metadata: Optional[Dict] - created_at: datetime - - -class ConfirmRequest(BaseModel): - item_id: str - label_time_seconds: Optional[int] = None - - -class CorrectRequest(BaseModel): - item_id: str - ground_truth: str - label_time_seconds: Optional[int] = None - - -class SkipRequest(BaseModel): - item_id: str - - -class ExportRequest(BaseModel): - export_format: str = "generic" # generic, trocr, llama_vision - session_id: Optional[str] = None - batch_id: Optional[str] = None - - -class StatsResponse(BaseModel): - total_sessions: Optional[int] = None - total_items: int - labeled_items: int - confirmed_items: int - corrected_items: int - pending_items: int - exportable_items: Optional[int] = None - accuracy_rate: float - avg_label_time_seconds: Optional[float] = None +# Backward-compat shim -- module moved to ocr/labeling/models.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.labeling.models") diff --git a/klausur-service/backend/ocr_labeling_routes.py b/klausur-service/backend/ocr_labeling_routes.py index b2365da..4d5f73a 100644 --- a/klausur-service/backend/ocr_labeling_routes.py +++ b/klausur-service/backend/ocr_labeling_routes.py @@ -1,241 +1,4 @@ -""" -OCR Labeling - Session and Labeling Route Handlers - -Extracted from ocr_labeling_api.py to keep files under 500 LOC. - -Endpoints: -- POST /sessions - Create labeling session -- GET /sessions - List sessions -- GET /sessions/{id} - Get session -- GET /queue - Get labeling queue -- GET /items/{id} - Get item -- POST /confirm - Confirm OCR -- POST /correct - Correct ground truth -- POST /skip - Skip item -- GET /stats - Get statistics -""" - -from fastapi import APIRouter, HTTPException, Query -from typing import Optional, List -from datetime import datetime -import uuid - -from metrics_db import ( - create_ocr_labeling_session, - get_ocr_labeling_sessions, - get_ocr_labeling_session, - get_ocr_labeling_queue, - get_ocr_labeling_item, - confirm_ocr_label, - correct_ocr_label, - skip_ocr_item, - get_ocr_labeling_stats, -) - -from ocr_labeling_models import ( - SessionCreate, SessionResponse, ItemResponse, - ConfirmRequest, CorrectRequest, SkipRequest, -) -from ocr_labeling_helpers import get_image_url - - -router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) - - -# ============================================================================= -# Session Endpoints -# ============================================================================= - -@router.post("/sessions", response_model=SessionResponse) -async def create_session(session: SessionCreate): - """Create a new OCR labeling session.""" - session_id = str(uuid.uuid4()) - - success = await create_ocr_labeling_session( - session_id=session_id, - name=session.name, - source_type=session.source_type, - description=session.description, - ocr_model=session.ocr_model, - ) - - if not success: - raise HTTPException(status_code=500, detail="Failed to create session") - - return SessionResponse( - id=session_id, - name=session.name, - source_type=session.source_type, - description=session.description, - ocr_model=session.ocr_model, - total_items=0, - labeled_items=0, - confirmed_items=0, - corrected_items=0, - skipped_items=0, - created_at=datetime.utcnow(), - ) - - -@router.get("/sessions", response_model=List[SessionResponse]) -async def list_sessions(limit: int = Query(50, ge=1, le=100)): - """List all OCR labeling sessions.""" - sessions = await get_ocr_labeling_sessions(limit=limit) - - return [ - SessionResponse( - id=s['id'], - name=s['name'], - source_type=s['source_type'], - description=s.get('description'), - ocr_model=s.get('ocr_model'), - total_items=s.get('total_items', 0), - labeled_items=s.get('labeled_items', 0), - confirmed_items=s.get('confirmed_items', 0), - corrected_items=s.get('corrected_items', 0), - skipped_items=s.get('skipped_items', 0), - created_at=s.get('created_at', datetime.utcnow()), - ) - for s in sessions - ] - - -@router.get("/sessions/{session_id}", response_model=SessionResponse) -async def get_session(session_id: str): - """Get a specific OCR labeling session.""" - session = await get_ocr_labeling_session(session_id) - - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - return SessionResponse( - id=session['id'], - name=session['name'], - source_type=session['source_type'], - description=session.get('description'), - ocr_model=session.get('ocr_model'), - total_items=session.get('total_items', 0), - labeled_items=session.get('labeled_items', 0), - confirmed_items=session.get('confirmed_items', 0), - corrected_items=session.get('corrected_items', 0), - skipped_items=session.get('skipped_items', 0), - created_at=session.get('created_at', datetime.utcnow()), - ) - - -# ============================================================================= -# Queue and Item Endpoints -# ============================================================================= - -@router.get("/queue", response_model=List[ItemResponse]) -async def get_labeling_queue( - session_id: Optional[str] = Query(None), - status: str = Query("pending"), - limit: int = Query(10, ge=1, le=50), -): - """Get items from the labeling queue.""" - items = await get_ocr_labeling_queue( - session_id=session_id, - status=status, - limit=limit, - ) - - return [ - ItemResponse( - id=item['id'], - session_id=item['session_id'], - session_name=item.get('session_name', ''), - image_path=item['image_path'], - image_url=get_image_url(item['image_path']), - ocr_text=item.get('ocr_text'), - ocr_confidence=item.get('ocr_confidence'), - ground_truth=item.get('ground_truth'), - status=item.get('status', 'pending'), - metadata=item.get('metadata'), - created_at=item.get('created_at', datetime.utcnow()), - ) - for item in items - ] - - -@router.get("/items/{item_id}", response_model=ItemResponse) -async def get_item(item_id: str): - """Get a specific labeling item.""" - item = await get_ocr_labeling_item(item_id) - - if not item: - raise HTTPException(status_code=404, detail="Item not found") - - return ItemResponse( - id=item['id'], - session_id=item['session_id'], - session_name=item.get('session_name', ''), - image_path=item['image_path'], - image_url=get_image_url(item['image_path']), - ocr_text=item.get('ocr_text'), - ocr_confidence=item.get('ocr_confidence'), - ground_truth=item.get('ground_truth'), - status=item.get('status', 'pending'), - metadata=item.get('metadata'), - created_at=item.get('created_at', datetime.utcnow()), - ) - - -# ============================================================================= -# Labeling Action Endpoints -# ============================================================================= - -@router.post("/confirm") -async def confirm_item(request: ConfirmRequest): - """Confirm that OCR text is correct.""" - success = await confirm_ocr_label( - item_id=request.item_id, - labeled_by="admin", - label_time_seconds=request.label_time_seconds, - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to confirm item") - - return {"status": "confirmed", "item_id": request.item_id} - - -@router.post("/correct") -async def correct_item(request: CorrectRequest): - """Save corrected ground truth for an item.""" - success = await correct_ocr_label( - item_id=request.item_id, - ground_truth=request.ground_truth, - labeled_by="admin", - label_time_seconds=request.label_time_seconds, - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to correct item") - - return {"status": "corrected", "item_id": request.item_id} - - -@router.post("/skip") -async def skip_item(request: SkipRequest): - """Skip an item (unusable image, etc.).""" - success = await skip_ocr_item( - item_id=request.item_id, - labeled_by="admin", - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to skip item") - - return {"status": "skipped", "item_id": request.item_id} - - -@router.get("/stats") -async def get_stats(session_id: Optional[str] = Query(None)): - """Get labeling statistics.""" - stats = await get_ocr_labeling_stats(session_id=session_id) - - if "error" in stats: - raise HTTPException(status_code=500, detail=stats["error"]) - - return stats +# Backward-compat shim -- module moved to ocr/labeling/routes.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.labeling.routes") diff --git a/klausur-service/backend/ocr_labeling_upload_routes.py b/klausur-service/backend/ocr_labeling_upload_routes.py index 0e8a684..e8579e4 100644 --- a/klausur-service/backend/ocr_labeling_upload_routes.py +++ b/klausur-service/backend/ocr_labeling_upload_routes.py @@ -1,313 +1,4 @@ -""" -OCR Labeling - Upload, Run-OCR, and Export Route Handlers - -Extracted from ocr_labeling_routes.py to keep files under 500 LOC. - -Endpoints: -- POST /sessions/{id}/upload - Upload images for labeling -- POST /run-ocr/{item_id} - Run OCR on existing item -- POST /export - Export training data -- GET /training-samples - List training samples -- GET /images/{path} - Serve images from local storage -- GET /exports - List exports -""" - -from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query -from typing import Optional, List -import uuid -import os - -from metrics_db import ( - get_ocr_labeling_session, - add_ocr_labeling_item, - get_ocr_labeling_item, - export_training_samples, - get_training_samples, -) - -from ocr_labeling_models import ( - ExportRequest, - LOCAL_STORAGE_PATH, -) -from ocr_labeling_helpers import ( - compute_image_hash, run_ocr_on_image, - save_image_locally, - MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE, -) - -# Conditional imports -try: - from minio_storage import upload_ocr_image, get_ocr_image -except ImportError: - pass - -try: - from training_export_service import TrainingSample, get_training_export_service -except ImportError: - pass - - -router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) - - -@router.post("/sessions/{session_id}/upload") -async def upload_images( - session_id: str, - files: List[UploadFile] = File(...), - run_ocr: bool = Form(True), - metadata: Optional[str] = Form(None), -): - """ - Upload images to a labeling session. - - Args: - session_id: Session to add images to - files: Image files to upload (PNG, JPG, PDF) - run_ocr: Whether to run OCR immediately (default: True) - metadata: Optional JSON metadata (subject, year, etc.) - """ - import json - - session = await get_ocr_labeling_session(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - meta_dict = None - if metadata: - try: - meta_dict = json.loads(metadata) - except json.JSONDecodeError: - meta_dict = {"raw": metadata} - - results = [] - ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') - - for file in files: - content = await file.read() - image_hash = compute_image_hash(content) - item_id = str(uuid.uuid4()) - - extension = file.filename.split('.')[-1].lower() if file.filename else 'png' - if extension not in ['png', 'jpg', 'jpeg', 'pdf']: - extension = 'png' - - if MINIO_AVAILABLE: - try: - image_path = upload_ocr_image(session_id, item_id, content, extension) - except Exception as e: - print(f"MinIO upload failed, using local storage: {e}") - image_path = save_image_locally(session_id, item_id, content, extension) - else: - image_path = save_image_locally(session_id, item_id, content, extension) - - ocr_text = None - ocr_confidence = None - - if run_ocr and extension != 'pdf': - ocr_text, ocr_confidence = await run_ocr_on_image( - content, - file.filename or f"{item_id}.{extension}", - model=ocr_model - ) - - success = await add_ocr_labeling_item( - item_id=item_id, - session_id=session_id, - image_path=image_path, - image_hash=image_hash, - ocr_text=ocr_text, - ocr_confidence=ocr_confidence, - ocr_model=ocr_model if ocr_text else None, - metadata=meta_dict, - ) - - if success: - results.append({ - "id": item_id, - "filename": file.filename, - "image_path": image_path, - "image_hash": image_hash, - "ocr_text": ocr_text, - "ocr_confidence": ocr_confidence, - "status": "pending", - }) - - return { - "session_id": session_id, - "uploaded_count": len(results), - "items": results, - } - - -@router.post("/export") -async def export_data(request: ExportRequest): - """Export labeled data for training.""" - db_samples = await export_training_samples( - export_format=request.export_format, - session_id=request.session_id, - batch_id=request.batch_id, - exported_by="admin", - ) - - if not db_samples: - return { - "export_format": request.export_format, - "batch_id": request.batch_id, - "exported_count": 0, - "samples": [], - "message": "No labeled samples found to export", - } - - export_result = None - if TRAINING_EXPORT_AVAILABLE: - try: - export_service = get_training_export_service() - - training_samples = [] - for s in db_samples: - training_samples.append(TrainingSample( - id=s.get('id', s.get('item_id', '')), - image_path=s.get('image_path', ''), - ground_truth=s.get('ground_truth', ''), - ocr_text=s.get('ocr_text'), - ocr_confidence=s.get('ocr_confidence'), - metadata=s.get('metadata'), - )) - - export_result = export_service.export( - samples=training_samples, - export_format=request.export_format, - batch_id=request.batch_id, - ) - except Exception as e: - print(f"Training export failed: {e}") - - response = { - "export_format": request.export_format, - "batch_id": request.batch_id or (export_result.batch_id if export_result else None), - "exported_count": len(db_samples), - "samples": db_samples, - } - - if export_result: - response["export_path"] = export_result.export_path - response["manifest_path"] = export_result.manifest_path - - return response - - -@router.get("/training-samples") -async def list_training_samples( - export_format: Optional[str] = Query(None), - batch_id: Optional[str] = Query(None), - limit: int = Query(100, ge=1, le=1000), -): - """Get exported training samples.""" - samples = await get_training_samples( - export_format=export_format, - batch_id=batch_id, - limit=limit, - ) - - return { - "count": len(samples), - "samples": samples, - } - - -@router.get("/images/{path:path}") -async def get_image(path: str): - """Serve an image from local storage.""" - from fastapi.responses import FileResponse - - filepath = os.path.join(LOCAL_STORAGE_PATH, path) - - if not os.path.exists(filepath): - raise HTTPException(status_code=404, detail="Image not found") - - extension = filepath.split('.')[-1].lower() - content_type = { - 'png': 'image/png', - 'jpg': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'pdf': 'application/pdf', - }.get(extension, 'application/octet-stream') - - return FileResponse(filepath, media_type=content_type) - - -@router.post("/run-ocr/{item_id}") -async def run_ocr_for_item(item_id: str): - """Run OCR on an existing item.""" - item = await get_ocr_labeling_item(item_id) - - if not item: - raise HTTPException(status_code=404, detail="Item not found") - - image_path = item['image_path'] - - if image_path.startswith(LOCAL_STORAGE_PATH): - if not os.path.exists(image_path): - raise HTTPException(status_code=404, detail="Image file not found") - with open(image_path, 'rb') as f: - image_data = f.read() - elif MINIO_AVAILABLE: - try: - image_data = get_ocr_image(image_path) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to load image: {e}") - else: - raise HTTPException(status_code=500, detail="Cannot load image") - - session = await get_ocr_labeling_session(item['session_id']) - ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b' - - ocr_text, ocr_confidence = await run_ocr_on_image( - image_data, - os.path.basename(image_path), - model=ocr_model - ) - - if ocr_text is None: - raise HTTPException(status_code=500, detail="OCR failed") - - from metrics_db import get_pool - pool = await get_pool() - if pool: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE ocr_labeling_items - SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4 - WHERE id = $1 - """, - item_id, ocr_text, ocr_confidence, ocr_model - ) - - return { - "item_id": item_id, - "ocr_text": ocr_text, - "ocr_confidence": ocr_confidence, - "ocr_model": ocr_model, - } - - -@router.get("/exports") -async def list_exports(export_format: Optional[str] = Query(None)): - """List all available training data exports.""" - if not TRAINING_EXPORT_AVAILABLE: - return { - "exports": [], - "message": "Training export service not available", - } - - try: - export_service = get_training_export_service() - exports = export_service.list_exports(export_format=export_format) - - return { - "count": len(exports), - "exports": exports, - } - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}") +# Backward-compat shim -- module moved to ocr/labeling/upload_routes.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.labeling.upload_routes") diff --git a/klausur-service/backend/ocr_merge_helpers.py b/klausur-service/backend/ocr_merge_helpers.py index 571c116..40b211c 100644 --- a/klausur-service/backend/ocr_merge_helpers.py +++ b/klausur-service/backend/ocr_merge_helpers.py @@ -1,272 +1,4 @@ -""" -OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results. - -Extracted from ocr_pipeline_ocr_merge.py. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -from typing import List - -logger = logging.getLogger(__name__) - - -def _split_paddle_multi_words(words: list) -> list: - """Split PaddleOCR multi-word boxes into individual word boxes. - - PaddleOCR often returns entire phrases as a single box, e.g. - "More than 200 singers took part in the" with one bounding box. - This splits them into individual words with proportional widths. - Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"]) - and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]). - """ - import re - - result = [] - for w in words: - raw_text = w.get("text", "").strip() - if not raw_text: - continue - # Split on whitespace, before "[" (IPA), and after "!" before letter - tokens = re.split( - r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text - ) - tokens = [t for t in tokens if t] - - if len(tokens) <= 1: - result.append(w) - else: - # Split proportionally by character count - total_chars = sum(len(t) for t in tokens) - if total_chars == 0: - continue - n_gaps = len(tokens) - 1 - gap_px = w["width"] * 0.02 - usable_w = w["width"] - gap_px * n_gaps - cursor = w["left"] - for t in tokens: - token_w = max(1, usable_w * len(t) / total_chars) - result.append({ - "text": t, - "left": round(cursor), - "top": w["top"], - "width": round(token_w), - "height": w["height"], - "conf": w.get("conf", 0), - }) - cursor += token_w + gap_px - return result - - -def _group_words_into_rows(words: list, row_gap: int = 12) -> list: - """Group words into rows by Y-position clustering. - - Words whose vertical centers are within `row_gap` pixels are on the same row. - Returns list of rows, each row is a list of words sorted left-to-right. - """ - if not words: - return [] - # Sort by vertical center - sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) - rows: list = [] - current_row: list = [sorted_words[0]] - current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 - - for w in sorted_words[1:]: - cy = w["top"] + w.get("height", 0) / 2 - if abs(cy - current_cy) <= row_gap: - current_row.append(w) - else: - # Sort current row left-to-right before saving - rows.append(sorted(current_row, key=lambda w: w["left"])) - current_row = [w] - current_cy = cy - if current_row: - rows.append(sorted(current_row, key=lambda w: w["left"])) - return rows - - -def _row_center_y(row: list) -> float: - """Average vertical center of a row of words.""" - if not row: - return 0.0 - return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) - - -def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: - """Merge two word sequences from the same row using sequence alignment. - - Both sequences are sorted left-to-right. Walk through both simultaneously: - - If words match (same/similar text): take Paddle text with averaged coords - - If they don't match: the extra word is unique to one engine, include it - """ - merged = [] - pi, ti = 0, 0 - - while pi < len(paddle_row) and ti < len(tess_row): - pw = paddle_row[pi] - tw = tess_row[ti] - - pt = pw.get("text", "").lower().strip() - tt = tw.get("text", "").lower().strip() - - is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) - - # Spatial overlap check - spatial_match = False - if not is_same: - overlap_left = max(pw["left"], tw["left"]) - overlap_right = min( - pw["left"] + pw.get("width", 0), - tw["left"] + tw.get("width", 0), - ) - overlap_w = max(0, overlap_right - overlap_left) - min_w = min(pw.get("width", 1), tw.get("width", 1)) - if min_w > 0 and overlap_w / min_w >= 0.4: - is_same = True - spatial_match = True - - if is_same: - pc = pw.get("conf", 80) - tc = tw.get("conf", 50) - total = pc + tc - if total == 0: - total = 1 - if spatial_match and pc < tc: - best_text = tw["text"] - else: - best_text = pw["text"] - merged.append({ - "text": best_text, - "left": round((pw["left"] * pc + tw["left"] * tc) / total), - "top": round((pw["top"] * pc + tw["top"] * tc) / total), - "width": round((pw["width"] * pc + tw["width"] * tc) / total), - "height": round((pw["height"] * pc + tw["height"] * tc) / total), - "conf": max(pc, tc), - }) - pi += 1 - ti += 1 - else: - paddle_ahead = any( - tess_row[t].get("text", "").lower().strip() == pt - for t in range(ti + 1, min(ti + 4, len(tess_row))) - ) - tess_ahead = any( - paddle_row[p].get("text", "").lower().strip() == tt - for p in range(pi + 1, min(pi + 4, len(paddle_row))) - ) - - if paddle_ahead and not tess_ahead: - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - elif tess_ahead and not paddle_ahead: - merged.append(pw) - pi += 1 - else: - if pw["left"] <= tw["left"]: - merged.append(pw) - pi += 1 - else: - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - while pi < len(paddle_row): - merged.append(paddle_row[pi]) - pi += 1 - while ti < len(tess_row): - tw = tess_row[ti] - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - return merged - - -def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: - """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.""" - if not paddle_words and not tess_words: - return [] - if not paddle_words: - return [w for w in tess_words if w.get("conf", 0) >= 40] - if not tess_words: - return list(paddle_words) - - paddle_rows = _group_words_into_rows(paddle_words) - tess_rows = _group_words_into_rows(tess_words) - - used_tess_rows: set = set() - merged_all: list = [] - - for pr in paddle_rows: - pr_cy = _row_center_y(pr) - best_dist, best_tri = float("inf"), -1 - for tri, tr in enumerate(tess_rows): - if tri in used_tess_rows: - continue - tr_cy = _row_center_y(tr) - dist = abs(pr_cy - tr_cy) - if dist < best_dist: - best_dist, best_tri = dist, tri - - max_row_dist = max( - max((w.get("height", 20) for w in pr), default=20), - 15, - ) - - if best_tri >= 0 and best_dist <= max_row_dist: - tr = tess_rows[best_tri] - used_tess_rows.add(best_tri) - merged_all.extend(_merge_row_sequences(pr, tr)) - else: - merged_all.extend(pr) - - for tri, tr in enumerate(tess_rows): - if tri not in used_tess_rows: - for tw in tr: - if tw.get("conf", 0) >= 40: - merged_all.append(tw) - - return merged_all - - -def _deduplicate_words(words: list) -> list: - """Remove duplicate words with same text at overlapping positions.""" - if not words: - return words - - result: list = [] - for w in words: - wt = w.get("text", "").lower().strip() - if not wt: - continue - is_dup = False - w_right = w["left"] + w.get("width", 0) - w_bottom = w["top"] + w.get("height", 0) - for existing in result: - et = existing.get("text", "").lower().strip() - if wt != et: - continue - ox_l = max(w["left"], existing["left"]) - ox_r = min(w_right, existing["left"] + existing.get("width", 0)) - ox = max(0, ox_r - ox_l) - min_w = min(w.get("width", 1), existing.get("width", 1)) - if min_w <= 0 or ox / min_w < 0.5: - continue - oy_t = max(w["top"], existing["top"]) - oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) - oy = max(0, oy_b - oy_t) - min_h = min(w.get("height", 1), existing.get("height", 1)) - if min_h > 0 and oy / min_h >= 0.5: - is_dup = True - break - if not is_dup: - result.append(w) - - removed = len(words) - len(result) - if removed: - logger.info("dedup: removed %d duplicate words", removed) - return result +# Backward-compat shim -- module moved to ocr/pipeline/merge_helpers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.merge_helpers") diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index 56eda09..18ca941 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -1,63 +1,4 @@ -""" -OCR Pipeline API - Schrittweise Seitenrekonstruktion. - -Thin wrapper that assembles all sub-module routers into a single -composite router. Backward-compatible: main.py and tests can still -import ``router``, ``_cache``, and helper functions from here. - -Sub-modules (each < 1 000 lines): - ocr_pipeline_common – shared state, cache, Pydantic models, helpers - ocr_pipeline_sessions – session CRUD, image serving, doc-type - ocr_pipeline_geometry – deskew, dewarp, structure, columns - ocr_pipeline_rows – row detection, box-overlay helper - ocr_pipeline_words – word detection (SSE), paddle-direct, word GT - ocr_pipeline_ocr_merge – paddle/tesseract merge helpers, kombi endpoints - ocr_pipeline_postprocess – LLM review, reconstruction, export, validation - ocr_pipeline_auto – auto-mode orchestrator, reprocess - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -from fastapi import APIRouter - -# --------------------------------------------------------------------------- -# Shared state (imported by main.py and orientation_crop_api.py) -# --------------------------------------------------------------------------- -from ocr_pipeline_common import ( # noqa: F401 – re-exported - _cache, - _BORDER_GHOST_CHARS, - _filter_border_ghost_words, -) - -# --------------------------------------------------------------------------- -# Sub-module routers -# --------------------------------------------------------------------------- -from ocr_pipeline_sessions import router as _sessions_router -from ocr_pipeline_geometry import router as _geometry_router -from ocr_pipeline_rows import router as _rows_router -from ocr_pipeline_words import router as _words_router -from ocr_pipeline_ocr_merge import ( - router as _ocr_merge_router, - # Re-export for test backward compatibility - _split_paddle_multi_words, # noqa: F401 - _group_words_into_rows, # noqa: F401 - _merge_row_sequences, # noqa: F401 - _merge_paddle_tesseract, # noqa: F401 -) -from ocr_pipeline_postprocess import router as _postprocess_router -from ocr_pipeline_auto import router as _auto_router -from ocr_pipeline_regression import router as _regression_router - -# --------------------------------------------------------------------------- -# Composite router (used by main.py) -# --------------------------------------------------------------------------- -router = APIRouter() -router.include_router(_sessions_router) -router.include_router(_geometry_router) -router.include_router(_rows_router) -router.include_router(_words_router) -router.include_router(_ocr_merge_router) -router.include_router(_postprocess_router) -router.include_router(_auto_router) -router.include_router(_regression_router) +# Backward-compat shim -- module moved to ocr/pipeline/api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.api") diff --git a/klausur-service/backend/ocr_pipeline_auto.py b/klausur-service/backend/ocr_pipeline_auto.py index f354659..3ef8168 100644 --- a/klausur-service/backend/ocr_pipeline_auto.py +++ b/klausur-service/backend/ocr_pipeline_auto.py @@ -1,23 +1,4 @@ -""" -OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export. - -Split into submodules: -- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess -- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -from fastapi import APIRouter - -from ocr_pipeline_reprocess import router as _reprocess_router -from ocr_pipeline_auto_steps import router as _steps_router - -# Combine both sub-routers into a single router for backwards compatibility. -# The consumer imports `from ocr_pipeline_auto import router as _auto_router`. -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) -router.include_router(_reprocess_router) -router.include_router(_steps_router) - -__all__ = ["router"] +# Backward-compat shim -- module moved to ocr/pipeline/auto.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto") diff --git a/klausur-service/backend/ocr_pipeline_auto_helpers.py b/klausur-service/backend/ocr_pipeline_auto_helpers.py index 05df86d..306d2b3 100644 --- a/klausur-service/backend/ocr_pipeline_auto_helpers.py +++ b/klausur-service/backend/ocr_pipeline_auto_helpers.py @@ -1,84 +1,4 @@ -""" -OCR Pipeline Auto-Mode Helpers. - -VLM shear detection, SSE event formatting, and request models. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import os -import re -from typing import Any, Dict - -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -class RunAutoRequest(BaseModel): - from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review - ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" - pronunciation: str = "british" - skip_llm_review: bool = False - dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" - - -async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: - """Format a single SSE event line.""" - payload = {"step": step, "status": status, **data} - return f"data: {json.dumps(payload)}\n\n" - - -async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: - """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. - - The VLM is shown the image and asked: are the column/table borders tilted? - If yes, by how many degrees? Returns a dict with shear_degrees and confidence. - Confidence is 0.0 if Ollama is unavailable or parsing fails. - """ - import httpx - import base64 - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " - "Are they perfectly vertical, or do they tilt slightly? " - "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " - "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " - "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " - "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" - ) - - img_b64 = base64.b64encode(image_bytes).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON from response (may have surrounding text) - match = re.search(r'\{[^}]+\}', text) - if match: - data = json.loads(match.group(0)) - shear = float(data.get("shear_degrees", 0.0)) - conf = float(data.get("confidence", 0.0)) - # Clamp to reasonable range - shear = max(-3.0, min(3.0, shear)) - conf = max(0.0, min(1.0, conf)) - return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} - except Exception as e: - logger.warning(f"VLM dewarp failed: {e}") - - return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} +# Backward-compat shim -- module moved to ocr/pipeline/auto_helpers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_helpers") diff --git a/klausur-service/backend/ocr_pipeline_auto_steps.py b/klausur-service/backend/ocr_pipeline_auto_steps.py index 4961ee9..eff30b2 100644 --- a/klausur-service/backend/ocr_pipeline_auto_steps.py +++ b/klausur-service/backend/ocr_pipeline_auto_steps.py @@ -1,528 +1,4 @@ -""" -OCR Pipeline Auto-Mode Orchestrator. - -POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import time -from dataclasses import asdict -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse - -from cv_vocab_pipeline import ( - OLLAMA_REVIEW_MODEL, - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _detect_header_footer_gaps, - _detect_sub_columns, - _fix_character_confusion, - _fix_phonetic_brackets, - fix_cell_phonetics, - analyze_layout, - build_cell_grid, - classify_column_types, - create_layout_image, - create_ocr_image, - deskew_image, - deskew_image_by_word_alignment, - detect_column_geometry, - detect_row_geometry, - _apply_shear, - dewarp_image, - llm_review_entries, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_auto_helpers import ( - RunAutoRequest, - auto_sse_event as _auto_sse_event, - detect_shear_with_vlm as _detect_shear_with_vlm, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(tags=["ocr-pipeline"]) - -@router.post("/sessions/{session_id}/run-auto") -async def run_auto(session_id: str, req: RunAutoRequest, request: Request): - """Run the full OCR pipeline automatically from a given step, streaming SSE progress. - - Steps: - 1. Deskew -- straighten the scan - 2. Dewarp -- correct vertical shear (ensemble CV or VLM) - 3. Columns -- detect column layout - 4. Rows -- detect row layout - 5. Words -- OCR each cell - 6. LLM review -- correct OCR errors (optional) - - Already-completed steps are skipped unless `from_step` forces a rerun. - Yields SSE events of the form: - data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} - - Final event: - data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} - """ - if req.from_step < 1 or req.from_step > 6: - raise HTTPException(status_code=400, detail="from_step must be 1-6") - if req.dewarp_method not in ("ensemble", "vlm", "cv"): - raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") - - if session_id not in _cache: - await _load_session_to_cache(session_id) - - async def _generate(): - steps_run: List[str] = [] - steps_skipped: List[str] = [] - error_step: Optional[str] = None - - session = await get_session_db(session_id) - if not session: - yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) - return - - cached = _get_cached(session_id) - - # Step 1: Deskew - if req.from_step <= 1: - yield await _auto_sse_event("deskew", "start", {}) - try: - t0 = time.time() - orig_bgr = cached.get("original_bgr") - if orig_bgr is None: - raise ValueError("Original image not loaded") - - try: - deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) - except Exception: - deskewed_hough, angle_hough = orig_bgr, 0.0 - - success_enc, png_orig = cv2.imencode(".png", orig_bgr) - orig_bytes = png_orig.tobytes() if success_enc else b"" - try: - deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) - except Exception: - deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 - - if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: - method_used = "word_alignment" - angle_applied = angle_wa - wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) - deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) - if deskewed_bgr is None: - deskewed_bgr = deskewed_hough - method_used = "hough" - angle_applied = angle_hough - else: - method_used = "hough" - angle_applied = angle_hough - deskewed_bgr = deskewed_hough - - success, png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = png_buf.tobytes() if success else b"" - - deskew_result = { - "method_used": method_used, - "rotation_degrees": round(float(angle_applied), 3), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["deskewed_bgr"] = deskewed_bgr - cached["deskew_result"] = deskew_result - await update_session_db( - session_id, - deskewed_png=deskewed_png, - deskew_result=deskew_result, - auto_rotation_degrees=float(angle_applied), - current_step=3, - ) - session = await get_session_db(session_id) - - steps_run.append("deskew") - yield await _auto_sse_event("deskew", "done", deskew_result) - except Exception as e: - logger.error(f"Auto-mode deskew failed for {session_id}: {e}") - error_step = "deskew" - yield await _auto_sse_event("deskew", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("deskew") - yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) - - # Step 2: Dewarp - if req.from_step <= 2: - yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) - try: - t0 = time.time() - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise ValueError("Deskewed image not available") - - if req.dewarp_method == "vlm": - success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) - img_bytes = png_buf.tobytes() if success_enc else b"" - vlm_det = await _detect_shear_with_vlm(img_bytes) - shear_deg = vlm_det["shear_degrees"] - if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: - dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) - else: - dewarped_bgr = deskewed_bgr - dewarp_info = { - "method": vlm_det["method"], - "shear_degrees": shear_deg, - "confidence": vlm_det["confidence"], - "detections": [vlm_det], - } - else: - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) - - success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success_enc else b"" - - dewarp_result = { - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(time.time() - t0, 2), - "detections": dewarp_info.get("detections", []), - } - - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), - current_step=4, - ) - session = await get_session_db(session_id) - - steps_run.append("dewarp") - yield await _auto_sse_event("dewarp", "done", dewarp_result) - except Exception as e: - logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") - error_step = "dewarp" - yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("dewarp") - yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) - - # Step 3: Columns - if req.from_step <= 3: - yield await _auto_sse_event("columns", "start", {}) - try: - t0 = time.time() - col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if col_img is None: - raise ValueError("Cropped/dewarped image not available") - - ocr_img = create_ocr_image(col_img) - h, w = ocr_img.shape[:2] - - geo_result = detect_column_geometry(ocr_img, col_img) - if geo_result is None: - layout_img = create_layout_image(col_img) - regions = analyze_layout(layout_img, ocr_img) - cached["_word_dicts"] = None - cached["_inv"] = None - cached["_content_bounds"] = None - else: - geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - content_w = right_x - left_x - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) - geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, - top_y=top_y, header_y=header_y, footer_y=footer_y) - regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, - left_x=left_x, right_x=right_x, inv=inv) - - columns = [asdict(r) for r in regions] - column_result = { - "columns": columns, - "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["column_result"] = column_result - await update_session_db(session_id, column_result=column_result, - row_result=None, word_result=None, current_step=6) - session = await get_session_db(session_id) - - steps_run.append("columns") - yield await _auto_sse_event("columns", "done", { - "column_count": len(columns), - "duration_seconds": column_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode columns failed for {session_id}: {e}") - error_step = "columns" - yield await _auto_sse_event("columns", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("columns") - yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) - - # Step 4: Rows - if req.from_step <= 4: - yield await _auto_sse_event("rows", "start", {}) - try: - t0 = time.time() - row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - column_result = session.get("column_result") or cached.get("column_result") - if not column_result or not column_result.get("columns"): - raise ValueError("Column detection must complete first") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - word_dicts = cached.get("_word_dicts") - inv = cached.get("_inv") - content_bounds = cached.get("_content_bounds") - - if word_dicts is None or inv is None or content_bounds is None: - ocr_img_tmp = create_ocr_image(row_img) - geo_result = detect_column_geometry(ocr_img_tmp, row_img) - if geo_result is None: - raise ValueError("Column geometry detection failed -- cannot detect rows") - _g, lx, rx, ty, by, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (lx, rx, ty, by) - content_bounds = (lx, rx, ty, by) - - left_x, right_x, top_y, bottom_y = content_bounds - row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - - row_list = [ - { - "index": r.index, "x": r.x, "y": r.y, - "width": r.width, "height": r.height, - "word_count": r.word_count, - "row_type": r.row_type, - "gap_before": r.gap_before, - } - for r in row_geoms - ] - row_result = { - "rows": row_list, - "row_count": len(row_list), - "content_rows": len([r for r in row_geoms if r.row_type == "content"]), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["row_result"] = row_result - await update_session_db(session_id, row_result=row_result, current_step=7) - session = await get_session_db(session_id) - - steps_run.append("rows") - yield await _auto_sse_event("rows", "done", { - "row_count": len(row_list), - "content_rows": row_result["content_rows"], - "duration_seconds": row_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode rows failed for {session_id}: {e}") - error_step = "rows" - yield await _auto_sse_event("rows", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("rows") - yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) - - # Step 5: Words (OCR) - if req.from_step <= 5: - yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) - try: - t0 = time.time() - word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - - column_result = session.get("column_result") or cached.get("column_result") - row_result = session.get("row_result") or cached.get("row_result") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - row_geoms = [ - RowGeometry( - index=r["index"], x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - word_dicts = cached.get("_word_dicts") - if word_dicts is not None: - content_bounds = cached.get("_content_bounds") - top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - ocr_img = create_ocr_image(word_img) - img_h, img_w = word_img.shape[:2] - - cells, columns_meta = build_cell_grid( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=req.ocr_engine, img_bgr=word_img, - ) - duration = time.time() - t0 - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine - - fix_cell_phonetics(cells, pronunciation=req.pronunciation) - - word_result_data = { - "cells": cells, - "grid_shape": { - "rows": n_content_rows, - "cols": len(columns_meta), - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) - word_result_data["vocab_entries"] = entries - word_result_data["entries"] = entries - word_result_data["entry_count"] = len(entries) - word_result_data["summary"]["total_entries"] = len(entries) - - await update_session_db(session_id, word_result=word_result_data, current_step=8) - cached["word_result"] = word_result_data - session = await get_session_db(session_id) - - steps_run.append("words") - yield await _auto_sse_event("words", "done", { - "total_cells": len(cells), - "layout": word_result_data["layout"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": word_result_data["summary"], - }) - except Exception as e: - logger.error(f"Auto-mode words failed for {session_id}: {e}") - error_step = "words" - yield await _auto_sse_event("words", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("words") - yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) - - # Step 6: LLM Review (optional) - if req.from_step <= 6 and not req.skip_llm_review: - yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) - try: - session = await get_session_db(session_id) - word_result = session.get("word_result") or cached.get("word_result") - entries = word_result.get("entries") or word_result.get("vocab_entries") or [] - - if not entries: - yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) - steps_skipped.append("llm_review") - else: - reviewed = await llm_review_entries(entries) - - session = await get_session_db(session_id) - word_result_updated = dict(session.get("word_result") or {}) - word_result_updated["entries"] = reviewed - word_result_updated["vocab_entries"] = reviewed - word_result_updated["llm_reviewed"] = True - word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL - - await update_session_db(session_id, word_result=word_result_updated, current_step=9) - cached["word_result"] = word_result_updated - - steps_run.append("llm_review") - yield await _auto_sse_event("llm_review", "done", { - "entries_reviewed": len(reviewed), - "model": OLLAMA_REVIEW_MODEL, - }) - except Exception as e: - logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") - yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) - steps_skipped.append("llm_review") - else: - steps_skipped.append("llm_review") - reason = "skipped by request" if req.skip_llm_review else "from_step > 6" - yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) - - # Final event - yield await _auto_sse_event("complete", "done", { - "steps_run": steps_run, - "steps_skipped": steps_skipped, - }) - - return StreamingResponse( - _generate(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) +# Backward-compat shim -- module moved to ocr/pipeline/auto_steps.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.auto_steps") diff --git a/klausur-service/backend/ocr_pipeline_columns.py b/klausur-service/backend/ocr_pipeline_columns.py index 6572509..6a9d510 100644 --- a/klausur-service/backend/ocr_pipeline_columns.py +++ b/klausur-service/backend/ocr_pipeline_columns.py @@ -1,293 +1,4 @@ -""" -OCR Pipeline Column Detection Endpoints (Step 5) - -Detect invisible columns, manual column override, and ground truth. -Extracted from ocr_pipeline_geometry.py for file-size compliance. -""" - -import logging -import time -from dataclasses import asdict -from datetime import datetime -from typing import Dict, List - -import cv2 -from fastapi import APIRouter, HTTPException - -from cv_vocab_pipeline import ( - _detect_header_footer_gaps, - _detect_sub_columns, - classify_column_types, - create_layout_image, - create_ocr_image, - analyze_layout, - detect_column_geometry_zoned, - expand_narrow_columns, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _append_pipeline_log, - ManualColumnsRequest, - ColumnGroundTruthRequest, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -@router.post("/sessions/{session_id}/columns") -async def detect_columns(session_id: str): - """Run column detection on the cropped (or dewarped) image.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection") - - # ----------------------------------------------------------------------- - # Sub-sessions (box crops): skip column detection entirely. - # Instead, create a single pseudo-column spanning the full image width. - # Also run Tesseract + binarization here so that the row detection step - # can reuse the cached intermediates (_word_dicts, _inv, _content_bounds) - # instead of falling back to detect_column_geometry() which may fail - # on small box images with < 5 words. - # ----------------------------------------------------------------------- - session = await get_session_db(session_id) - if session and session.get("parent_session_id"): - h, w = img_bgr.shape[:2] - - # Binarize + invert for row detection (horizontal projection profile) - ocr_img = create_ocr_image(img_bgr) - inv = cv2.bitwise_not(ocr_img) - - # Run Tesseract to get word bounding boxes. - try: - from PIL import Image as PILImage - pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - import pytesseract - data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT) - word_dicts = [] - for i in range(len(data['text'])): - conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1 - text = str(data['text'][i]).strip() - if conf < 30 or not text: - continue - word_dicts.append({ - 'text': text, 'conf': conf, - 'left': int(data['left'][i]), - 'top': int(data['top'][i]), - 'width': int(data['width'][i]), - 'height': int(data['height'][i]), - }) - # Log all words including low-confidence ones for debugging - all_count = sum(1 for i in range(len(data['text'])) - if str(data['text'][i]).strip()) - low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) - for i in range(len(data['text'])) - if str(data['text'][i]).strip() - and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30 - and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0] - if low_conf: - logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}") - logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)") - except Exception as e: - logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}") - word_dicts = [] - - # Cache intermediates for row detection (detect_rows reuses these) - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (0, w, 0, h) - - column_result = { - "columns": [{ - "type": "column_text", - "x": 0, "y": 0, - "width": w, "height": h, - }], - "zones": None, - "boxes_detected": 0, - "duration_seconds": 0, - "method": "sub_session_pseudo_column", - } - await update_session_db( - session_id, - column_result=column_result, - row_result=None, - word_result=None, - current_step=6, - ) - cached["column_result"] = column_result - cached.pop("row_result", None) - cached.pop("word_result", None) - logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px") - return {"session_id": session_id, **column_result} - - t0 = time.time() - - # Binarized image for layout analysis - ocr_img = create_ocr_image(img_bgr) - h, w = ocr_img.shape[:2] - - # Phase A: Zone-aware geometry detection - zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr) - - boxes_detected = 0 - if zoned_result is None: - # Fallback to projection-based layout - layout_img = create_layout_image(img_bgr) - regions = analyze_layout(layout_img, ocr_img) - zones_data = None - else: - geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result - content_w = right_x - left_x - boxes_detected = len(boxes) - - # Cache intermediates for row detection (avoids second Tesseract run) - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - cached["_zones_data"] = zones_data - cached["_boxes_detected"] = boxes_detected - - # Detect header/footer early so sub-column clustering ignores them - header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) - - # Split sub-columns (e.g. page references) before classification - geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, - top_y=top_y, header_y=header_y, footer_y=footer_y) - - # Expand narrow columns (sub-columns are often very narrow) - geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts) - - # Phase B: Content-based classification - regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, - left_x=left_x, right_x=right_x, inv=inv) - - duration = time.time() - t0 - - columns = [asdict(r) for r in regions] - - # Determine classification methods used - methods = list(set( - c.get("classification_method", "") for c in columns - if c.get("classification_method") - )) - - column_result = { - "columns": columns, - "classification_methods": methods, - "duration_seconds": round(duration, 2), - "boxes_detected": boxes_detected, - } - - # Add zone data when boxes are present - if zones_data and boxes_detected > 0: - column_result["zones"] = zones_data - - # Persist to DB -- also invalidate downstream results (rows, words) - await update_session_db( - session_id, - column_result=column_result, - row_result=None, - word_result=None, - current_step=6, - ) - - # Update cache - cached["column_result"] = column_result - cached.pop("row_result", None) - cached.pop("word_result", None) - - col_count = len([c for c in columns if c["type"].startswith("column")]) - logger.info(f"OCR Pipeline: columns session {session_id}: " - f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)") - - img_w = img_bgr.shape[1] - await _append_pipeline_log(session_id, "columns", { - "total_columns": len(columns), - "column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns], - "column_types": [c["type"] for c in columns], - "boxes_detected": boxes_detected, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **column_result, - } - - -@router.post("/sessions/{session_id}/columns/manual") -async def set_manual_columns(session_id: str, req: ManualColumnsRequest): - """Override detected columns with manual definitions.""" - column_result = { - "columns": req.columns, - "duration_seconds": 0, - "method": "manual", - } - - await update_session_db(session_id, column_result=column_result, - row_result=None, word_result=None) - - if session_id in _cache: - _cache[session_id]["column_result"] = column_result - _cache[session_id].pop("row_result", None) - _cache[session_id].pop("word_result", None) - - logger.info(f"OCR Pipeline: manual columns session {session_id}: " - f"{len(req.columns)} columns set") - - return {"session_id": session_id, **column_result} - - -@router.post("/sessions/{session_id}/ground-truth/columns") -async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): - """Save ground truth feedback for the column detection step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_columns": req.corrected_columns, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "column_result": session.get("column_result"), - } - ground_truth["columns"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@router.get("/sessions/{session_id}/ground-truth/columns") -async def get_column_ground_truth(session_id: str): - """Retrieve saved ground truth for column detection, including auto vs GT diff.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - columns_gt = ground_truth.get("columns") - if not columns_gt: - raise HTTPException(status_code=404, detail="No column ground truth saved") - - return { - "session_id": session_id, - "columns_gt": columns_gt, - "columns_auto": session.get("column_result"), - } +# Backward-compat shim -- module moved to ocr/pipeline/columns.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.columns") diff --git a/klausur-service/backend/ocr_pipeline_common.py b/klausur-service/backend/ocr_pipeline_common.py index 789df26..4eacaa4 100644 --- a/klausur-service/backend/ocr_pipeline_common.py +++ b/klausur-service/backend/ocr_pipeline_common.py @@ -1,354 +1,4 @@ -""" -Shared common module for the OCR pipeline. - -Contains in-memory cache, helper functions, Pydantic request models, -pipeline logging, and border-ghost word filtering used by the pipeline -API endpoints and related modules. -""" - -import logging -import re -import time -from datetime import datetime -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -from fastapi import HTTPException -from pydantic import BaseModel - -from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db - -__all__ = [ - # Cache - "_cache", - # Helper functions - "_get_base_image_png", - "_load_session_to_cache", - "_get_cached", - # Pydantic models - "ManualDeskewRequest", - "DeskewGroundTruthRequest", - "ManualDewarpRequest", - "CombinedAdjustRequest", - "DewarpGroundTruthRequest", - "VALID_DOCUMENT_CATEGORIES", - "UpdateSessionRequest", - "ManualColumnsRequest", - "ColumnGroundTruthRequest", - "ManualRowsRequest", - "RowGroundTruthRequest", - "RemoveHandwritingRequest", - # Pipeline log - "_append_pipeline_log", - # Border-ghost filter - "_BORDER_GHOST_CHARS", - "_filter_border_ghost_words", -] - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# In-memory cache for active sessions (BGR numpy arrays for processing) -# DB is source of truth, cache holds BGR arrays during active processing. -# --------------------------------------------------------------------------- - -_cache: Dict[str, Dict[str, Any]] = {} - - -async def _get_base_image_png(session_id: str) -> Optional[bytes]: - """Get the best available base image for a session (cropped > dewarped > original).""" - for img_type in ("cropped", "dewarped", "original"): - png_data = await get_session_image(session_id, img_type) - if png_data: - return png_data - return None - - -async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: - """Load session from DB into cache, decoding PNGs to BGR arrays.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - if session_id in _cache: - return _cache[session_id] - - cache_entry: Dict[str, Any] = { - "id": session_id, - **session, - "original_bgr": None, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - } - - # Decode images from DB into BGR numpy arrays - for img_type, bgr_key in [ - ("original", "original_bgr"), - ("oriented", "oriented_bgr"), - ("cropped", "cropped_bgr"), - ("deskewed", "deskewed_bgr"), - ("dewarped", "dewarped_bgr"), - ]: - png_data = await get_session_image(session_id, img_type) - if png_data: - arr = np.frombuffer(png_data, dtype=np.uint8) - bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) - cache_entry[bgr_key] = bgr - - # Sub-sessions: original image IS the cropped box region. - # Promote original_bgr to cropped_bgr so downstream steps find it. - if session.get("parent_session_id") and cache_entry["original_bgr"] is not None: - if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None: - cache_entry["cropped_bgr"] = cache_entry["original_bgr"] - - _cache[session_id] = cache_entry - return cache_entry - - -def _get_cached(session_id: str) -> Dict[str, Any]: - """Get from cache or raise 404.""" - entry = _cache.get(session_id) - if not entry: - raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") - return entry - - -# --------------------------------------------------------------------------- -# Pydantic Models -# --------------------------------------------------------------------------- - -class ManualDeskewRequest(BaseModel): - angle: float - - -class DeskewGroundTruthRequest(BaseModel): - is_correct: bool - corrected_angle: Optional[float] = None - notes: Optional[str] = None - - -class ManualDewarpRequest(BaseModel): - shear_degrees: float - - -class CombinedAdjustRequest(BaseModel): - rotation_degrees: float = 0.0 - shear_degrees: float = 0.0 - - -class DewarpGroundTruthRequest(BaseModel): - is_correct: bool - corrected_shear: Optional[float] = None - notes: Optional[str] = None - - -VALID_DOCUMENT_CATEGORIES = { - 'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite', - 'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges', -} - - -class UpdateSessionRequest(BaseModel): - name: Optional[str] = None - document_category: Optional[str] = None - - -class ManualColumnsRequest(BaseModel): - columns: List[Dict[str, Any]] - - -class ColumnGroundTruthRequest(BaseModel): - is_correct: bool - corrected_columns: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -class ManualRowsRequest(BaseModel): - rows: List[Dict[str, Any]] - - -class RowGroundTruthRequest(BaseModel): - is_correct: bool - corrected_rows: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -class RemoveHandwritingRequest(BaseModel): - method: str = "auto" # "auto" | "telea" | "ns" - target_ink: str = "all" # "all" | "colored" | "pencil" - dilation: int = 2 # mask dilation iterations (0-5) - use_source: str = "auto" # "original" | "deskewed" | "auto" - - -# --------------------------------------------------------------------------- -# Pipeline Log Helper -# --------------------------------------------------------------------------- - -async def _append_pipeline_log( - session_id: str, - step_name: str, - metrics: Dict[str, Any], - success: bool = True, - duration_ms: Optional[int] = None, -): - """Append a step entry to the session's pipeline_log JSONB.""" - session = await get_session_db(session_id) - if not session: - return - log = session.get("pipeline_log") or {"steps": []} - if not isinstance(log, dict): - log = {"steps": []} - entry = { - "step": step_name, - "completed_at": datetime.utcnow().isoformat(), - "success": success, - "metrics": metrics, - } - if duration_ms is not None: - entry["duration_ms"] = duration_ms - log.setdefault("steps", []).append(entry) - await update_session_db(session_id, pipeline_log=log) - - -# --------------------------------------------------------------------------- -# Border-ghost word filter -# --------------------------------------------------------------------------- - -# Characters that OCR produces when reading box-border lines. -_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"") - - -def _filter_border_ghost_words( - word_result: Dict, - boxes: List, -) -> int: - """Remove OCR words that are actually box border lines. - - A word is considered a border ghost when it sits on a known box edge - (left, right, top, or bottom) and looks like a line artefact (narrow - aspect ratio or text consists only of line-like characters). - - After removing ghost cells, columns that have become empty are also - removed from ``columns_used`` so the grid no longer shows phantom - columns. - - Modifies *word_result* in-place and returns the number of removed cells. - """ - if not boxes or not word_result: - return 0 - - cells = word_result.get("cells") - if not cells: - return 0 - - # Build border bands — vertical (X) and horizontal (Y) - x_bands = [] # list of (x_lo, x_hi) - y_bands = [] # list of (y_lo, y_hi) - for b in boxes: - bx = b.x if hasattr(b, "x") else b.get("x", 0) - by = b.y if hasattr(b, "y") else b.get("y", 0) - bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) - bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) - bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3) - margin = max(bt * 2, 10) + 6 # generous margin - - # Vertical edges (left / right) - x_bands.append((bx - margin, bx + margin)) - x_bands.append((bx + bw - margin, bx + bw + margin)) - # Horizontal edges (top / bottom) - y_bands.append((by - margin, by + margin)) - y_bands.append((by + bh - margin, by + bh + margin)) - - img_w = word_result.get("image_width", 1) - img_h = word_result.get("image_height", 1) - - def _is_ghost(cell: Dict) -> bool: - text = (cell.get("text") or "").strip() - if not text: - return False - - # Compute absolute pixel position - if cell.get("bbox_px"): - px = cell["bbox_px"] - cx = px["x"] + px["w"] / 2 - cy = px["y"] + px["h"] / 2 - cw = px["w"] - ch = px["h"] - elif cell.get("bbox_pct"): - pct = cell["bbox_pct"] - cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2 - cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2 - cw = (pct["w"] / 100) * img_w - ch = (pct["h"] / 100) * img_h - else: - return False - - # Check if center sits on a vertical or horizontal border - on_vertical = any(lo <= cx <= hi for lo, hi in x_bands) - on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands) - if not on_vertical and not on_horizontal: - return False - - # Very short text (1-2 chars) on a border → very likely ghost - if len(text) <= 2: - # Narrow vertically (line-like) or narrow horizontally (dash-like)? - if ch > 0 and cw / ch < 0.5: - return True - if cw > 0 and ch / cw < 0.5: - return True - # Text is only border-ghost characters? - if all(c in _BORDER_GHOST_CHARS for c in text): - return True - - # Longer text but still only ghost chars and very narrow - if all(c in _BORDER_GHOST_CHARS for c in text): - if ch > 0 and cw / ch < 0.35: - return True - if cw > 0 and ch / cw < 0.35: - return True - return True # all ghost chars on a border → remove - - return False - - before = len(cells) - word_result["cells"] = [c for c in cells if not _is_ghost(c)] - removed = before - len(word_result["cells"]) - - # --- Remove empty columns from columns_used --- - columns_used = word_result.get("columns_used") - if removed and columns_used and len(columns_used) > 1: - remaining_cells = word_result["cells"] - occupied_cols = {c.get("col_index") for c in remaining_cells} - before_cols = len(columns_used) - columns_used = [col for col in columns_used if col.get("index") in occupied_cols] - - # Re-index columns and remap cell col_index values - if len(columns_used) < before_cols: - old_to_new = {} - for new_i, col in enumerate(columns_used): - old_to_new[col["index"]] = new_i - col["index"] = new_i - for cell in remaining_cells: - old_ci = cell.get("col_index") - if old_ci in old_to_new: - cell["col_index"] = old_to_new[old_ci] - word_result["columns_used"] = columns_used - logger.info("border-ghost: removed %d empty column(s), %d remaining", - before_cols - len(columns_used), len(columns_used)) - - if removed: - # Update summary counts - summary = word_result.get("summary", {}) - summary["total_cells"] = len(word_result["cells"]) - summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text")) - word_result["summary"] = summary - gs = word_result.get("grid_shape", {}) - gs["total_cells"] = len(word_result["cells"]) - if columns_used is not None: - gs["cols"] = len(columns_used) - word_result["grid_shape"] = gs - - return removed +# Backward-compat shim -- module moved to ocr/pipeline/common.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.common") diff --git a/klausur-service/backend/ocr_pipeline_deskew.py b/klausur-service/backend/ocr_pipeline_deskew.py index 07dc270..c15dd4b 100644 --- a/klausur-service/backend/ocr_pipeline_deskew.py +++ b/klausur-service/backend/ocr_pipeline_deskew.py @@ -1,236 +1,4 @@ -""" -OCR Pipeline Deskew Endpoints (Step 2) - -Auto deskew, manual deskew, and ground truth for the deskew step. -Extracted from ocr_pipeline_geometry.py for file-size compliance. -""" - -import logging -import time -from datetime import datetime - -import cv2 -from fastapi import APIRouter, HTTPException - -from cv_vocab_pipeline import ( - create_ocr_image, - deskew_image, - deskew_image_by_word_alignment, - deskew_two_pass, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _append_pipeline_log, - ManualDeskewRequest, - DeskewGroundTruthRequest, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -@router.post("/sessions/{session_id}/deskew") -async def auto_deskew(session_id: str): - """Two-pass deskew: iterative projection (wide range) + word-alignment residual.""" - # Ensure session is in cache - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - # Deskew runs right after orientation -- use oriented image, fall back to original - img_bgr = next((v for k in ("oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), None) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for deskewing") - - t0 = time.time() - - # Two-pass deskew: iterative (+-5 deg) + word-alignment residual check - deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy()) - - # Also run individual methods for reporting (non-authoritative) - try: - _, angle_hough = deskew_image(img_bgr.copy()) - except Exception: - angle_hough = 0.0 - - success_enc, png_orig = cv2.imencode(".png", img_bgr) - orig_bytes = png_orig.tobytes() if success_enc else b"" - try: - _, angle_wa = deskew_image_by_word_alignment(orig_bytes) - except Exception: - angle_wa = 0.0 - - angle_iterative = two_pass_debug.get("pass1_angle", 0.0) - angle_residual = two_pass_debug.get("pass2_angle", 0.0) - angle_textline = two_pass_debug.get("pass3_angle", 0.0) - - duration = time.time() - t0 - - method_used = "three_pass" if abs(angle_textline) >= 0.01 else ( - "two_pass" if abs(angle_residual) >= 0.01 else "iterative" - ) - - # Encode as PNG - success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = deskewed_png_buf.tobytes() if success else b"" - - # Create binarized version - binarized_png = None - try: - binarized = create_ocr_image(deskewed_bgr) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception as e: - logger.warning(f"Binarization failed: {e}") - - confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0) - - deskew_result = { - "angle_hough": round(angle_hough, 3), - "angle_word_alignment": round(angle_wa, 3), - "angle_iterative": round(angle_iterative, 3), - "angle_residual": round(angle_residual, 3), - "angle_textline": round(angle_textline, 3), - "angle_applied": round(angle_applied, 3), - "method_used": method_used, - "confidence": round(confidence, 2), - "duration_seconds": round(duration, 2), - "two_pass_debug": two_pass_debug, - } - - # Update cache - cached["deskewed_bgr"] = deskewed_bgr - cached["binarized_png"] = binarized_png - cached["deskew_result"] = deskew_result - - # Persist to DB - db_update = { - "deskewed_png": deskewed_png, - "deskew_result": deskew_result, - "current_step": 3, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: deskew session {session_id}: " - f"hough={angle_hough:.2f} wa={angle_wa:.2f} " - f"iter={angle_iterative:.2f} residual={angle_residual:.2f} " - f"textline={angle_textline:.2f} " - f"-> {method_used} total={angle_applied:.2f}") - - await _append_pipeline_log(session_id, "deskew", { - "angle_applied": round(angle_applied, 3), - "angle_iterative": round(angle_iterative, 3), - "angle_residual": round(angle_residual, 3), - "angle_textline": round(angle_textline, 3), - "confidence": round(confidence, 2), - "method": method_used, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **deskew_result, - "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", - "binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized", - } - - -@router.post("/sessions/{session_id}/deskew/manual") -async def manual_deskew(session_id: str, req: ManualDeskewRequest): - """Apply a manual rotation angle to the oriented image.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = next((v for k in ("oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), None) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for deskewing") - - angle = max(-5.0, min(5.0, req.angle)) - - h, w = img_bgr.shape[:2] - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, angle, 1.0) - rotated = cv2.warpAffine(img_bgr, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - success, png_buf = cv2.imencode(".png", rotated) - deskewed_png = png_buf.tobytes() if success else b"" - - # Binarize - binarized_png = None - try: - binarized = create_ocr_image(rotated) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception: - pass - - deskew_result = { - **(cached.get("deskew_result") or {}), - "angle_applied": round(angle, 3), - "method_used": "manual", - } - - # Update cache - cached["deskewed_bgr"] = rotated - cached["binarized_png"] = binarized_png - cached["deskew_result"] = deskew_result - - # Persist to DB - db_update = { - "deskewed_png": deskewed_png, - "deskew_result": deskew_result, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}") - - return { - "session_id": session_id, - "angle_applied": round(angle, 3), - "method_used": "manual", - "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", - } - - -@router.post("/sessions/{session_id}/ground-truth/deskew") -async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest): - """Save ground truth feedback for the deskew step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_angle": req.corrected_angle, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "deskew_result": session.get("deskew_result"), - } - ground_truth["deskew"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - # Update cache - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: " - f"correct={req.is_correct}, corrected_angle={req.corrected_angle}") - - return {"session_id": session_id, "ground_truth": gt} +# Backward-compat shim -- module moved to ocr/pipeline/deskew.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.deskew") diff --git a/klausur-service/backend/ocr_pipeline_dewarp.py b/klausur-service/backend/ocr_pipeline_dewarp.py index b8eaa38..7291b19 100644 --- a/klausur-service/backend/ocr_pipeline_dewarp.py +++ b/klausur-service/backend/ocr_pipeline_dewarp.py @@ -1,346 +1,4 @@ -""" -OCR Pipeline Dewarp Endpoints - -Auto dewarp (with VLM/CV ensemble), manual dewarp, combined -rotation+shear adjustment, and ground truth. -Extracted from ocr_pipeline_geometry.py for file-size compliance. -""" - -import json -import logging -import os -import re -import time -from datetime import datetime -from typing import Any, Dict - -import cv2 -from fastapi import APIRouter, HTTPException, Query - -from cv_vocab_pipeline import ( - _apply_shear, - create_ocr_image, - dewarp_image, - dewarp_image_manual, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _append_pipeline_log, - ManualDewarpRequest, - CombinedAdjustRequest, - DewarpGroundTruthRequest, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: - """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. - - The VLM is shown the image and asked: are the column/table borders tilted? - If yes, by how many degrees? Returns a dict with shear_degrees and confidence. - Confidence is 0.0 if Ollama is unavailable or parsing fails. - """ - import httpx - import base64 - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " - "Are they perfectly vertical, or do they tilt slightly? " - "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " - "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " - "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " - "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" - ) - - img_b64 = base64.b64encode(image_bytes).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON from response (may have surrounding text) - match = re.search(r'\{[^}]+\}', text) - if match: - data = json.loads(match.group(0)) - shear = float(data.get("shear_degrees", 0.0)) - conf = float(data.get("confidence", 0.0)) - # Clamp to reasonable range - shear = max(-3.0, min(3.0, shear)) - conf = max(0.0, min(1.0, conf)) - return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} - except Exception as e: - logger.warning(f"VLM dewarp failed: {e}") - - return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} - - -@router.post("/sessions/{session_id}/dewarp") -async def auto_dewarp( - session_id: str, - method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), -): - """Detect and correct vertical shear on the deskewed image. - - Methods: - - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) - - **cv**: CV ensemble only (same as ensemble) - - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually - """ - if method not in ("ensemble", "cv", "vlm"): - raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") - - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") - - t0 = time.time() - - if method == "vlm": - # Encode deskewed image to PNG for VLM - success, png_buf = cv2.imencode(".png", deskewed_bgr) - img_bytes = png_buf.tobytes() if success else b"" - vlm_det = await _detect_shear_with_vlm(img_bytes) - shear_deg = vlm_det["shear_degrees"] - if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: - dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) - else: - dewarped_bgr = deskewed_bgr - dewarp_info = { - "method": vlm_det["method"], - "shear_degrees": shear_deg, - "confidence": vlm_det["confidence"], - "detections": [vlm_det], - } - else: - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) - - duration = time.time() - t0 - - # Encode as PNG - success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - dewarp_result = { - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(duration, 2), - "detections": dewarp_info.get("detections", []), - } - - # Update cache - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - - # Persist to DB - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), - current_step=4, - ) - - logger.info(f"OCR Pipeline: dewarp session {session_id}: " - f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " - f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") - - await _append_pipeline_log(session_id, "dewarp", { - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "method": dewarp_info["method"], - "ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])], - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **dewarp_result, - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/dewarp/manual") -async def manual_dewarp(session_id: str, req: ManualDewarpRequest): - """Apply shear correction with a manual angle.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") - - shear_deg = max(-2.0, min(2.0, req.shear_degrees)) - - if abs(shear_deg) < 0.001: - dewarped_bgr = deskewed_bgr - else: - dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) - - success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - dewarp_result = { - **(cached.get("dewarp_result") or {}), - "method_used": "manual", - "shear_degrees": round(shear_deg, 3), - } - - # Update cache - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - - # Persist to DB - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - ) - - logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") - - return { - "session_id": session_id, - "shear_degrees": round(shear_deg, 3), - "method_used": "manual", - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/adjust-combined") -async def adjust_combined(session_id: str, req: CombinedAdjustRequest): - """Apply rotation + shear combined to the original image. - - Used by the fine-tuning sliders to preview arbitrary rotation/shear - combinations without re-running the full deskew/dewarp pipeline. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("original_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Original image not available") - - rotation = max(-15.0, min(15.0, req.rotation_degrees)) - shear_deg = max(-5.0, min(5.0, req.shear_degrees)) - - h, w = img_bgr.shape[:2] - result_bgr = img_bgr - - # Step 1: Apply rotation - if abs(rotation) >= 0.001: - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, rotation, 1.0) - result_bgr = cv2.warpAffine(result_bgr, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - # Step 2: Apply shear - if abs(shear_deg) >= 0.001: - result_bgr = dewarp_image_manual(result_bgr, shear_deg) - - # Encode - success, png_buf = cv2.imencode(".png", result_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - # Binarize - binarized_png = None - try: - binarized = create_ocr_image(result_bgr) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception: - pass - - # Build combined result dicts - deskew_result = { - **(cached.get("deskew_result") or {}), - "angle_applied": round(rotation, 3), - "method_used": "manual_combined", - } - dewarp_result = { - **(cached.get("dewarp_result") or {}), - "method_used": "manual_combined", - "shear_degrees": round(shear_deg, 3), - } - - # Update cache - cached["deskewed_bgr"] = result_bgr - cached["dewarped_bgr"] = result_bgr - cached["deskew_result"] = deskew_result - cached["dewarp_result"] = dewarp_result - - # Persist to DB - db_update = { - "dewarped_png": dewarped_png, - "deskew_result": deskew_result, - "dewarp_result": dewarp_result, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - db_update["deskewed_png"] = dewarped_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: combined adjust session {session_id}: " - f"rotation={rotation:.3f} shear={shear_deg:.3f}") - - return { - "session_id": session_id, - "rotation_degrees": round(rotation, 3), - "shear_degrees": round(shear_deg, 3), - "method_used": "manual_combined", - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/ground-truth/dewarp") -async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): - """Save ground truth feedback for the dewarp step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_shear": req.corrected_shear, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "dewarp_result": session.get("dewarp_result"), - } - ground_truth["dewarp"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " - f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") - - return {"session_id": session_id, "ground_truth": gt} +# Backward-compat shim -- module moved to ocr/pipeline/dewarp.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.dewarp") diff --git a/klausur-service/backend/ocr_pipeline_geometry.py b/klausur-service/backend/ocr_pipeline_geometry.py index 3d03619..f20d948 100644 --- a/klausur-service/backend/ocr_pipeline_geometry.py +++ b/klausur-service/backend/ocr_pipeline_geometry.py @@ -1,27 +1,4 @@ -""" -OCR Pipeline Geometry API (barrel re-export) - -This module was split into: - - ocr_pipeline_deskew.py (Deskew endpoints) - - ocr_pipeline_dewarp.py (Dewarp endpoints) - - ocr_pipeline_structure.py (Structure detection + exclude regions) - - ocr_pipeline_columns.py (Column detection + ground truth) - -The `router` object is assembled here by including all sub-routers. -Importers that did `from ocr_pipeline_geometry import router` continue to work. -""" - -from fastapi import APIRouter - -from ocr_pipeline_deskew import router as _deskew_router -from ocr_pipeline_dewarp import router as _dewarp_router -from ocr_pipeline_structure import router as _structure_router -from ocr_pipeline_columns import router as _columns_router - -# Assemble the combined router. -# All sub-routers use prefix="/api/v1/ocr-pipeline", so include without extra prefix. -router = APIRouter() -router.include_router(_deskew_router) -router.include_router(_dewarp_router) -router.include_router(_structure_router) -router.include_router(_columns_router) +# Backward-compat shim -- module moved to ocr/pipeline/geometry.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.geometry") diff --git a/klausur-service/backend/ocr_pipeline_llm_review.py b/klausur-service/backend/ocr_pipeline_llm_review.py index 37e8df7..ef8e16e 100644 --- a/klausur-service/backend/ocr_pipeline_llm_review.py +++ b/klausur-service/backend/ocr_pipeline_llm_review.py @@ -1,209 +1,4 @@ -""" -OCR Pipeline LLM Review — LLM-based correction endpoints. - -Extracted from ocr_pipeline_postprocess.py. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -from datetime import datetime -from typing import Dict, List - -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse - -from cv_vocab_pipeline import ( - OLLAMA_REVIEW_MODEL, - llm_review_entries, - llm_review_entries_streaming, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _append_pipeline_log, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Step 8: LLM Review -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/llm-review") -async def run_llm_review(session_id: str, request: Request, stream: bool = False): - """Run LLM-based correction on vocab entries from Step 5. - - Query params: - stream: false (default) for JSON response, true for SSE streaming - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") - - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if not entries: - raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") - - # Optional model override from request body - body = {} - try: - body = await request.json() - except Exception: - pass - model = body.get("model") or OLLAMA_REVIEW_MODEL - - if stream: - return StreamingResponse( - _llm_review_stream_generator(session_id, entries, word_result, model, request), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, - ) - - # Non-streaming path - try: - result = await llm_review_entries(entries, model=model) - except Exception as e: - import traceback - logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") - - # Store result inside word_result as a sub-key - word_result["llm_review"] = { - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "entries_corrected": result["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " - f"{result['duration_ms']}ms, model={result['model_used']}") - - await _append_pipeline_log(session_id, "correction", { - "engine": "llm", - "model": result["model_used"], - "total_entries": len(entries), - "corrections_proposed": len(result["changes"]), - }, duration_ms=result["duration_ms"]) - - return { - "session_id": session_id, - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "total_entries": len(entries), - "corrections_found": len(result["changes"]), - } - - -async def _llm_review_stream_generator( - session_id: str, - entries: List[Dict], - word_result: Dict, - model: str, - request: Request, -): - """SSE generator that yields batch-by-batch LLM review progress.""" - try: - async for event in llm_review_entries_streaming(entries, model=model): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during LLM review for {session_id}") - return - - yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" - - # On complete: persist to DB - if event.get("type") == "complete": - word_result["llm_review"] = { - "changes": event["changes"], - "model_used": event["model_used"], - "duration_ms": event["duration_ms"], - "entries_corrected": event["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " - f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") - - except Exception as e: - import traceback - logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} - yield f"data: {json.dumps(error_event)}\n\n" - - -@router.post("/sessions/{session_id}/llm-review/apply") -async def apply_llm_corrections(session_id: str, request: Request): - """Apply selected LLM corrections to vocab entries.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - llm_review = word_result.get("llm_review") - if not llm_review: - raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") - - body = await request.json() - accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] - - changes = llm_review.get("changes", []) - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - - # Build a lookup: (row_index, field) -> new_value for accepted changes - corrections = {} - applied_count = 0 - for idx, change in enumerate(changes): - if idx in accepted_indices: - key = (change["row_index"], change["field"]) - corrections[key] = change["new"] - applied_count += 1 - - # Apply corrections to entries - for entry in entries: - row_idx = entry.get("row_index", -1) - for field_name in ("english", "german", "example"): - key = (row_idx, field_name) - if key in corrections: - entry[field_name] = corrections[key] - entry["llm_corrected"] = True - - # Update word_result - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["llm_review"]["applied_count"] = applied_count - word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() - - await update_session_db(session_id, word_result=word_result) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") - - return { - "session_id": session_id, - "applied_count": applied_count, - "total_changes": len(changes), - } +# Backward-compat shim -- module moved to ocr/pipeline/llm_review.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.llm_review") diff --git a/klausur-service/backend/ocr_pipeline_ocr_merge.py b/klausur-service/backend/ocr_pipeline_ocr_merge.py index c91f8b2..e9f45db 100644 --- a/klausur-service/backend/ocr_pipeline_ocr_merge.py +++ b/klausur-service/backend/ocr_pipeline_ocr_merge.py @@ -1,266 +1,4 @@ -""" -OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints. - -Merge helper functions live in ocr_merge_helpers.py. -This module re-exports them for backward compatibility. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import time - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException - -from cv_words_first import build_grid_from_words -from ocr_pipeline_common import _cache, _append_pipeline_log -from ocr_pipeline_session_store import get_session_image, update_session_db - -# Re-export merge helpers for backward compatibility -from ocr_merge_helpers import ( # noqa: F401 - _split_paddle_multi_words, - _group_words_into_rows, - _row_center_y, - _merge_row_sequences, - _merge_paddle_tesseract, - _deduplicate_words, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -def _run_tesseract_words(img_bgr) -> list: - """Run Tesseract OCR on an image and return word dicts.""" - from PIL import Image - import pytesseract - - pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - data = pytesseract.image_to_data( - pil_img, lang="eng+deu", - config="--psm 6 --oem 3", - output_type=pytesseract.Output.DICT, - ) - tess_words = [] - for i in range(len(data["text"])): - text = str(data["text"][i]).strip() - conf_raw = str(data["conf"][i]) - conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 - if not text or conf < 20: - continue - tess_words.append({ - "text": text, - "left": data["left"][i], - "top": data["top"][i], - "width": data["width"][i], - "height": data["height"][i], - "conf": conf, - }) - return tess_words - - -def _build_kombi_word_result( - cells: list, - columns_meta: list, - img_w: int, - img_h: int, - duration: float, - engine_name: str, - raw_engine_words: list, - raw_engine_words_split: list, - tess_words: list, - merged_words: list, - raw_engine_key: str = "raw_paddle_words", - raw_split_key: str = "raw_paddle_words_split", -) -> dict: - """Build the word_result dict for kombi endpoints.""" - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - return { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": engine_name, - "grid_method": engine_name, - raw_engine_key: raw_engine_words, - raw_split_key: raw_engine_words_split, - "raw_tesseract_words": tess_words, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words), - raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - }, - } - - -async def _load_session_image(session_id: str): - """Load preprocessed image for kombi endpoints.""" - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode image") - - return img_png, img_bgr - - -# --------------------------------------------------------------------------- -# Kombi endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/paddle-kombi") -async def paddle_kombi(session_id: str): - """Run PaddleOCR + Tesseract on the preprocessed image and merge results.""" - img_png, img_bgr = await _load_session_image(session_id) - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_paddle - - t0 = time.time() - - paddle_words = await ocr_region_paddle(img_bgr, region=None) - if not paddle_words: - paddle_words = [] - - tess_words = _run_tesseract_words(img_bgr) - - paddle_words_split = _split_paddle_multi_words(paddle_words) - logger.info( - "paddle_kombi: split %d paddle boxes -> %d individual words", - len(paddle_words), len(paddle_words_split), - ) - - if not paddle_words_split and not tess_words: - raise HTTPException(status_code=400, detail="Both OCR engines returned no words") - - merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words) - merged_words = _deduplicate_words(merged_words) - - cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) - duration = time.time() - t0 - - for cell in cells: - cell["ocr_engine"] = "kombi" - - word_result = _build_kombi_word_result( - cells, columns_meta, img_w, img_h, duration, "kombi", - paddle_words, paddle_words_split, tess_words, merged_words, - "raw_paddle_words", "raw_paddle_words_split", - ) - - await update_session_db( - session_id, word_result=word_result, cropped_png=img_png, current_step=8, - ) - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info( - "paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " - "[paddle=%d, tess=%d, merged=%d]", - session_id, len(cells), word_result["grid_shape"]["rows"], - word_result["grid_shape"]["cols"], duration, - len(paddle_words), len(tess_words), len(merged_words), - ) - - await _append_pipeline_log(session_id, "paddle_kombi", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "paddle_words": len(paddle_words), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - "ocr_engine": "kombi", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -@router.post("/sessions/{session_id}/rapid-kombi") -async def rapid_kombi(session_id: str): - """Run RapidOCR + Tesseract on the preprocessed image and merge results.""" - img_png, img_bgr = await _load_session_image(session_id) - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_rapid - from cv_vocab_types import PageRegion - - t0 = time.time() - - full_region = PageRegion( - type="full_page", x=0, y=0, width=img_w, height=img_h, - ) - rapid_words = ocr_region_rapid(img_bgr, full_region) - if not rapid_words: - rapid_words = [] - - tess_words = _run_tesseract_words(img_bgr) - - rapid_words_split = _split_paddle_multi_words(rapid_words) - logger.info( - "rapid_kombi: split %d rapid boxes -> %d individual words", - len(rapid_words), len(rapid_words_split), - ) - - if not rapid_words_split and not tess_words: - raise HTTPException(status_code=400, detail="Both OCR engines returned no words") - - merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words) - merged_words = _deduplicate_words(merged_words) - - cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) - duration = time.time() - t0 - - for cell in cells: - cell["ocr_engine"] = "rapid_kombi" - - word_result = _build_kombi_word_result( - cells, columns_meta, img_w, img_h, duration, "rapid_kombi", - rapid_words, rapid_words_split, tess_words, merged_words, - "raw_rapid_words", "raw_rapid_words_split", - ) - - await update_session_db( - session_id, word_result=word_result, cropped_png=img_png, current_step=8, - ) - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info( - "rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " - "[rapid=%d, tess=%d, merged=%d]", - session_id, len(cells), word_result["grid_shape"]["rows"], - word_result["grid_shape"]["cols"], duration, - len(rapid_words), len(tess_words), len(merged_words), - ) - - await _append_pipeline_log(session_id, "rapid_kombi", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "rapid_words": len(rapid_words), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - "ocr_engine": "rapid_kombi", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} +# Backward-compat shim -- module moved to ocr/pipeline/ocr_merge.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.ocr_merge") diff --git a/klausur-service/backend/ocr_pipeline_overlay_grid.py b/klausur-service/backend/ocr_pipeline_overlay_grid.py index 769ef0f..de01832 100644 --- a/klausur-service/backend/ocr_pipeline_overlay_grid.py +++ b/klausur-service/backend/ocr_pipeline_overlay_grid.py @@ -1,333 +1,4 @@ -""" -Overlay rendering for columns, rows, and words (grid-based overlays). - -Extracted from ocr_pipeline_overlays.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -from typing import Any, Dict, List - -import cv2 -import numpy as np -from fastapi import HTTPException -from fastapi.responses import Response - -from ocr_pipeline_common import _get_base_image_png -from ocr_pipeline_session_store import get_session_db -from ocr_pipeline_rows import _draw_box_exclusion_overlay - -logger = logging.getLogger(__name__) - - -async def _get_columns_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with column borders drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result or not column_result.get("columns"): - raise HTTPException(status_code=404, detail="No column data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for region types (BGR) - colors = { - "column_en": (255, 180, 0), # Blue - "column_de": (0, 200, 0), # Green - "column_example": (0, 140, 255), # Orange - "column_text": (200, 200, 0), # Cyan/Turquoise - "page_ref": (200, 0, 200), # Purple - "column_marker": (0, 0, 220), # Red - "column_ignore": (180, 180, 180), # Light Gray - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for col in column_result["columns"]: - x, y = col["x"], col["y"] - w, h = col["width"], col["height"] - color = colors.get(col.get("type", ""), (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) - - # Label with confidence - label = col.get("type", "unknown").replace("column_", "").upper() - conf = col.get("classification_confidence") - if conf is not None and conf < 1.0: - label = f"{label} {int(conf * 100)}%" - cv2.putText(img, label, (x + 10, y + 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) - - # Blend overlay at 20% opacity - cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) - - # Draw detected box boundaries as dashed rectangles - zones = column_result.get("zones") or [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - box_color = (0, 200, 255) # Yellow (BGR) - # Draw dashed rectangle by drawing short line segments - dash_len = 15 - for edge_x in range(bx, bx + bw, dash_len * 2): - end_x = min(edge_x + dash_len, bx + bw) - cv2.line(img, (edge_x, by), (end_x, by), box_color, 2) - cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2) - for edge_y in range(by, by + bh, dash_len * 2): - end_y = min(edge_y + dash_len, by + bh) - cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2) - cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2) - cv2.putText(img, "BOX", (bx + 10, by + bh - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -async def _get_rows_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with row bands drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - row_result = session.get("row_result") - if not row_result or not row_result.get("rows"): - raise HTTPException(status_code=404, detail="No row data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for row types (BGR) - row_colors = { - "content": (255, 180, 0), # Blue - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for row in row_result["rows"]: - x, y = row["x"], row["y"] - w, h = row["width"], row["height"] - row_type = row.get("row_type", "content") - color = row_colors.get(row_type, (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) - - # Label - idx = row.get("index", 0) - label = f"R{idx} {row_type.upper()}" - wc = row.get("word_count", 0) - if wc: - label = f"{label} ({wc}w)" - cv2.putText(img, label, (x + 5, y + 18), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # Draw zone separator lines if zones exist - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - if zones: - img_w_px = img.shape[1] - zone_color = (0, 200, 255) # Yellow (BGR) - dash_len = 20 - for zone in zones: - if zone.get("zone_type") == "box": - zy = zone["y"] - zh = zone["height"] - for line_y in [zy, zy + zh]: - for sx in range(0, img_w_px, dash_len * 2): - ex = min(sx + dash_len, img_w_px) - cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -async def _get_words_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with cell grid drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=404, detail="No word data available") - - # Support both new cell-based and legacy entry-based formats - cells = word_result.get("cells") - if not cells and not word_result.get("entries"): - raise HTTPException(status_code=404, detail="No word data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - img_h, img_w = img.shape[:2] - - overlay = img.copy() - - if cells: - # New cell-based overlay: color by column index - col_palette = [ - (255, 180, 0), # Blue (BGR) - (0, 200, 0), # Green - (0, 140, 255), # Orange - (200, 100, 200), # Purple - (200, 200, 0), # Cyan - (100, 200, 200), # Yellow-ish - ] - - for cell in cells: - bbox = cell.get("bbox_px", {}) - cx = bbox.get("x", 0) - cy = bbox.get("y", 0) - cw = bbox.get("w", 0) - ch = bbox.get("h", 0) - if cw <= 0 or ch <= 0: - continue - - col_idx = cell.get("col_index", 0) - color = col_palette[col_idx % len(col_palette)] - - # Cell rectangle border - cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1) - # Semi-transparent fill - cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1) - - # Cell-ID label (top-left corner) - cell_id = cell.get("cell_id", "") - cv2.putText(img, cell_id, (cx + 2, cy + 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1) - - # Text label (bottom of cell) - text = cell.get("text", "") - if text: - conf = cell.get("confidence", 0) - if conf >= 70: - text_color = (0, 180, 0) - elif conf >= 50: - text_color = (0, 180, 220) - else: - text_color = (0, 0, 220) - - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, cy + ch - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - else: - # Legacy fallback: entry-based overlay (for old sessions) - column_result = session.get("column_result") - row_result = session.get("row_result") - col_colors = { - "column_en": (255, 180, 0), - "column_de": (0, 200, 0), - "column_example": (0, 140, 255), - } - - columns = [] - if column_result and column_result.get("columns"): - columns = [c for c in column_result["columns"] - if c.get("type", "").startswith("column_")] - - content_rows_data = [] - if row_result and row_result.get("rows"): - content_rows_data = [r for r in row_result["rows"] - if r.get("row_type") == "content"] - - for col in columns: - col_type = col.get("type", "") - color = col_colors.get(col_type, (200, 200, 200)) - cx, cw = col["x"], col["width"] - for row in content_rows_data: - ry, rh = row["y"], row["height"] - cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1) - cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1) - - entries = word_result["entries"] - entry_by_row: Dict[int, Dict] = {} - for entry in entries: - entry_by_row[entry.get("row_index", -1)] = entry - - for row_idx, row in enumerate(content_rows_data): - entry = entry_by_row.get(row_idx) - if not entry: - continue - conf = entry.get("confidence", 0) - text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220) - ry, rh = row["y"], row["height"] - for col in columns: - col_type = col.get("type", "") - cx, cw = col["x"], col["width"] - field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "") - text = entry.get(field, "") if field else "" - if text: - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, ry + rh - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - - # Blend overlay at 10% opacity - cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) - - # Red semi-transparent overlay for box zones - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") +# Backward-compat shim -- module moved to ocr/pipeline/overlay_grid.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_grid") diff --git a/klausur-service/backend/ocr_pipeline_overlay_structure.py b/klausur-service/backend/ocr_pipeline_overlay_structure.py index ad48382..721748d 100644 --- a/klausur-service/backend/ocr_pipeline_overlay_structure.py +++ b/klausur-service/backend/ocr_pipeline_overlay_structure.py @@ -1,205 +1,4 @@ -""" -Overlay rendering for structure detection (boxes, zones, colors, graphics). - -Extracted from ocr_pipeline_overlays.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -from typing import Any, Dict, List - -import cv2 -import numpy as np -from fastapi import HTTPException -from fastapi.responses import Response - -from ocr_pipeline_common import _get_base_image_png -from ocr_pipeline_session_store import get_session_db -from cv_color_detect import _COLOR_HEX, _COLOR_RANGES -from cv_box_detect import detect_boxes, split_page_into_zones - -logger = logging.getLogger(__name__) - - -async def _get_structure_overlay(session_id: str) -> Response: - """Generate overlay image showing detected boxes, zones, and color regions.""" - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - h, w = img.shape[:2] - - # Get structure result (run detection if not cached) - session = await get_session_db(session_id) - structure = (session or {}).get("structure_result") - - if not structure: - # Run detection on-the-fly - margin = int(min(w, h) * 0.03) - content_x, content_y = margin, margin - content_w_px = w - 2 * margin - content_h_px = h - 2 * margin - boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px) - zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes) - structure = { - "boxes": [ - {"x": b.x, "y": b.y, "w": b.width, "h": b.height, - "confidence": b.confidence, "border_thickness": b.border_thickness} - for b in boxes - ], - "zones": [ - {"index": z.index, "zone_type": z.zone_type, - "y": z.y, "h": z.height, "x": z.x, "w": z.width} - for z in zones - ], - } - - overlay = img.copy() - - # --- Draw zone boundaries --- - zone_colors = { - "content": (200, 200, 200), # light gray - "box": (255, 180, 0), # blue-ish (BGR) - } - for zone in structure.get("zones", []): - zx = zone["x"] - zy = zone["y"] - zw = zone["w"] - zh = zone["h"] - color = zone_colors.get(zone["zone_type"], (200, 200, 200)) - - # Draw zone boundary as dashed line - dash_len = 12 - for edge_x in range(zx, zx + zw, dash_len * 2): - end_x = min(edge_x + dash_len, zx + zw) - cv2.line(img, (edge_x, zy), (end_x, zy), color, 1) - cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1) - - # Zone label - zone_label = f"Zone {zone['index']} ({zone['zone_type']})" - cv2.putText(img, zone_label, (zx + 5, zy + 15), - cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1) - - # --- Draw detected boxes --- - # Color map for box backgrounds (BGR) - bg_hex_to_bgr = { - "#dc2626": (38, 38, 220), # red - "#2563eb": (235, 99, 37), # blue - "#16a34a": (74, 163, 22), # green - "#ea580c": (12, 88, 234), # orange - "#9333ea": (234, 51, 147), # purple - "#ca8a04": (4, 138, 202), # yellow - "#6b7280": (128, 114, 107), # gray - } - - for box_data in structure.get("boxes", []): - bx = box_data["x"] - by = box_data["y"] - bw = box_data["w"] - bh = box_data["h"] - conf = box_data.get("confidence", 0) - thickness = box_data.get("border_thickness", 0) - bg_hex = box_data.get("bg_color_hex", "#6b7280") - bg_name = box_data.get("bg_color_name", "") - - # Box fill color - fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107)) - - # Semi-transparent fill - cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1) - - # Solid border - border_color = fill_bgr - cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3) - - # Label - label = f"BOX" - if bg_name and bg_name not in ("unknown", "white"): - label += f" ({bg_name})" - if thickness > 0: - label += f" border={thickness}px" - label += f" {int(conf * 100)}%" - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2) - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # --- Draw color regions (HSV masks) --- - hsv = cv2.cvtColor( - cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR), - cv2.COLOR_BGR2HSV, - ) - color_bgr_map = { - "red": (0, 0, 255), - "orange": (0, 140, 255), - "yellow": (0, 200, 255), - "green": (0, 200, 0), - "blue": (255, 150, 0), - "purple": (200, 0, 200), - } - for color_name, ranges in _COLOR_RANGES.items(): - mask = np.zeros((h, w), dtype=np.uint8) - for lower, upper in ranges: - mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) - # Only draw if there are significant colored pixels - if np.sum(mask > 0) < 100: - continue - # Draw colored contours - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - draw_color = color_bgr_map.get(color_name, (200, 200, 200)) - for cnt in contours: - area = cv2.contourArea(cnt) - if area < 20: - continue - cv2.drawContours(img, [cnt], -1, draw_color, 2) - - # --- Draw graphic elements --- - graphics_data = structure.get("graphics", []) - shape_icons = { - "image": "IMAGE", - "illustration": "ILLUST", - } - for gfx in graphics_data: - gx, gy = gfx["x"], gfx["y"] - gw, gh = gfx["w"], gfx["h"] - shape = gfx.get("shape", "icon") - color_hex = gfx.get("color_hex", "#6b7280") - conf = gfx.get("confidence", 0) - - # Pick draw color based on element color (BGR) - gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) - - # Draw bounding box (dashed style via short segments) - dash = 6 - for seg_x in range(gx, gx + gw, dash * 2): - end_x = min(seg_x + dash, gx + gw) - cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) - cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) - for seg_y in range(gy, gy + gh, dash * 2): - end_y = min(seg_y + dash, gy + gh) - cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) - cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) - - # Label - icon = shape_icons.get(shape, shape.upper()[:5]) - label = f"{icon} {int(conf * 100)}%" - # White background for readability - (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) - lx = gx + 2 - ly = max(gy - 4, th + 4) - cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) - cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) - - # Encode result - _, png_buf = cv2.imencode(".png", img) - return Response(content=png_buf.tobytes(), media_type="image/png") +# Backward-compat shim -- module moved to ocr/pipeline/overlay_structure.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlay_structure") diff --git a/klausur-service/backend/ocr_pipeline_overlays.py b/klausur-service/backend/ocr_pipeline_overlays.py index 7a30f9b..ce6fc95 100644 --- a/klausur-service/backend/ocr_pipeline_overlays.py +++ b/klausur-service/backend/ocr_pipeline_overlays.py @@ -1,34 +1,4 @@ -""" -Overlay image rendering for OCR pipeline — barrel re-export. - -All implementation split into: - ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics) - ocr_pipeline_overlay_grid — columns, rows, words overlays - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -from fastapi import HTTPException -from fastapi.responses import Response - -from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401 -from ocr_pipeline_overlay_grid import ( # noqa: F401 - _get_columns_overlay, - _get_rows_overlay, - _get_words_overlay, -) - - -async def render_overlay(overlay_type: str, session_id: str) -> Response: - """Dispatch to the appropriate overlay renderer.""" - if overlay_type == "structure": - return await _get_structure_overlay(session_id) - elif overlay_type == "columns": - return await _get_columns_overlay(session_id) - elif overlay_type == "rows": - return await _get_rows_overlay(session_id) - elif overlay_type == "words": - return await _get_words_overlay(session_id) - else: - raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}") +# Backward-compat shim -- module moved to ocr/pipeline/overlays.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.overlays") diff --git a/klausur-service/backend/ocr_pipeline_postprocess.py b/klausur-service/backend/ocr_pipeline_postprocess.py index 388f5e2..815ff82 100644 --- a/klausur-service/backend/ocr_pipeline_postprocess.py +++ b/klausur-service/backend/ocr_pipeline_postprocess.py @@ -1,26 +1,4 @@ -""" -OCR Pipeline Postprocessing API — composite router assembling LLM review, -reconstruction, export, validation, image detection/generation, and -handwriting removal endpoints. - -Split into sub-modules: - ocr_pipeline_llm_review — LLM review + apply corrections - ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX - ocr_pipeline_validation — image detection, generation, validation, handwriting removal - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -from fastapi import APIRouter - -from ocr_pipeline_llm_review import router as _llm_review_router -from ocr_pipeline_reconstruction import router as _reconstruction_router -from ocr_pipeline_validation import router as _validation_router - -# Composite router — drop-in replacement for the old monolithic router. -# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``. -router = APIRouter() -router.include_router(_llm_review_router) -router.include_router(_reconstruction_router) -router.include_router(_validation_router) +# Backward-compat shim -- module moved to ocr/pipeline/postprocess.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.postprocess") diff --git a/klausur-service/backend/ocr_pipeline_reconstruction.py b/klausur-service/backend/ocr_pipeline_reconstruction.py index 99081c4..cd22501 100644 --- a/klausur-service/backend/ocr_pipeline_reconstruction.py +++ b/klausur-service/backend/ocr_pipeline_reconstruction.py @@ -1,362 +1,4 @@ -""" -OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export. - -Extracted from ocr_pipeline_postprocess.py. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import re -from typing import Dict - -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse - -from ocr_pipeline_session_store import ( - get_session_db, - get_sub_sessions, - update_session_db, -) -from ocr_pipeline_common import _cache - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Step 9: Reconstruction + Fabric JSON export -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reconstruction") -async def save_reconstruction(session_id: str, request: Request): - """Save edited cell texts from reconstruction step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - body = await request.json() - cell_updates = body.get("cells", []) - - if not cell_updates: - await update_session_db(session_id, current_step=10) - return {"session_id": session_id, "updated": 0} - - # Build update map: cell_id -> new text - update_map = {c["cell_id"]: c["text"] for c in cell_updates} - - # Separate sub-session updates (cell_ids prefixed with "box{N}_") - sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} - main_updates: Dict[str, str] = {} - for cell_id, text in update_map.items(): - m = re.match(r'^box(\d+)_(.+)$', cell_id) - if m: - bi = int(m.group(1)) - original_id = m.group(2) - sub_updates.setdefault(bi, {})[original_id] = text - else: - main_updates[cell_id] = text - - # Update main session cells - cells = word_result.get("cells", []) - updated_count = 0 - for cell in cells: - if cell["cell_id"] in main_updates: - cell["text"] = main_updates[cell["cell_id"]] - cell["status"] = "edited" - updated_count += 1 - - word_result["cells"] = cells - - # Also update vocab_entries if present - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if entries: - for entry in entries: - row_idx = entry.get("row_index", -1) - for col_idx, field_name in enumerate(["english", "german", "example"]): - cell_id = f"R{row_idx:02d}_C{col_idx}" - cell_id_alt = f"R{row_idx}_C{col_idx}" - new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) - if new_text is not None: - entry[field_name] = new_text - - word_result["vocab_entries"] = entries - if "entries" in word_result: - word_result["entries"] = entries - - await update_session_db(session_id, word_result=word_result, current_step=10) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - # Route sub-session updates - sub_updated = 0 - if sub_updates: - subs = await get_sub_sessions(session_id) - sub_by_index = {s.get("box_index"): s["id"] for s in subs} - for bi, updates in sub_updates.items(): - sub_id = sub_by_index.get(bi) - if not sub_id: - continue - sub_session = await get_session_db(sub_id) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word: - continue - sub_cells = sub_word.get("cells", []) - for cell in sub_cells: - if cell["cell_id"] in updates: - cell["text"] = updates[cell["cell_id"]] - cell["status"] = "edited" - sub_updated += 1 - sub_word["cells"] = sub_cells - await update_session_db(sub_id, word_result=sub_word) - if sub_id in _cache: - _cache[sub_id]["word_result"] = sub_word - - total_updated = updated_count + sub_updated - logger.info(f"Reconstruction saved for session {session_id}: " - f"{updated_count} main + {sub_updated} sub-session cells updated") - - return { - "session_id": session_id, - "updated": total_updated, - "main_updated": updated_count, - "sub_updated": sub_updated, - } - - -@router.get("/sessions/{session_id}/reconstruction/fabric-json") -async def get_fabric_json(session_id: str): - """Return cell grid as Fabric.js-compatible JSON for the canvas editor.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = list(word_result.get("cells", [])) - img_w = word_result.get("image_width", 800) - img_h = word_result.get("image_height", 600) - - # Merge sub-session cells at box positions - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word or not sub_word.get("cells"): - continue - - bi = sub.get("box_index", 0) - if bi < len(box_zones): - box = box_zones[bi]["box"] - box_y, box_x = box["y"], box["x"] - else: - box_y, box_x = 0, 0 - - for cell in sub_word["cells"]: - cell_copy = dict(cell) - cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" - cell_copy["source"] = f"box_{bi}" - bbox = cell_copy.get("bbox_px", {}) - if bbox: - bbox = dict(bbox) - bbox["x"] = bbox.get("x", 0) + box_x - bbox["y"] = bbox.get("y", 0) + box_y - cell_copy["bbox_px"] = bbox - cells.append(cell_copy) - - from services.layout_reconstruction_service import cells_to_fabric_json - fabric_json = cells_to_fabric_json(cells, img_w, img_h) - - return fabric_json - - -# --------------------------------------------------------------------------- -# Vocab entries merged + PDF/DOCX export -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/vocab-entries/merged") -async def get_merged_vocab_entries(session_id: str): - """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") or {} - entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) - - for e in entries: - e.setdefault("source", "main") - - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") or {} - sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] - - bi = sub.get("box_index", 0) - box_y = 0 - if bi < len(box_zones): - box_y = box_zones[bi]["box"]["y"] - - for e in sub_entries: - e_copy = dict(e) - e_copy["source"] = f"box_{bi}" - e_copy["source_y"] = box_y - entries.append(e_copy) - - def _sort_key(e): - if e.get("source", "main") == "main": - return e.get("row_index", 0) * 100 - return e.get("source_y", 0) * 100 + e.get("row_index", 0) - - entries.sort(key=_sort_key) - - return { - "session_id": session_id, - "entries": entries, - "total": len(entries), - "sources": list(set(e.get("source", "main") for e in entries)), - } - - -@router.get("/sessions/{session_id}/reconstruction/export/pdf") -async def export_reconstruction_pdf(session_id: str): - """Export the reconstructed cell grid as a PDF table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - # Build table data: rows x columns - table_data: list[list[str]] = [] - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - table_data.append(header) - - for r in range(n_rows): - row_texts = [] - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - row_texts.append(cell.get("text", "") if cell else "") - table_data.append(row_texts) - - try: - from reportlab.lib.pagesizes import A4 - from reportlab.lib import colors - from reportlab.platypus import SimpleDocTemplate, Table, TableStyle - import io as _io - - buf = _io.BytesIO() - doc = SimpleDocTemplate(buf, pagesize=A4) - if not table_data or not table_data[0]: - raise HTTPException(status_code=400, detail="No data to export") - - t = Table(table_data) - t.setStyle(TableStyle([ - ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), - ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), - ('FONTSIZE', (0, 0), (-1, -1), 9), - ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), - ('VALIGN', (0, 0), (-1, -1), 'TOP'), - ('WORDWRAP', (0, 0), (-1, -1), True), - ])) - doc.build([t]) - buf.seek(0) - - return StreamingResponse( - buf, - media_type="application/pdf", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="reportlab not installed") - - -@router.get("/sessions/{session_id}/reconstruction/export/docx") -async def export_reconstruction_docx(session_id: str): - """Export the reconstructed cell grid as a DOCX table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - try: - from docx import Document - from docx.shared import Pt - import io as _io - - doc = Document() - doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1) - - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - - table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) - table.style = 'Table Grid' - - for ci, h in enumerate(header): - table.rows[0].cells[ci].text = h - - for r in range(n_rows): - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" - - buf = _io.BytesIO() - doc.save(buf) - buf.seek(0) - - return StreamingResponse( - buf, - media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="python-docx not installed") +# Backward-compat shim -- module moved to ocr/pipeline/reconstruction.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reconstruction") diff --git a/klausur-service/backend/ocr_pipeline_regression.py b/klausur-service/backend/ocr_pipeline_regression.py index 5c8ff89..b5147ed 100644 --- a/klausur-service/backend/ocr_pipeline_regression.py +++ b/klausur-service/backend/ocr_pipeline_regression.py @@ -1,22 +1,4 @@ -""" -OCR Pipeline Regression Tests — barrel re-export. - -All implementation split into: - ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison - ocr_pipeline_regression_endpoints — FastAPI routes - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -# Helpers (used by grid_editor_api_grid.py) -from ocr_pipeline_regression_helpers import ( # noqa: F401 - _init_regression_table, - _persist_regression_run, - _extract_cells_for_comparison, - _build_reference_snapshot, - compare_grids, -) - -# Endpoints (router used by ocr_pipeline_api.py) -from ocr_pipeline_regression_endpoints import router # noqa: F401 +# Backward-compat shim -- module moved to ocr/pipeline/regression.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression") diff --git a/klausur-service/backend/ocr_pipeline_regression_endpoints.py b/klausur-service/backend/ocr_pipeline_regression_endpoints.py index a91d6d6..375d135 100644 --- a/klausur-service/backend/ocr_pipeline_regression_endpoints.py +++ b/klausur-service/backend/ocr_pipeline_regression_endpoints.py @@ -1,421 +1,4 @@ -""" -OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression. - -Extracted from ocr_pipeline_regression.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import time -from typing import Any, Dict, Optional - -from fastapi import APIRouter, HTTPException, Query - -from grid_editor_api import _build_grid_core -from ocr_pipeline_session_store import ( - get_session_db, - list_ground_truth_sessions_db, - update_session_db, -) -from ocr_pipeline_regression_helpers import ( - _build_reference_snapshot, - _init_regression_table, - _persist_regression_run, - compare_grids, - get_pool, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) - - -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/mark-ground-truth") -async def mark_ground_truth( - session_id: str, - pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), -): - """Save the current build-grid result as ground-truth reference.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - grid_result = session.get("grid_editor_result") - if not grid_result or not grid_result.get("zones"): - raise HTTPException( - status_code=400, - detail="No grid_editor_result found. Run build-grid first.", - ) - - # Auto-detect pipeline from word_result if not provided - if not pipeline: - wr = session.get("word_result") or {} - engine = wr.get("ocr_engine", "") - if engine in ("kombi", "rapid_kombi"): - pipeline = "kombi" - elif engine == "paddle_direct": - pipeline = "paddle-direct" - else: - pipeline = "pipeline" - - reference = _build_reference_snapshot(grid_result, pipeline=pipeline) - - # Merge into existing ground_truth JSONB - gt = session.get("ground_truth") or {} - gt["build_grid_reference"] = reference - await update_session_db(session_id, ground_truth=gt, current_step=11) - - # Compare with auto-snapshot if available (shows what the user corrected) - auto_snapshot = gt.get("auto_grid_snapshot") - correction_diff = None - if auto_snapshot: - correction_diff = compare_grids(auto_snapshot, reference) - - logger.info( - "Ground truth marked for session %s: %d cells (corrections: %s)", - session_id, - len(reference["cells"]), - correction_diff["summary"] if correction_diff else "no auto-snapshot", - ) - - return { - "status": "ok", - "session_id": session_id, - "cells_saved": len(reference["cells"]), - "summary": reference["summary"], - "correction_diff": correction_diff, - } - - -@router.delete("/sessions/{session_id}/mark-ground-truth") -async def unmark_ground_truth(session_id: str): - """Remove the ground-truth reference from a session.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - if "build_grid_reference" not in gt: - raise HTTPException(status_code=404, detail="No ground truth reference found") - - del gt["build_grid_reference"] - await update_session_db(session_id, ground_truth=gt) - - logger.info("Ground truth removed for session %s", session_id) - return {"status": "ok", "session_id": session_id} - - -@router.get("/sessions/{session_id}/correction-diff") -async def get_correction_diff(session_id: str): - """Compare automatic OCR grid with manually corrected ground truth. - - Returns a diff showing exactly which cells the user corrected, - broken down by col_type (english, german, ipa, etc.). - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - auto_snapshot = gt.get("auto_grid_snapshot") - reference = gt.get("build_grid_reference") - - if not auto_snapshot: - raise HTTPException( - status_code=404, - detail="No auto_grid_snapshot found. Re-run build-grid to create one.", - ) - if not reference: - raise HTTPException( - status_code=404, - detail="No ground truth reference found. Mark as ground truth first.", - ) - - diff = compare_grids(auto_snapshot, reference) - - # Enrich with per-col_type breakdown - col_type_stats: Dict[str, Dict[str, int]] = {} - for cell_diff in diff.get("cell_diffs", []): - if cell_diff["type"] != "text_change": - continue - # Find col_type from reference cells - cell_id = cell_diff["cell_id"] - ref_cell = next( - (c for c in reference.get("cells", []) if c["cell_id"] == cell_id), - None, - ) - ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown" - if ct not in col_type_stats: - col_type_stats[ct] = {"total": 0, "corrected": 0} - col_type_stats[ct]["corrected"] += 1 - - # Count total cells per col_type from reference - for cell in reference.get("cells", []): - ct = cell.get("col_type", "unknown") - if ct not in col_type_stats: - col_type_stats[ct] = {"total": 0, "corrected": 0} - col_type_stats[ct]["total"] += 1 - - # Calculate accuracy per col_type - for ct, stats in col_type_stats.items(): - total = stats["total"] - corrected = stats["corrected"] - stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0 - - diff["col_type_breakdown"] = col_type_stats - - return diff - - -@router.get("/ground-truth-sessions") -async def list_ground_truth_sessions(): - """List all sessions that have a ground-truth reference.""" - sessions = await list_ground_truth_sessions_db() - - result = [] - for s in sessions: - gt = s.get("ground_truth") or {} - ref = gt.get("build_grid_reference", {}) - result.append({ - "session_id": s["id"], - "name": s.get("name", ""), - "filename": s.get("filename", ""), - "document_category": s.get("document_category"), - "pipeline": ref.get("pipeline"), - "saved_at": ref.get("saved_at"), - "summary": ref.get("summary", {}), - }) - - return {"sessions": result, "count": len(result)} - - -@router.post("/sessions/{session_id}/regression/run") -async def run_single_regression(session_id: str): - """Re-run build_grid for a single session and compare to ground truth.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - reference = gt.get("build_grid_reference") - if not reference: - raise HTTPException( - status_code=400, - detail="No ground truth reference found for this session", - ) - - # Re-compute grid without persisting - try: - new_result = await _build_grid_core(session_id, session) - except (ValueError, Exception) as e: - return { - "session_id": session_id, - "name": session.get("name", ""), - "status": "error", - "error": str(e), - } - - new_snapshot = _build_reference_snapshot(new_result) - diff = compare_grids(reference, new_snapshot) - - logger.info( - "Regression test session %s: %s (%d structural, %d cell diffs)", - session_id, diff["status"], - diff["summary"]["structural_changes"], - sum(v for k, v in diff["summary"].items() if k != "structural_changes"), - ) - - return { - "session_id": session_id, - "name": session.get("name", ""), - "status": diff["status"], - "diff": diff, - "reference_summary": reference.get("summary", {}), - "current_summary": new_snapshot.get("summary", {}), - } - - -@router.post("/regression/run") -async def run_all_regressions( - triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"), -): - """Re-run build_grid for ALL ground-truth sessions and compare.""" - start_time = time.monotonic() - sessions = await list_ground_truth_sessions_db() - - if not sessions: - return { - "status": "pass", - "message": "No ground truth sessions found", - "results": [], - "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, - } - - results = [] - passed = 0 - failed = 0 - errors = 0 - - for s in sessions: - session_id = s["id"] - gt = s.get("ground_truth") or {} - reference = gt.get("build_grid_reference") - if not reference: - continue - - # Re-load full session (list query may not include all JSONB fields) - full_session = await get_session_db(session_id) - if not full_session: - results.append({ - "session_id": session_id, - "name": s.get("name", ""), - "status": "error", - "error": "Session not found during re-load", - }) - errors += 1 - continue - - try: - new_result = await _build_grid_core(session_id, full_session) - except (ValueError, Exception) as e: - results.append({ - "session_id": session_id, - "name": s.get("name", ""), - "status": "error", - "error": str(e), - }) - errors += 1 - continue - - new_snapshot = _build_reference_snapshot(new_result) - diff = compare_grids(reference, new_snapshot) - - entry = { - "session_id": session_id, - "name": s.get("name", ""), - "status": diff["status"], - "diff_summary": diff["summary"], - "reference_summary": reference.get("summary", {}), - "current_summary": new_snapshot.get("summary", {}), - } - - # Include full diffs only for failures (keep response compact) - if diff["status"] == "fail": - entry["structural_diffs"] = diff["structural_diffs"] - entry["cell_diffs"] = diff["cell_diffs"] - failed += 1 - else: - passed += 1 - - results.append(entry) - - overall = "pass" if failed == 0 and errors == 0 else "fail" - duration_ms = int((time.monotonic() - start_time) * 1000) - - summary = { - "total": len(results), - "passed": passed, - "failed": failed, - "errors": errors, - } - - logger.info( - "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms", - overall, passed, failed, errors, len(results), duration_ms, - ) - - # Persist to DB - run_id = await _persist_regression_run( - status=overall, - summary=summary, - results=results, - duration_ms=duration_ms, - triggered_by=triggered_by, - ) - - return { - "status": overall, - "run_id": run_id, - "duration_ms": duration_ms, - "results": results, - "summary": summary, - } - - -@router.get("/regression/history") -async def get_regression_history( - limit: int = Query(20, ge=1, le=100), -): - """Get recent regression run history from the database.""" - try: - await _init_regression_table() - pool = await get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT id, run_at, status, total, passed, failed, errors, - duration_ms, triggered_by - FROM regression_runs - ORDER BY run_at DESC - LIMIT $1 - """, - limit, - ) - return { - "runs": [ - { - "id": str(row["id"]), - "run_at": row["run_at"].isoformat() if row["run_at"] else None, - "status": row["status"], - "total": row["total"], - "passed": row["passed"], - "failed": row["failed"], - "errors": row["errors"], - "duration_ms": row["duration_ms"], - "triggered_by": row["triggered_by"], - } - for row in rows - ], - "count": len(rows), - } - except Exception as e: - logger.warning("Failed to fetch regression history: %s", e) - return {"runs": [], "count": 0, "error": str(e)} - - -@router.get("/regression/history/{run_id}") -async def get_regression_run_detail(run_id: str): - """Get detailed results of a specific regression run.""" - try: - await _init_regression_table() - pool = await get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT * FROM regression_runs WHERE id = $1", - run_id, - ) - if not row: - raise HTTPException(status_code=404, detail="Run not found") - return { - "id": str(row["id"]), - "run_at": row["run_at"].isoformat() if row["run_at"] else None, - "status": row["status"], - "total": row["total"], - "passed": row["passed"], - "failed": row["failed"], - "errors": row["errors"], - "duration_ms": row["duration_ms"], - "triggered_by": row["triggered_by"], - "results": json.loads(row["results"]) if row["results"] else [], - } - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) +# Backward-compat shim -- module moved to ocr/pipeline/regression_endpoints.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_endpoints") diff --git a/klausur-service/backend/ocr_pipeline_regression_helpers.py b/klausur-service/backend/ocr_pipeline_regression_helpers.py index b8e0a57..865dc35 100644 --- a/klausur-service/backend/ocr_pipeline_regression_helpers.py +++ b/klausur-service/backend/ocr_pipeline_regression_helpers.py @@ -1,207 +1,4 @@ -""" -OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison. - -Extracted from ocr_pipeline_regression.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import os -import uuid -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -from ocr_pipeline_session_store import get_pool - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# DB persistence for regression runs -# --------------------------------------------------------------------------- - -async def _init_regression_table(): - """Ensure regression_runs table exists (idempotent).""" - pool = await get_pool() - async with pool.acquire() as conn: - migration_path = os.path.join( - os.path.dirname(__file__), - "migrations/008_regression_runs.sql", - ) - if os.path.exists(migration_path): - with open(migration_path, "r") as f: - sql = f.read() - await conn.execute(sql) - - -async def _persist_regression_run( - status: str, - summary: dict, - results: list, - duration_ms: int, - triggered_by: str = "manual", -) -> str: - """Save a regression run to the database. Returns the run ID.""" - try: - await _init_regression_table() - pool = await get_pool() - run_id = str(uuid.uuid4()) - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO regression_runs - (id, status, total, passed, failed, errors, duration_ms, results, triggered_by) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9) - """, - run_id, - status, - summary.get("total", 0), - summary.get("passed", 0), - summary.get("failed", 0), - summary.get("errors", 0), - duration_ms, - json.dumps(results), - triggered_by, - ) - logger.info("Regression run %s persisted: %s", run_id, status) - return run_id - except Exception as e: - logger.warning("Failed to persist regression run: %s", e) - return "" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]: - """Extract a flat list of cells from a grid_editor_result for comparison. - - Only keeps fields relevant for comparison: cell_id, row_index, col_index, - col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold. - """ - cells = [] - for zone in grid_result.get("zones", []): - for cell in zone.get("cells", []): - cells.append({ - "cell_id": cell.get("cell_id", ""), - "row_index": cell.get("row_index"), - "col_index": cell.get("col_index"), - "col_type": cell.get("col_type", ""), - "text": cell.get("text", ""), - }) - return cells - - -def _build_reference_snapshot( - grid_result: dict, - pipeline: Optional[str] = None, -) -> dict: - """Build a ground-truth reference snapshot from a grid_editor_result.""" - cells = _extract_cells_for_comparison(grid_result) - - total_zones = len(grid_result.get("zones", [])) - total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", [])) - total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", [])) - - snapshot = { - "saved_at": datetime.now(timezone.utc).isoformat(), - "version": 1, - "pipeline": pipeline, - "summary": { - "total_zones": total_zones, - "total_columns": total_columns, - "total_rows": total_rows, - "total_cells": len(cells), - }, - "cells": cells, - } - return snapshot - - -def compare_grids(reference: dict, current: dict) -> dict: - """Compare a reference grid snapshot with a newly computed one. - - Returns a diff report with: - - status: "pass" or "fail" - - structural_diffs: changes in zone/row/column counts - - cell_diffs: list of individual cell changes - """ - ref_summary = reference.get("summary", {}) - cur_summary = current.get("summary", {}) - - structural_diffs = [] - for key in ("total_zones", "total_columns", "total_rows", "total_cells"): - ref_val = ref_summary.get(key, 0) - cur_val = cur_summary.get(key, 0) - if ref_val != cur_val: - structural_diffs.append({ - "field": key, - "reference": ref_val, - "current": cur_val, - }) - - # Build cell lookup by cell_id - ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])} - cur_cells = {c["cell_id"]: c for c in current.get("cells", [])} - - cell_diffs: List[Dict[str, Any]] = [] - - # Check for missing cells (in reference but not in current) - for cell_id in ref_cells: - if cell_id not in cur_cells: - cell_diffs.append({ - "type": "cell_missing", - "cell_id": cell_id, - "reference_text": ref_cells[cell_id].get("text", ""), - }) - - # Check for added cells (in current but not in reference) - for cell_id in cur_cells: - if cell_id not in ref_cells: - cell_diffs.append({ - "type": "cell_added", - "cell_id": cell_id, - "current_text": cur_cells[cell_id].get("text", ""), - }) - - # Check for changes in shared cells - for cell_id in ref_cells: - if cell_id not in cur_cells: - continue - ref_cell = ref_cells[cell_id] - cur_cell = cur_cells[cell_id] - - if ref_cell.get("text", "") != cur_cell.get("text", ""): - cell_diffs.append({ - "type": "text_change", - "cell_id": cell_id, - "reference": ref_cell.get("text", ""), - "current": cur_cell.get("text", ""), - }) - - if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""): - cell_diffs.append({ - "type": "col_type_change", - "cell_id": cell_id, - "reference": ref_cell.get("col_type", ""), - "current": cur_cell.get("col_type", ""), - }) - - status = "pass" if not structural_diffs and not cell_diffs else "fail" - - return { - "status": status, - "structural_diffs": structural_diffs, - "cell_diffs": cell_diffs, - "summary": { - "structural_changes": len(structural_diffs), - "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"), - "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"), - "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"), - "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"), - }, - } +# Backward-compat shim -- module moved to ocr/pipeline/regression_helpers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.regression_helpers") diff --git a/klausur-service/backend/ocr_pipeline_reprocess.py b/klausur-service/backend/ocr_pipeline_reprocess.py index 62d68fa..5d5dd49 100644 --- a/klausur-service/backend/ocr_pipeline_reprocess.py +++ b/klausur-service/backend/ocr_pipeline_reprocess.py @@ -1,94 +1,4 @@ -""" -OCR Pipeline Reprocess Endpoint. - -POST /sessions/{session_id}/reprocess — clear downstream + restart from step. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -from typing import Any, Dict - -from fastapi import APIRouter, HTTPException, Request - -from ocr_pipeline_common import _cache -from ocr_pipeline_session_store import get_session_db, update_session_db - -logger = logging.getLogger(__name__) - -router = APIRouter(tags=["ocr-pipeline"]) - - -@router.post("/sessions/{session_id}/reprocess") -async def reprocess_session(session_id: str, request: Request): - """Re-run pipeline from a specific step, clearing downstream data. - - Body: {"from_step": 5} (1-indexed step number) - - Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) -> - Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10) - - Clears downstream results: - - from_step <= 1: orientation_result + all downstream - - from_step <= 2: deskew_result + all downstream - - from_step <= 3: dewarp_result + all downstream - - from_step <= 4: crop_result + all downstream - - from_step <= 5: column_result, row_result, word_result - - from_step <= 6: row_result, word_result - - from_step <= 7: word_result (cells, vocab_entries) - - from_step <= 8: word_result.llm_review only - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - body = await request.json() - from_step = body.get("from_step", 1) - if not isinstance(from_step, int) or from_step < 1 or from_step > 10: - raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") - - update_kwargs: Dict[str, Any] = {"current_step": from_step} - - # Clear downstream data based on from_step - # New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) -> - # Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11) - if from_step <= 8: - update_kwargs["word_result"] = None - elif from_step == 9: - # Only clear LLM review from word_result - word_result = session.get("word_result") - if word_result: - word_result.pop("llm_review", None) - word_result.pop("llm_corrections", None) - update_kwargs["word_result"] = word_result - - if from_step <= 7: - update_kwargs["row_result"] = None - if from_step <= 6: - update_kwargs["column_result"] = None - if from_step <= 4: - update_kwargs["crop_result"] = None - if from_step <= 3: - update_kwargs["dewarp_result"] = None - if from_step <= 2: - update_kwargs["deskew_result"] = None - if from_step <= 1: - update_kwargs["orientation_result"] = None - - await update_session_db(session_id, **update_kwargs) - - # Also clear cache - if session_id in _cache: - for key in list(update_kwargs.keys()): - if key != "current_step": - _cache[session_id][key] = update_kwargs[key] - _cache[session_id]["current_step"] = from_step - - logger.info(f"Session {session_id} reprocessing from step {from_step}") - - return { - "session_id": session_id, - "from_step": from_step, - "cleared": [k for k in update_kwargs if k != "current_step"], - } +# Backward-compat shim -- module moved to ocr/pipeline/reprocess.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.reprocess") diff --git a/klausur-service/backend/ocr_pipeline_rows.py b/klausur-service/backend/ocr_pipeline_rows.py index 9fb9915..6d387cc 100644 --- a/klausur-service/backend/ocr_pipeline_rows.py +++ b/klausur-service/backend/ocr_pipeline_rows.py @@ -1,348 +1,4 @@ -""" -OCR Pipeline - Row Detection Endpoints. - -Extracted from ocr_pipeline_api.py. -Handles row detection (auto + manual) and row ground truth. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import time -from datetime import datetime -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException - -from cv_vocab_pipeline import ( - create_ocr_image, - detect_column_geometry, - detect_row_geometry, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _append_pipeline_log, - ManualRowsRequest, - RowGroundTruthRequest, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Helper: Box-exclusion overlay (used by rows overlay and columns overlay) -# --------------------------------------------------------------------------- - -def _draw_box_exclusion_overlay( - img: np.ndarray, - zones: List[Dict], - *, - label: str = "BOX — separat verarbeitet", -) -> None: - """Draw red semi-transparent rectangles over box zones (in-place). - - Reusable for columns, rows, and words overlays. - """ - for zone in zones: - if zone.get("zone_type") != "box" or not zone.get("box"): - continue - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - - # Red semi-transparent fill (~25 %) - box_overlay = img.copy() - cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1) - cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img) - - # Border - cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2) - - # Label - cv2.putText(img, label, (bx + 10, by + bh - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) - - -# --------------------------------------------------------------------------- -# Row Detection Endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/rows") -async def detect_rows(session_id: str): - """Run row detection on the cropped (or dewarped) image using horizontal gap analysis.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if dewarped_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection") - - t0 = time.time() - - # Try to reuse cached word_dicts and inv from column detection - word_dicts = cached.get("_word_dicts") - inv = cached.get("_inv") - content_bounds = cached.get("_content_bounds") - - if word_dicts is None or inv is None or content_bounds is None: - # Not cached — run column geometry to get intermediates - ocr_img = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img, dewarped_bgr) - if geo_result is None: - raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows") - _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - else: - left_x, right_x, top_y, bottom_y = content_bounds - - # Read zones from column_result to exclude box regions - session = await get_session_db(session_id) - column_result = (session or {}).get("column_result") or {} - is_sub_session = bool((session or {}).get("parent_session_id")) - - # Sub-sessions (box crops): use word-grouping instead of gap-based - # row detection. Box images are small with complex internal layouts - # (headings, sub-columns) where the horizontal projection approach - # merges rows. Word-grouping directly clusters words by Y proximity, - # which is more robust for these cases. - if is_sub_session and word_dicts: - from cv_layout import _build_rows_from_word_grouping - rows = _build_rows_from_word_grouping( - word_dicts, left_x, right_x, top_y, bottom_y, - right_x - left_x, bottom_y - top_y, - ) - logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows") - else: - zones = column_result.get("zones") or [] # zones can be None for sub-sessions - - # Collect box y-ranges for filtering. - # Use border_thickness to shrink the exclusion zone: the border pixels - # belong visually to the box frame, but text rows above/below the box - # may overlap with the border area and must not be clipped. - box_ranges = [] # [(y_start, y_end)] - box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin - box_ranges.append((box["y"], box["y"] + box["height"])) - # Inner range: shrink by border thickness so boundary rows aren't excluded - box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) - - if box_ranges and inv is not None: - # Combined-image approach: strip box regions from inv image, - # run row detection on the combined image, then remap y-coords back. - content_strips = [] # [(y_start, y_end)] in absolute coords - # Build content strips by subtracting box inner ranges from [top_y, bottom_y]. - # Using inner ranges means the border area is included in the content - # strips, so the last row above a box isn't clipped by the border. - sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0]) - strip_start = top_y - for by_start, by_end in sorted_boxes: - if by_start > strip_start: - content_strips.append((strip_start, by_start)) - strip_start = max(strip_start, by_end) - if strip_start < bottom_y: - content_strips.append((strip_start, bottom_y)) - - # Filter to strips with meaningful height - content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20] - - if content_strips: - # Stack content strips vertically - inv_strips = [inv[ys:ye, :] for ys, ye in content_strips] - combined_inv = np.vstack(inv_strips) - - # Filter word_dicts to only include words from content strips - combined_words = [] - cum_y = 0 - strip_offsets = [] # (combined_y_start, strip_height, abs_y_start) - for ys, ye in content_strips: - h = ye - ys - strip_offsets.append((cum_y, h, ys)) - for w in word_dicts: - w_abs_y = w['top'] + top_y # word y is relative to content top - w_center = w_abs_y + w['height'] / 2 - if ys <= w_center < ye: - # Remap to combined coordinates - w_copy = dict(w) - w_copy['top'] = cum_y + (w_abs_y - ys) - combined_words.append(w_copy) - cum_y += h - - # Run row detection on combined image - combined_h = combined_inv.shape[0] - rows = detect_row_geometry( - combined_inv, combined_words, left_x, right_x, 0, combined_h, - ) - - # Remap y-coordinates back to absolute page coords - def _combined_y_to_abs(cy: int) -> int: - for c_start, s_h, abs_start in strip_offsets: - if cy < c_start + s_h: - return abs_start + (cy - c_start) - last_c, last_h, last_abs = strip_offsets[-1] - return last_abs + last_h - - for r in rows: - abs_y = _combined_y_to_abs(r.y) - abs_y_end = _combined_y_to_abs(r.y + r.height) - r.y = abs_y - r.height = abs_y_end - abs_y - else: - rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - else: - # No boxes — standard row detection - rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - - duration = time.time() - t0 - - # Assign zone_index based on which content zone each row falls in - # Build content zone list with indices - zones = column_result.get("zones") or [] - content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else [] - - # Build serializable result (exclude words to keep payload small) - rows_data = [] - for r in rows: - # Determine zone_index - zone_idx = 0 - row_center_y = r.y + r.height / 2 - for zi, zone in content_zones: - zy = zone["y"] - zh = zone["height"] - if zy <= row_center_y < zy + zh: - zone_idx = zi - break - - rd = { - "index": r.index, - "x": r.x, - "y": r.y, - "width": r.width, - "height": r.height, - "word_count": r.word_count, - "row_type": r.row_type, - "gap_before": r.gap_before, - "zone_index": zone_idx, - } - rows_data.append(rd) - - type_counts = {} - for r in rows: - type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1 - - row_result = { - "rows": rows_data, - "summary": type_counts, - "total_rows": len(rows), - "duration_seconds": round(duration, 2), - } - - # Persist to DB — also invalidate word_result since rows changed - await update_session_db( - session_id, - row_result=row_result, - word_result=None, - current_step=7, - ) - - cached["row_result"] = row_result - cached.pop("word_result", None) - - logger.info(f"OCR Pipeline: rows session {session_id}: " - f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}") - - content_rows = sum(1 for r in rows if r.row_type == "content") - avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0 - await _append_pipeline_log(session_id, "rows", { - "total_rows": len(rows), - "content_rows": content_rows, - "artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0), - "avg_row_height_px": avg_height, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **row_result, - } - - -@router.post("/sessions/{session_id}/rows/manual") -async def set_manual_rows(session_id: str, req: ManualRowsRequest): - """Override detected rows with manual definitions.""" - row_result = { - "rows": req.rows, - "total_rows": len(req.rows), - "duration_seconds": 0, - "method": "manual", - } - - await update_session_db(session_id, row_result=row_result, word_result=None) - - if session_id in _cache: - _cache[session_id]["row_result"] = row_result - _cache[session_id].pop("word_result", None) - - logger.info(f"OCR Pipeline: manual rows session {session_id}: " - f"{len(req.rows)} rows set") - - return {"session_id": session_id, **row_result} - - -@router.post("/sessions/{session_id}/ground-truth/rows") -async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest): - """Save ground truth feedback for the row detection step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_rows": req.corrected_rows, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "row_result": session.get("row_result"), - } - ground_truth["rows"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@router.get("/sessions/{session_id}/ground-truth/rows") -async def get_row_ground_truth(session_id: str): - """Retrieve saved ground truth for row detection.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - rows_gt = ground_truth.get("rows") - if not rows_gt: - raise HTTPException(status_code=404, detail="No row ground truth saved") - - return { - "session_id": session_id, - "rows_gt": rows_gt, - "rows_auto": session.get("row_result"), - } +# Backward-compat shim -- module moved to ocr/pipeline/rows.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.rows") diff --git a/klausur-service/backend/ocr_pipeline_session_store.py b/klausur-service/backend/ocr_pipeline_session_store.py index dc80c55..6efa9ec 100644 --- a/klausur-service/backend/ocr_pipeline_session_store.py +++ b/klausur-service/backend/ocr_pipeline_session_store.py @@ -1,388 +1,4 @@ -""" -OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions. - -Replaces in-memory storage with database persistence. -See migrations/002_ocr_pipeline_sessions.sql for schema. -""" - -import os -import uuid -import logging -import json -from typing import Optional, List, Dict, Any - -import asyncpg - -logger = logging.getLogger(__name__) - -# Database configuration (same as vocab_session_store) -DATABASE_URL = os.getenv( - "DATABASE_URL", - "postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db" -) - -# Connection pool (initialized lazily) -_pool: Optional[asyncpg.Pool] = None - - -async def get_pool() -> asyncpg.Pool: - """Get or create the database connection pool.""" - global _pool - if _pool is None: - _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) - return _pool - - -async def init_ocr_pipeline_tables(): - """Initialize OCR pipeline tables if they don't exist.""" - pool = await get_pool() - async with pool.acquire() as conn: - tables_exist = await conn.fetchval(""" - SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_name = 'ocr_pipeline_sessions' - ) - """) - - if not tables_exist: - logger.info("Creating OCR pipeline tables...") - migration_path = os.path.join( - os.path.dirname(__file__), - "migrations/002_ocr_pipeline_sessions.sql" - ) - if os.path.exists(migration_path): - with open(migration_path, "r") as f: - sql = f.read() - await conn.execute(sql) - logger.info("OCR pipeline tables created successfully") - else: - logger.warning(f"Migration file not found: {migration_path}") - else: - logger.debug("OCR pipeline tables already exist") - - # Ensure new columns exist (idempotent ALTER TABLE) - await conn.execute(""" - ALTER TABLE ocr_pipeline_sessions - ADD COLUMN IF NOT EXISTS clean_png BYTEA, - ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB, - ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50), - ADD COLUMN IF NOT EXISTS doc_type_result JSONB, - ADD COLUMN IF NOT EXISTS document_category VARCHAR(50), - ADD COLUMN IF NOT EXISTS pipeline_log JSONB, - ADD COLUMN IF NOT EXISTS oriented_png BYTEA, - ADD COLUMN IF NOT EXISTS cropped_png BYTEA, - ADD COLUMN IF NOT EXISTS orientation_result JSONB, - ADD COLUMN IF NOT EXISTS crop_result JSONB, - ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE, - ADD COLUMN IF NOT EXISTS box_index INT, - ADD COLUMN IF NOT EXISTS grid_editor_result JSONB, - ADD COLUMN IF NOT EXISTS structure_result JSONB, - ADD COLUMN IF NOT EXISTS document_group_id UUID, - ADD COLUMN IF NOT EXISTS page_number INT - """) - - # Index for document group lookups - await conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group - ON ocr_pipeline_sessions (document_group_id) - WHERE document_group_id IS NOT NULL - """) - - -# ============================================================================= -# SESSION CRUD -# ============================================================================= - -async def create_session_db( - session_id: str, - name: str, - filename: str, - original_png: bytes, - parent_session_id: Optional[str] = None, - box_index: Optional[int] = None, - document_group_id: Optional[str] = None, - page_number: Optional[int] = None, -) -> Dict[str, Any]: - """Create a new OCR pipeline session. - - Args: - parent_session_id: If set, this is a sub-session for a box region. - box_index: 0-based index of the box this sub-session represents. - document_group_id: Groups multi-page uploads into one document. - page_number: 1-based page index within the document group. - """ - pool = await get_pool() - parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None - group_uuid = uuid.UUID(document_group_id) if document_group_id else None - async with pool.acquire() as conn: - row = await conn.fetchrow(""" - INSERT INTO ocr_pipeline_sessions ( - id, name, filename, original_png, status, current_step, - parent_session_id, box_index, document_group_id, page_number - ) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8) - RETURNING id, name, filename, status, current_step, - orientation_result, crop_result, - deskew_result, dewarp_result, column_result, row_result, - word_result, ground_truth, auto_shear_degrees, - doc_type, doc_type_result, - document_category, pipeline_log, - grid_editor_result, structure_result, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at - """, uuid.UUID(session_id), name, filename, original_png, - parent_uuid, box_index, group_uuid, page_number) - - return _row_to_dict(row) - - -async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]: - """Get session metadata (without images).""" - pool = await get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow(""" - SELECT id, name, filename, status, current_step, - orientation_result, crop_result, - deskew_result, dewarp_result, column_result, row_result, - word_result, ground_truth, auto_shear_degrees, - doc_type, doc_type_result, - document_category, pipeline_log, - grid_editor_result, structure_result, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at - FROM ocr_pipeline_sessions WHERE id = $1 - """, uuid.UUID(session_id)) - - if row: - return _row_to_dict(row) - return None - - -async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]: - """Load a single image (BYTEA) from the session.""" - column_map = { - "original": "original_png", - "oriented": "oriented_png", - "cropped": "cropped_png", - "deskewed": "deskewed_png", - "binarized": "binarized_png", - "dewarped": "dewarped_png", - "clean": "clean_png", - } - column = column_map.get(image_type) - if not column: - return None - - pool = await get_pool() - async with pool.acquire() as conn: - return await conn.fetchval( - f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1", - uuid.UUID(session_id) - ) - - -async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]: - """Update session fields dynamically.""" - pool = await get_pool() - - fields = [] - values = [] - param_idx = 1 - - allowed_fields = { - 'name', 'filename', 'status', 'current_step', - 'original_png', 'oriented_png', 'cropped_png', - 'deskewed_png', 'binarized_png', 'dewarped_png', - 'clean_png', 'handwriting_removal_meta', - 'orientation_result', 'crop_result', - 'deskew_result', 'dewarp_result', 'column_result', 'row_result', - 'word_result', 'ground_truth', 'auto_shear_degrees', - 'doc_type', 'doc_type_result', - 'document_category', 'pipeline_log', - 'grid_editor_result', 'structure_result', - 'parent_session_id', 'box_index', - 'document_group_id', 'page_number', - } - - jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'} - - for key, value in kwargs.items(): - if key in allowed_fields: - fields.append(f"{key} = ${param_idx}") - if key in jsonb_fields and value is not None and not isinstance(value, str): - value = json.dumps(value) - values.append(value) - param_idx += 1 - - if not fields: - return await get_session_db(session_id) - - # Always update updated_at - fields.append(f"updated_at = NOW()") - - values.append(uuid.UUID(session_id)) - - async with pool.acquire() as conn: - row = await conn.fetchrow(f""" - UPDATE ocr_pipeline_sessions - SET {', '.join(fields)} - WHERE id = ${param_idx} - RETURNING id, name, filename, status, current_step, - orientation_result, crop_result, - deskew_result, dewarp_result, column_result, row_result, - word_result, ground_truth, auto_shear_degrees, - doc_type, doc_type_result, - document_category, pipeline_log, - grid_editor_result, structure_result, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at - """, *values) - - if row: - return _row_to_dict(row) - return None - - -async def list_sessions_db( - limit: int = 50, - include_sub_sessions: bool = False, -) -> List[Dict[str, Any]]: - """List sessions (metadata only, no images). - - By default, sub-sessions (those with parent_session_id) are excluded. - Pass include_sub_sessions=True to include them. - """ - pool = await get_pool() - async with pool.acquire() as conn: - where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')" - rows = await conn.fetch(f""" - SELECT id, name, filename, status, current_step, - document_category, doc_type, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at, - ground_truth - FROM ocr_pipeline_sessions - {where} - ORDER BY created_at DESC - LIMIT $1 - """, limit) - - results = [] - for row in rows: - d = _row_to_dict(row) - # Derive is_ground_truth flag from JSONB, then drop the heavy field - gt = d.pop("ground_truth", None) or {} - d["is_ground_truth"] = bool(gt.get("build_grid_reference")) - results.append(d) - return results - - -async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]: - """Get all sub-sessions for a parent session, ordered by box_index.""" - pool = await get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT id, name, filename, status, current_step, - document_category, doc_type, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at - FROM ocr_pipeline_sessions - WHERE parent_session_id = $1 - ORDER BY box_index ASC - """, uuid.UUID(parent_session_id)) - - return [_row_to_dict(row) for row in rows] - - -async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]: - """Get all sessions in a document group, ordered by page_number.""" - pool = await get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT id, name, filename, status, current_step, - document_category, doc_type, - parent_session_id, box_index, - document_group_id, page_number, - created_at, updated_at - FROM ocr_pipeline_sessions - WHERE document_group_id = $1 - ORDER BY page_number ASC - """, uuid.UUID(document_group_id)) - - return [_row_to_dict(row) for row in rows] - - -async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]: - """List sessions that have a build_grid_reference in ground_truth.""" - pool = await get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT id, name, filename, status, current_step, - document_category, doc_type, - ground_truth, - parent_session_id, box_index, - created_at, updated_at - FROM ocr_pipeline_sessions - WHERE ground_truth IS NOT NULL - AND ground_truth::text LIKE '%build_grid_reference%' - AND parent_session_id IS NULL - ORDER BY created_at DESC - """) - - return [_row_to_dict(row) for row in rows] - - -async def delete_session_db(session_id: str) -> bool: - """Delete a session.""" - pool = await get_pool() - async with pool.acquire() as conn: - result = await conn.execute(""" - DELETE FROM ocr_pipeline_sessions WHERE id = $1 - """, uuid.UUID(session_id)) - return result == "DELETE 1" - - -async def delete_all_sessions_db() -> int: - """Delete all sessions. Returns number of deleted rows.""" - pool = await get_pool() - async with pool.acquire() as conn: - result = await conn.execute("DELETE FROM ocr_pipeline_sessions") - # result is e.g. "DELETE 5" - try: - return int(result.split()[-1]) - except (ValueError, IndexError): - return 0 - - -# ============================================================================= -# HELPER -# ============================================================================= - -def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: - """Convert asyncpg Record to JSON-serializable dict.""" - if row is None: - return {} - - result = dict(row) - - # UUID → string - for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']: - if key in result and result[key] is not None: - result[key] = str(result[key]) - - # datetime → ISO string - for key in ['created_at', 'updated_at']: - if key in result and result[key] is not None: - result[key] = result[key].isoformat() - - # JSONB → parsed (asyncpg returns str for JSONB) - for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']: - if key in result and result[key] is not None: - if isinstance(result[key], str): - result[key] = json.loads(result[key]) - - return result +# Backward-compat shim -- module moved to ocr/pipeline/session_store.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.session_store") diff --git a/klausur-service/backend/ocr_pipeline_sessions.py b/klausur-service/backend/ocr_pipeline_sessions.py index ae3f771..d1f978f 100644 --- a/klausur-service/backend/ocr_pipeline_sessions.py +++ b/klausur-service/backend/ocr_pipeline_sessions.py @@ -1,20 +1,4 @@ -""" -OCR Pipeline Sessions API — barrel re-export. - -All implementation split into: - ocr_pipeline_sessions_crud — session CRUD, box sessions - ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -from fastapi import APIRouter - -from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401 -from ocr_pipeline_sessions_images import router as _images_router # noqa: F401 - -# Composite router (used by ocr_pipeline_api.py) -router = APIRouter() -router.include_router(_crud_router) -router.include_router(_images_router) +# Backward-compat shim -- module moved to ocr/pipeline/sessions.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions") diff --git a/klausur-service/backend/ocr_pipeline_sessions_crud.py b/klausur-service/backend/ocr_pipeline_sessions_crud.py index 19343d7..a507df6 100644 --- a/klausur-service/backend/ocr_pipeline_sessions_crud.py +++ b/klausur-service/backend/ocr_pipeline_sessions_crud.py @@ -1,449 +1,4 @@ -""" -OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions. - -Extracted from ocr_pipeline_sessions.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import uuid -from typing import Any, Dict, Optional - -import cv2 -import numpy as np -from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile - -from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res -from ocr_pipeline_common import ( - VALID_DOCUMENT_CATEGORIES, - UpdateSessionRequest, - _cache, -) -from ocr_pipeline_session_store import ( - create_session_db, - delete_all_sessions_db, - delete_session_db, - get_session_db, - get_session_image, - get_sub_sessions, - list_sessions_db, - update_session_db, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Session Management Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions") -async def list_sessions(include_sub_sessions: bool = False): - """List OCR pipeline sessions. - - By default, sub-sessions (box regions) are hidden. - Pass ?include_sub_sessions=true to show them. - """ - sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) - return {"sessions": sessions} - - -@router.post("/sessions") -async def create_session( - file: UploadFile = File(...), - name: Optional[str] = Form(None), -): - """Upload a PDF or image file and create a pipeline session. - - For multi-page PDFs (> 1 page), each page becomes its own session - grouped under a ``document_group_id``. The response includes a - ``pages`` array with one entry per page/session. - """ - file_data = await file.read() - filename = file.filename or "upload" - content_type = file.content_type or "" - - is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") - session_name = name or filename - - # --- Multi-page PDF handling --- - if is_pdf: - try: - import fitz # PyMuPDF - pdf_doc = fitz.open(stream=file_data, filetype="pdf") - page_count = pdf_doc.page_count - pdf_doc.close() - except Exception as e: - raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}") - - if page_count > 1: - return await _create_multi_page_sessions( - file_data, filename, session_name, page_count, - ) - - # --- Single page (image or 1-page PDF) --- - session_id = str(uuid.uuid4()) - - try: - if is_pdf: - img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) - else: - img_bgr = render_image_high_res(file_data) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Could not process file: {e}") - - # Encode original as PNG bytes - success, png_buf = cv2.imencode(".png", img_bgr) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode image") - - original_png = png_buf.tobytes() - - # Persist to DB - await create_session_db( - session_id=session_id, - name=session_name, - filename=filename, - original_png=original_png, - ) - - # Cache BGR array for immediate processing - _cache[session_id] = { - "id": session_id, - "filename": filename, - "name": session_name, - "original_bgr": img_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - logger.info(f"OCR Pipeline: created session {session_id} from {filename} " - f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") - - return { - "session_id": session_id, - "filename": filename, - "name": session_name, - "image_width": img_bgr.shape[1], - "image_height": img_bgr.shape[0], - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - } - - -async def _create_multi_page_sessions( - pdf_data: bytes, - filename: str, - base_name: str, - page_count: int, -) -> dict: - """Create one session per PDF page, grouped by document_group_id.""" - document_group_id = str(uuid.uuid4()) - pages = [] - - for page_idx in range(page_count): - session_id = str(uuid.uuid4()) - page_name = f"{base_name} — Seite {page_idx + 1}" - - try: - img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0) - except Exception as e: - logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}") - continue - - ok, png_buf = cv2.imencode(".png", img_bgr) - if not ok: - continue - page_png = png_buf.tobytes() - - await create_session_db( - session_id=session_id, - name=page_name, - filename=filename, - original_png=page_png, - document_group_id=document_group_id, - page_number=page_idx + 1, - ) - - _cache[session_id] = { - "id": session_id, - "filename": filename, - "name": page_name, - "original_bgr": img_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - h, w = img_bgr.shape[:2] - pages.append({ - "session_id": session_id, - "name": page_name, - "page_number": page_idx + 1, - "image_width": w, - "image_height": h, - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - }) - - logger.info( - f"OCR Pipeline: created page session {session_id} " - f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})" - ) - - # Include session_id pointing to first page for backwards compatibility - # (frontends that expect a single session_id will navigate to page 1) - first_session_id = pages[0]["session_id"] if pages else None - - return { - "session_id": first_session_id, - "document_group_id": document_group_id, - "filename": filename, - "name": base_name, - "page_count": page_count, - "pages": pages, - } - - -@router.get("/sessions/{session_id}") -async def get_session_info(session_id: str): - """Get session info including deskew/dewarp/column results for step navigation.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # Get image dimensions from original PNG - original_png = await get_session_image(session_id, "original") - if original_png: - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) - else: - img_w, img_h = 0, 0 - - result = { - "session_id": session["id"], - "filename": session.get("filename", ""), - "name": session.get("name", ""), - "image_width": img_w, - "image_height": img_h, - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - "current_step": session.get("current_step", 1), - "document_category": session.get("document_category"), - "doc_type": session.get("doc_type"), - } - - if session.get("orientation_result"): - result["orientation_result"] = session["orientation_result"] - if session.get("crop_result"): - result["crop_result"] = session["crop_result"] - if session.get("deskew_result"): - result["deskew_result"] = session["deskew_result"] - if session.get("dewarp_result"): - result["dewarp_result"] = session["dewarp_result"] - if session.get("column_result"): - result["column_result"] = session["column_result"] - if session.get("row_result"): - result["row_result"] = session["row_result"] - if session.get("word_result"): - result["word_result"] = session["word_result"] - if session.get("doc_type_result"): - result["doc_type_result"] = session["doc_type_result"] - if session.get("structure_result"): - result["structure_result"] = session["structure_result"] - if session.get("grid_editor_result"): - # Include summary only to keep response small - gr = session["grid_editor_result"] - result["grid_editor_result"] = { - "summary": gr.get("summary", {}), - "zones_count": len(gr.get("zones", [])), - "edited": gr.get("edited", False), - } - if session.get("ground_truth"): - result["ground_truth"] = session["ground_truth"] - - # Box sub-session info (zone_type='box' from column detection — NOT page-split) - if session.get("parent_session_id"): - result["parent_session_id"] = session["parent_session_id"] - result["box_index"] = session.get("box_index") - else: - # Check for box sub-sessions (column detection creates these) - subs = await get_sub_sessions(session_id) - if subs: - result["sub_sessions"] = [ - {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} - for s in subs - ] - - return result - - -@router.put("/sessions/{session_id}") -async def update_session(session_id: str, req: UpdateSessionRequest): - """Update session name and/or document category.""" - kwargs: Dict[str, Any] = {} - if req.name is not None: - kwargs["name"] = req.name - if req.document_category is not None: - if req.document_category not in VALID_DOCUMENT_CATEGORIES: - raise HTTPException( - status_code=400, - detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", - ) - kwargs["document_category"] = req.document_category - if not kwargs: - raise HTTPException(status_code=400, detail="Nothing to update") - updated = await update_session_db(session_id, **kwargs) - if not updated: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, **kwargs} - - -@router.delete("/sessions/{session_id}") -async def delete_session(session_id: str): - """Delete a session.""" - _cache.pop(session_id, None) - deleted = await delete_session_db(session_id) - if not deleted: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "deleted": True} - - -@router.delete("/sessions") -async def delete_all_sessions(): - """Delete ALL sessions (cleanup).""" - _cache.clear() - count = await delete_all_sessions_db() - return {"deleted_count": count} - - -@router.post("/sessions/{session_id}/create-box-sessions") -async def create_box_sessions(session_id: str): - """Create sub-sessions for each detected box region. - - Crops box regions from the cropped/dewarped image and creates - independent sub-sessions that can be processed through the pipeline. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result: - raise HTTPException(status_code=400, detail="Column detection must be completed first") - - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - if not box_zones: - return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} - - # Check for existing sub-sessions - existing = await get_sub_sessions(session_id) - if existing: - return { - "session_id": session_id, - "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], - "message": f"{len(existing)} sub-session(s) already exist", - } - - # Load base image - base_png = await get_session_image(session_id, "cropped") - if not base_png: - base_png = await get_session_image(session_id, "dewarped") - if not base_png: - raise HTTPException(status_code=400, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - parent_name = session.get("name", "Session") - created = [] - - for i, zone in enumerate(box_zones): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - - # Crop box region with small padding - pad = 5 - y1 = max(0, by - pad) - y2 = min(img.shape[0], by + bh + pad) - x1 = max(0, bx - pad) - x2 = min(img.shape[1], bx + bw + pad) - crop = img[y1:y2, x1:x2] - - # Encode as PNG - success, png_buf = cv2.imencode(".png", crop) - if not success: - logger.warning(f"Failed to encode box {i} crop for session {session_id}") - continue - - sub_id = str(uuid.uuid4()) - sub_name = f"{parent_name} — Box {i + 1}" - - await create_session_db( - session_id=sub_id, - name=sub_name, - filename=session.get("filename", "box-crop.png"), - original_png=png_buf.tobytes(), - parent_session_id=session_id, - box_index=i, - ) - - # Cache the BGR for immediate processing - # Promote original to cropped so column/row/word detection finds it - box_bgr = crop.copy() - _cache[sub_id] = { - "id": sub_id, - "filename": session.get("filename", "box-crop.png"), - "name": sub_name, - "parent_session_id": session_id, - "original_bgr": box_bgr, - "oriented_bgr": None, - "cropped_bgr": box_bgr, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - created.append({ - "id": sub_id, - "name": sub_name, - "box_index": i, - "box": box, - "image_width": crop.shape[1], - "image_height": crop.shape[0], - }) - - logger.info(f"Created box sub-session {sub_id} for session {session_id} " - f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") - - return { - "session_id": session_id, - "sub_sessions": created, - "total": len(created), - } +# Backward-compat shim -- module moved to ocr/pipeline/sessions_crud.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_crud") diff --git a/klausur-service/backend/ocr_pipeline_sessions_images.py b/klausur-service/backend/ocr_pipeline_sessions_images.py index 79da448..6283e3f 100644 --- a/klausur-service/backend/ocr_pipeline_sessions_images.py +++ b/klausur-service/backend/ocr_pipeline_sessions_images.py @@ -1,176 +1,4 @@ -""" -OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log, -categories, and document type detection. - -Extracted from ocr_pipeline_sessions.py for modularity. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import time -from typing import Any, Dict - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException, Query -from fastapi.responses import Response - -from cv_vocab_pipeline import create_ocr_image, detect_document_type -from ocr_pipeline_common import ( - VALID_DOCUMENT_CATEGORIES, - _append_pipeline_log, - _cache, - _get_base_image_png, - _get_cached, - _load_session_to_cache, -) -from ocr_pipeline_overlays import render_overlay -from ocr_pipeline_session_store import ( - get_session_db, - get_session_image, - update_session_db, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Thumbnail & Log Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/thumbnail") -async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): - """Return a small thumbnail of the original image.""" - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - h, w = img.shape[:2] - scale = size / max(h, w) - new_w, new_h = int(w * scale), int(h * scale) - thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) - _, png_bytes = cv2.imencode(".png", thumb) - return Response(content=png_bytes.tobytes(), media_type="image/png", - headers={"Cache-Control": "public, max-age=3600"}) - - -@router.get("/sessions/{session_id}/pipeline-log") -async def get_pipeline_log(session_id: str): - """Get the pipeline execution log for a session.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} - - -@router.get("/categories") -async def list_categories(): - """List valid document categories.""" - return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} - - -# --------------------------------------------------------------------------- -# Image Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/image/{image_type}") -async def get_image(session_id: str, image_type: str): - """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" - valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} - if image_type not in valid_types: - raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") - - if image_type == "structure-overlay": - return await render_overlay("structure", session_id) - - if image_type == "columns-overlay": - return await render_overlay("columns", session_id) - - if image_type == "rows-overlay": - return await render_overlay("rows", session_id) - - if image_type == "words-overlay": - return await render_overlay("words", session_id) - - # Try cache first for fast serving - cached = _cache.get(session_id) - if cached: - png_key = f"{image_type}_png" if image_type != "original" else None - bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None - - # For binarized, check if we have it cached as PNG - if image_type == "binarized" and cached.get("binarized_png"): - return Response(content=cached["binarized_png"], media_type="image/png") - - # Load from DB — for cropped/dewarped, fall back through the chain - if image_type in ("cropped", "dewarped"): - data = await _get_base_image_png(session_id) - else: - data = await get_session_image(session_id, image_type) - if not data: - raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") - - return Response(content=data, media_type="image/png") - - -# --------------------------------------------------------------------------- -# Document Type Detection (between Dewarp and Columns) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/detect-type") -async def detect_type(session_id: str): - """Detect document type (vocab_table, full_text, generic_table). - - Should be called after crop (clean image available). - Falls back to dewarped if crop was skipped. - Stores result in session for frontend to decide pipeline flow. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") - - t0 = time.time() - ocr_img = create_ocr_image(img_bgr) - result = detect_document_type(ocr_img, img_bgr) - duration = time.time() - t0 - - result_dict = { - "doc_type": result.doc_type, - "confidence": result.confidence, - "pipeline": result.pipeline, - "skip_steps": result.skip_steps, - "features": result.features, - "duration_seconds": round(duration, 2), - } - - # Persist to DB - await update_session_db( - session_id, - doc_type=result.doc_type, - doc_type_result=result_dict, - ) - - cached["doc_type_result"] = result_dict - - logger.info(f"OCR Pipeline: detect-type session {session_id}: " - f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") - - await _append_pipeline_log(session_id, "detect_type", { - "doc_type": result.doc_type, - "pipeline": result.pipeline, - "confidence": result.confidence, - **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **result_dict} +# Backward-compat shim -- module moved to ocr/pipeline/sessions_images.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.sessions_images") diff --git a/klausur-service/backend/ocr_pipeline_structure.py b/klausur-service/backend/ocr_pipeline_structure.py index 77a5b1b..5c697df 100644 --- a/klausur-service/backend/ocr_pipeline_structure.py +++ b/klausur-service/backend/ocr_pipeline_structure.py @@ -1,299 +1,4 @@ -""" -OCR Pipeline Structure Detection and Exclude Regions - -Detect document structure (boxes, zones, color regions, graphics) -and manage user-drawn exclude regions. -Extracted from ocr_pipeline_geometry.py for file-size compliance. -""" - -import logging -import time -from typing import Any, Dict, List - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from cv_box_detect import detect_boxes -from cv_color_detect import _COLOR_RANGES, _COLOR_HEX -from cv_graphic_detect import detect_graphic_elements -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _filter_border_ghost_words, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Structure Detection Endpoint -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/detect-structure") -async def detect_structure(session_id: str): - """Detect document structure: boxes, zones, and color regions. - - Runs box detection (line + shading) and color analysis on the cropped - image. Returns structured JSON with all detected elements for the - structure visualization step. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = ( - cached.get("cropped_bgr") - if cached.get("cropped_bgr") is not None - else cached.get("dewarped_bgr") - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") - - t0 = time.time() - h, w = img_bgr.shape[:2] - - # --- Content bounds from word result (if available) or full image --- - word_result = cached.get("word_result") - words: List[Dict] = [] - if word_result and word_result.get("cells"): - for cell in word_result["cells"]: - for wb in (cell.get("word_boxes") or []): - words.append(wb) - # Fallback: use raw OCR words if cell word_boxes are empty - if not words and word_result: - for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"): - raw = word_result.get(key, []) - if raw: - words = raw - logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key) - break - # If no words yet, use image dimensions with small margin - if words: - content_x = max(0, min(int(wb["left"]) for wb in words)) - content_y = max(0, min(int(wb["top"]) for wb in words)) - content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words)) - content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words)) - content_w_px = content_r - content_x - content_h_px = content_b - content_y - else: - margin = int(min(w, h) * 0.03) - content_x, content_y = margin, margin - content_w_px = w - 2 * margin - content_h_px = h - 2 * margin - - # --- Box detection --- - boxes = detect_boxes( - img_bgr, - content_x=content_x, - content_w=content_w_px, - content_y=content_y, - content_h=content_h_px, - ) - - # --- Zone splitting --- - from cv_box_detect import split_page_into_zones as _split_zones - zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes) - - # --- Color region sampling --- - # Sample background shading in each detected box - hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) - box_colors = [] - for box in boxes: - # Sample the center region of each box - cy1 = box.y + box.height // 4 - cy2 = box.y + 3 * box.height // 4 - cx1 = box.x + box.width // 4 - cx2 = box.x + 3 * box.width // 4 - cy1 = max(0, min(cy1, h - 1)) - cy2 = max(0, min(cy2, h - 1)) - cx1 = max(0, min(cx1, w - 1)) - cx2 = max(0, min(cx2, w - 1)) - if cy2 > cy1 and cx2 > cx1: - roi_hsv = hsv[cy1:cy2, cx1:cx2] - med_h = float(np.median(roi_hsv[:, :, 0])) - med_s = float(np.median(roi_hsv[:, :, 1])) - med_v = float(np.median(roi_hsv[:, :, 2])) - if med_s > 15: - from cv_color_detect import _hue_to_color_name - bg_name = _hue_to_color_name(med_h) - bg_hex = _COLOR_HEX.get(bg_name, "#6b7280") - else: - bg_name = "gray" if med_v < 220 else "white" - bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff" - else: - bg_name = "unknown" - bg_hex = "#6b7280" - box_colors.append({"color_name": bg_name, "color_hex": bg_hex}) - - # --- Color text detection overview --- - # Quick scan for colored text regions across the page - color_summary: Dict[str, int] = {} - for color_name, ranges in _COLOR_RANGES.items(): - mask = np.zeros((h, w), dtype=np.uint8) - for lower, upper in ranges: - mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) - pixel_count = int(np.sum(mask > 0)) - if pixel_count > 50: # minimum threshold - color_summary[color_name] = pixel_count - - # --- Graphic element detection --- - box_dicts = [ - {"x": b.x, "y": b.y, "w": b.width, "h": b.height} - for b in boxes - ] - graphics = detect_graphic_elements( - img_bgr, words, - detected_boxes=box_dicts, - ) - - # --- Filter border-ghost words from OCR result --- - ghost_count = 0 - if boxes and word_result: - ghost_count = _filter_border_ghost_words(word_result, boxes) - if ghost_count: - logger.info("detect-structure: removed %d border-ghost words", ghost_count) - await update_session_db(session_id, word_result=word_result) - cached["word_result"] = word_result - - duration = time.time() - t0 - - # Preserve user-drawn exclude regions from previous run - prev_sr = cached.get("structure_result") or {} - prev_exclude = prev_sr.get("exclude_regions", []) - - result_dict = { - "image_width": w, - "image_height": h, - "content_bounds": { - "x": content_x, "y": content_y, - "w": content_w_px, "h": content_h_px, - }, - "boxes": [ - { - "x": b.x, "y": b.y, "w": b.width, "h": b.height, - "confidence": b.confidence, - "border_thickness": b.border_thickness, - "bg_color_name": box_colors[i]["color_name"], - "bg_color_hex": box_colors[i]["color_hex"], - } - for i, b in enumerate(boxes) - ], - "zones": [ - { - "index": z.index, - "zone_type": z.zone_type, - "y": z.y, "h": z.height, - "x": z.x, "w": z.width, - } - for z in zones - ], - "graphics": [ - { - "x": g.x, "y": g.y, "w": g.width, "h": g.height, - "area": g.area, - "shape": g.shape, - "color_name": g.color_name, - "color_hex": g.color_hex, - "confidence": round(g.confidence, 2), - } - for g in graphics - ], - "exclude_regions": prev_exclude, - "color_pixel_counts": color_summary, - "has_words": len(words) > 0, - "word_count": len(words), - "border_ghosts_removed": ghost_count, - "duration_seconds": round(duration, 2), - } - - # Persist to session - await update_session_db(session_id, structure_result=result_dict) - cached["structure_result"] = result_dict - - logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs", - session_id, len(boxes), len(zones), len(graphics), duration) - - return {"session_id": session_id, **result_dict} - - -# --------------------------------------------------------------------------- -# Exclude Regions -- user-drawn rectangles to exclude from OCR results -# --------------------------------------------------------------------------- - -class _ExcludeRegionIn(BaseModel): - x: int - y: int - w: int - h: int - label: str = "" - - -class _ExcludeRegionsBatchIn(BaseModel): - regions: list[_ExcludeRegionIn] - - -@router.put("/sessions/{session_id}/exclude-regions") -async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn): - """Replace all exclude regions for a session. - - Regions are stored inside ``structure_result.exclude_regions``. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - sr = session.get("structure_result") or {} - sr["exclude_regions"] = [r.model_dump() for r in body.regions] - - # Invalidate grid so it rebuilds with new exclude regions - await update_session_db(session_id, structure_result=sr, grid_editor_result=None) - - # Update cache - if session_id in _cache: - _cache[session_id]["structure_result"] = sr - _cache[session_id].pop("grid_editor_result", None) - - return { - "session_id": session_id, - "exclude_regions": sr["exclude_regions"], - "count": len(sr["exclude_regions"]), - } - - -@router.delete("/sessions/{session_id}/exclude-regions/{region_index}") -async def delete_exclude_region(session_id: str, region_index: int): - """Remove a single exclude region by index.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - sr = session.get("structure_result") or {} - regions = sr.get("exclude_regions", []) - - if region_index < 0 or region_index >= len(regions): - raise HTTPException(status_code=404, detail="Region index out of range") - - removed = regions.pop(region_index) - sr["exclude_regions"] = regions - - # Invalidate grid so it rebuilds with new exclude regions - await update_session_db(session_id, structure_result=sr, grid_editor_result=None) - - if session_id in _cache: - _cache[session_id]["structure_result"] = sr - _cache[session_id].pop("grid_editor_result", None) - - return { - "session_id": session_id, - "removed": removed, - "remaining": len(regions), - } +# Backward-compat shim -- module moved to ocr/pipeline/structure.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.structure") diff --git a/klausur-service/backend/ocr_pipeline_validation.py b/klausur-service/backend/ocr_pipeline_validation.py index 3382a3f..08187e8 100644 --- a/klausur-service/backend/ocr_pipeline_validation.py +++ b/klausur-service/backend/ocr_pipeline_validation.py @@ -1,362 +1,4 @@ -""" -OCR Pipeline Validation — image detection, generation, validation save, -and handwriting removal endpoints. - -Extracted from ocr_pipeline_postprocess.py. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import os -from datetime import datetime -from typing import Optional - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from ocr_pipeline_session_store import ( - get_session_db, - get_session_image, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - RemoveHandwritingRequest, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - -# --------------------------------------------------------------------------- -# Pydantic Models -# --------------------------------------------------------------------------- - -STYLE_SUFFIXES = { - "educational": "educational illustration, textbook style, clear, colorful", - "cartoon": "cartoon, child-friendly, simple shapes", - "sketch": "pencil sketch, hand-drawn, black and white", - "clipart": "clipart, flat vector style, simple", - "realistic": "photorealistic, high detail", -} - - -class ValidationRequest(BaseModel): - notes: Optional[str] = None - score: Optional[int] = None - - -class GenerateImageRequest(BaseModel): - region_index: int - prompt: str - style: str = "educational" - - -# --------------------------------------------------------------------------- -# Image detection + generation -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reconstruction/detect-images") -async def detect_image_regions(session_id: str): - """Detect illustration/image regions in the original scan using VLM.""" - import base64 - import httpx - import re - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=400, detail="No original image found") - - word_result = session.get("word_result") or {} - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - vocab_context = "" - if entries: - sample = entries[:10] - words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] - if words: - vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "Analyze this scanned page. Find ALL illustration/image/picture regions " - "(NOT text, NOT table cells, NOT blank areas). " - "For each image region found, return its bounding box as percentage of page dimensions " - "and a short English description of what the image shows. " - "Reply with ONLY a JSON array like: " - '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' - "where x, y, w, h are percentages (0-100) of the page width/height. " - "If there are NO images on the page, return an empty array: []" - f"{vocab_context}" - ) - - img_b64 = base64.b64encode(original_png).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - match = re.search(r'\[.*?\]', text, re.DOTALL) - if match: - raw_regions = json.loads(match.group(0)) - else: - raw_regions = [] - - regions = [] - for r in raw_regions: - regions.append({ - "bbox_pct": { - "x": max(0, min(100, float(r.get("x", 0)))), - "y": max(0, min(100, float(r.get("y", 0)))), - "w": max(1, min(100, float(r.get("w", 10)))), - "h": max(1, min(100, float(r.get("h", 10)))), - }, - "description": r.get("description", ""), - "prompt": r.get("description", ""), - "image_b64": None, - "style": "educational", - }) - - # Enrich prompts with nearby vocab context - if entries: - for region in regions: - ry = region["bbox_pct"]["y"] - rh = region["bbox_pct"]["h"] - nearby = [ - e for e in entries - if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 - ] - if nearby: - en_words = [e.get("english", "") for e in nearby if e.get("english")] - de_words = [e.get("german", "") for e in nearby if e.get("german")] - if en_words or de_words: - context = f" (vocabulary context: {', '.join(en_words[:5])}" - if de_words: - context += f" / {', '.join(de_words[:5])}" - context += ")" - region["prompt"] = region["description"] + context - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["image_regions"] = regions - validation["detected_at"] = datetime.utcnow().isoformat() - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Detected {len(regions)} image regions for session {session_id}") - - return {"regions": regions, "count": len(regions)} - - except httpx.ConnectError: - logger.warning(f"VLM not available at {ollama_base} for image detection") - return {"regions": [], "count": 0, "error": "VLM not available"} - except Exception as e: - logger.error(f"Image detection failed for {session_id}: {e}") - return {"regions": [], "count": 0, "error": str(e)} - - -@router.post("/sessions/{session_id}/reconstruction/generate-image") -async def generate_image_for_region(session_id: str, req: GenerateImageRequest): - """Generate a replacement image for a detected region using mflux.""" - import httpx - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - regions = validation.get("image_regions") or [] - - if req.region_index < 0 or req.region_index >= len(regions): - raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") - - mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") - style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) - full_prompt = f"{req.prompt}, {style_suffix}" - - region = regions[req.region_index] - bbox = region["bbox_pct"] - aspect = bbox["w"] / max(bbox["h"], 1) - if aspect > 1.3: - width, height = 768, 512 - elif aspect < 0.7: - width, height = 512, 768 - else: - width, height = 512, 512 - - try: - async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post(f"{mflux_url}/generate", json={ - "prompt": full_prompt, - "width": width, - "height": height, - "steps": 4, - }) - resp.raise_for_status() - data = resp.json() - image_b64 = data.get("image_b64") - - if not image_b64: - return {"image_b64": None, "success": False, "error": "No image returned"} - - regions[req.region_index]["image_b64"] = image_b64 - regions[req.region_index]["prompt"] = req.prompt - regions[req.region_index]["style"] = req.style - validation["image_regions"] = regions - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Generated image for session {session_id} region {req.region_index}") - return {"image_b64": image_b64, "success": True} - - except httpx.ConnectError: - logger.warning(f"mflux-service not available at {mflux_url}") - return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} - except Exception as e: - logger.error(f"Image generation failed for {session_id}: {e}") - return {"image_b64": None, "success": False, "error": str(e)} - - -# --------------------------------------------------------------------------- -# Validation save/get -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reconstruction/validate") -async def save_validation(session_id: str, req: ValidationRequest): - """Save final validation results for step 8.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["validated_at"] = datetime.utcnow().isoformat() - validation["notes"] = req.notes - validation["score"] = req.score - ground_truth["validation"] = validation - - await update_session_db(session_id, ground_truth=ground_truth, current_step=11) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Validation saved for session {session_id}: score={req.score}") - - return {"session_id": session_id, "validation": validation} - - -@router.get("/sessions/{session_id}/reconstruction/validation") -async def get_validation(session_id: str): - """Retrieve saved validation data for step 8.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") - - return { - "session_id": session_id, - "validation": validation, - "word_result": session.get("word_result"), - } - - -# --------------------------------------------------------------------------- -# Remove handwriting -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/remove-handwriting") -async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): - """Remove handwriting from a session image using inpainting.""" - import time as _time - - from services.handwriting_detection import detect_handwriting - from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - t0 = _time.monotonic() - - # 1. Determine source image - source = req.use_source - if source == "auto": - deskewed = await get_session_image(session_id, "deskewed") - source = "deskewed" if deskewed else "original" - - image_bytes = await get_session_image(session_id, source) - if not image_bytes: - raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") - - # 2. Detect handwriting mask - detection = detect_handwriting(image_bytes, target_ink=req.target_ink) - - # 3. Convert mask to PNG bytes and dilate - import io - from PIL import Image as _PILImage - mask_img = _PILImage.fromarray(detection.mask) - mask_buf = io.BytesIO() - mask_img.save(mask_buf, format="PNG") - mask_bytes = mask_buf.getvalue() - - if req.dilation > 0: - mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) - - # 4. Inpaint - method_map = { - "telea": InpaintingMethod.OPENCV_TELEA, - "ns": InpaintingMethod.OPENCV_NS, - "auto": InpaintingMethod.AUTO, - } - inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) - - result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) - if not result.success: - raise HTTPException(status_code=500, detail="Inpainting failed") - - elapsed_ms = int((_time.monotonic() - t0) * 1000) - - meta = { - "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), - "handwriting_ratio": round(detection.handwriting_ratio, 4), - "detection_confidence": round(detection.confidence, 4), - "target_ink": req.target_ink, - "dilation": req.dilation, - "source_image": source, - "processing_time_ms": elapsed_ms, - } - - # 5. Persist clean image - clean_png_bytes = image_to_png(result.image) - await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) - - return { - **meta, - "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", - "session_id": session_id, - } +# Backward-compat shim -- module moved to ocr/pipeline/validation.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.validation") diff --git a/klausur-service/backend/ocr_pipeline_words.py b/klausur-service/backend/ocr_pipeline_words.py index a1d0f87..189962d 100644 --- a/klausur-service/backend/ocr_pipeline_words.py +++ b/klausur-service/backend/ocr_pipeline_words.py @@ -1,185 +1,4 @@ -""" -OCR Pipeline Words — composite router for word detection, PaddleOCR direct, -and ground truth endpoints. - -Split into sub-modules: - ocr_pipeline_words_detect — main detect_words endpoint (Step 7) - ocr_pipeline_words_stream — SSE streaming generators - -This barrel module contains the PaddleOCR direct endpoint and ground truth -endpoints, and assembles all word-related routers. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import logging -import time -from datetime import datetime -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from cv_words_first import build_grid_from_words -from ocr_pipeline_session_store import ( - get_session_db, - get_session_image, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _append_pipeline_log, -) -from ocr_pipeline_words_detect import router as _detect_router - -logger = logging.getLogger(__name__) - -_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Pydantic models -# --------------------------------------------------------------------------- - -class WordGroundTruthRequest(BaseModel): - is_correct: bool - corrected_entries: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -# --------------------------------------------------------------------------- -# PaddleOCR Direct Endpoint -# --------------------------------------------------------------------------- - -@_local_router.post("/sessions/{session_id}/paddle-direct") -async def paddle_direct(session_id: str): - """Run PaddleOCR on the preprocessed image and build a word grid directly.""" - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode original image") - - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_paddle - - t0 = time.time() - word_dicts = await ocr_region_paddle(img_bgr, region=None) - if not word_dicts: - raise HTTPException(status_code=400, detail="PaddleOCR returned no words") - - cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h) - duration = time.time() - t0 - - for cell in cells: - cell["ocr_engine"] = "paddle_direct" - - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "paddle_direct", - "grid_method": "paddle_direct", - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, - ) - - logger.info( - "paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs", - session_id, len(cells), n_rows, n_cols, duration, - ) - - await _append_pipeline_log(session_id, "paddle_direct", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "ocr_engine": "paddle_direct", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -# --------------------------------------------------------------------------- -# Ground Truth Words Endpoints -# --------------------------------------------------------------------------- - -@_local_router.post("/sessions/{session_id}/ground-truth/words") -async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): - """Save ground truth feedback for the word recognition step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_entries": req.corrected_entries, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "word_result": session.get("word_result"), - } - ground_truth["words"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@_local_router.get("/sessions/{session_id}/ground-truth/words") -async def get_word_ground_truth(session_id: str): - """Retrieve saved ground truth for word recognition.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - words_gt = ground_truth.get("words") - if not words_gt: - raise HTTPException(status_code=404, detail="No word ground truth saved") - - return { - "session_id": session_id, - "words_gt": words_gt, - "words_auto": session.get("word_result"), - } - - -# --------------------------------------------------------------------------- -# Composite router -# --------------------------------------------------------------------------- - -router = APIRouter() -router.include_router(_detect_router) -router.include_router(_local_router) +# Backward-compat shim -- module moved to ocr/pipeline/words.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words") diff --git a/klausur-service/backend/ocr_pipeline_words_detect.py b/klausur-service/backend/ocr_pipeline_words_detect.py index b70cff3..6824efe 100644 --- a/klausur-service/backend/ocr_pipeline_words_detect.py +++ b/klausur-service/backend/ocr_pipeline_words_detect.py @@ -1,393 +1,4 @@ -""" -OCR Pipeline Words Detect — main word detection endpoint (Step 7). - -Extracted from ocr_pipeline_words.py. Contains the ``detect_words`` -endpoint which handles both v2 and words_first grid methods. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import time -from typing import Any, Dict, List - -import numpy as np -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse - -from cv_vocab_pipeline import ( - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _fix_phonetic_brackets, - fix_cell_phonetics, - build_cell_grid_v2, - create_ocr_image, - detect_column_geometry, -) -from cv_words_first import build_grid_from_words -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _append_pipeline_log, -) -from ocr_pipeline_words_stream import ( - _word_batch_stream_generator, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Word Detection Endpoint (Step 7) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/words") -async def detect_words( - session_id: str, - request: Request, - engine: str = "auto", - pronunciation: str = "british", - stream: bool = False, - skip_heal_gaps: bool = False, - grid_method: str = "v2", -): - """Build word grid from columns x rows, OCR each cell. - - Query params: - engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' - pronunciation: 'british' (default) or 'american' - stream: false (default) for JSON response, true for SSE streaming - skip_heal_gaps: false (default). When true, cells keep exact row geometry. - grid_method: 'v2' (default) or 'words_first' - """ - # PaddleOCR is full-page remote OCR -> force words_first grid method - if engine == "paddle" and grid_method != "words_first": - logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) - grid_method = "words_first" - - if session_id not in _cache: - logger.info("detect_words: session %s not in cache, loading from DB", session_id) - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if dewarped_bgr is None: - logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", - session_id, [k for k in cached.keys() if k.endswith('_bgr')]) - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - row_result = session.get("row_result") - if not column_result or not column_result.get("columns"): - img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] - column_result = { - "columns": [{ - "type": "column_text", - "x": 0, "y": 0, - "width": img_w_tmp, "height": img_h_tmp, - "classification_confidence": 1.0, - "classification_method": "full_page_fallback", - }], - "zones": [], - "duration_seconds": 0, - } - logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) - if grid_method != "words_first" and (not row_result or not row_result.get("rows")): - raise HTTPException(status_code=400, detail="Row detection must be completed first") - - # Convert column dicts back to PageRegion objects - col_regions = [ - PageRegion( - type=c["type"], - x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - # Convert row dicts back to RowGeometry objects - row_geoms = [ - RowGeometry( - index=r["index"], - x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), - words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - # Populate word counts from cached words - word_dicts = cached.get("_word_dicts") - if word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if word_dicts: - content_bounds = cached.get("_content_bounds") - if content_bounds: - _lx, _rx, top_y, _by = content_bounds - else: - top_y = min(r.y for r in row_geoms) if row_geoms else 0 - - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - # Exclude rows that fall within box zones - zones = column_result.get("zones") or [] - box_ranges_inner = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bt = max(box.get("border_thickness", 0), 5) - box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) - - if box_ranges_inner: - def _row_in_box(r): - center_y = r.y + r.height / 2 - return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) - - before_count = len(row_geoms) - row_geoms = [r for r in row_geoms if not _row_in_box(r)] - excluded = before_count - len(row_geoms) - if excluded: - logger.info(f"detect_words: excluded {excluded} rows inside box zones") - - # --- Words-First path --- - if grid_method == "words_first": - return await _words_first_path( - session_id, cached, dewarped_bgr, engine, pronunciation, zones, - ) - - if stream: - return StreamingResponse( - _word_batch_stream_generator( - session_id, cached, col_regions, row_geoms, - dewarped_bgr, engine, pronunciation, request, - skip_heal_gaps=skip_heal_gaps, - ), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - # --- Non-streaming path (grid_method=v2) --- - return await _v2_path( - session_id, cached, col_regions, row_geoms, - dewarped_bgr, engine, pronunciation, skip_heal_gaps, - ) - - -async def _words_first_path( - session_id: str, - cached: Dict[str, Any], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - zones: list, -) -> dict: - """Words-first grid construction path.""" - t0 = time.time() - img_h, img_w = dewarped_bgr.shape[:2] - - if engine == "paddle": - from cv_ocr_engines import ocr_region_paddle - wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) - cached["_paddle_word_dicts"] = wf_word_dicts - else: - wf_word_dicts = cached.get("_word_dicts") - if wf_word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result - cached["_word_dicts"] = wf_word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if not wf_word_dicts: - raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid") - - # Convert word coordinates to absolute if needed - if engine != "paddle": - content_bounds = cached.get("_content_bounds") - if content_bounds: - lx, _rx, ty, _by = content_bounds - abs_words = [] - for w in wf_word_dicts: - abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty}) - wf_word_dicts = abs_words - - box_rects = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box_rects.append(zone["box"]) - - cells, columns_meta = build_grid_from_words( - wf_word_dicts, img_w, img_h, box_rects=box_rects or None, - ) - duration = time.time() - t0 - - fix_cell_phonetics(cells, pronunciation=pronunciation) - for cell in cells: - cell.setdefault("zone_index", 0) - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - used_engine = "paddle" if engine == "paddle" else "words_first" - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "grid_method": "words_first", - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - if is_vocab or 'column_text' in col_types: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words-first session {session_id}: " - f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") - - await _append_pipeline_log(session_id, "words", { - "grid_method": "words_first", - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -async def _v2_path( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - skip_heal_gaps: bool, -) -> dict: - """Cell-First OCR v2 non-streaming path.""" - t0 = time.time() - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - cells, columns_meta = build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ) - duration = time.time() - t0 - - for cell in cells: - cell.setdefault("zone_index", 0) - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len(columns_meta) - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - fix_cell_phonetics(cells, pronunciation=pronunciation) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") - - await _append_pipeline_log(session_id, "words", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "low_confidence_count": word_result["summary"]["low_confidence"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - "entry_count": word_result.get("entry_count", 0), - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} +# Backward-compat shim -- module moved to ocr/pipeline/words_detect.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_detect") diff --git a/klausur-service/backend/ocr_pipeline_words_stream.py b/klausur-service/backend/ocr_pipeline_words_stream.py index bb7d990..ff8c452 100644 --- a/klausur-service/backend/ocr_pipeline_words_stream.py +++ b/klausur-service/backend/ocr_pipeline_words_stream.py @@ -1,303 +1,4 @@ -""" -OCR Pipeline Words Stream — SSE streaming generators for word detection. - -Extracted from ocr_pipeline_words.py. - -Lizenz: Apache 2.0 -DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. -""" - -import json -import logging -import time -from typing import Any, Dict, List - -import numpy as np -from fastapi import Request - -from cv_vocab_pipeline import ( - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _fix_character_confusion, - _fix_phonetic_brackets, - fix_cell_phonetics, - build_cell_grid_v2, - build_cell_grid_v2_streaming, - create_ocr_image, -) -from ocr_pipeline_session_store import update_session_db -from ocr_pipeline_common import _cache - -logger = logging.getLogger(__name__) - - -async def _word_batch_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, - skip_heal_gaps: bool = False, -): - """SSE generator that runs batch OCR (parallel) then streams results. - - Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR, - then emits all cells as SSE events. - """ - import asyncio - - t0 = time.time() - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - total_cells = n_content_rows * n_cols - - # 1. Send meta event immediately - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - # 2. Send preparing event (keepalive for proxy) - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" - - # 3. Run batch OCR in thread pool with periodic keepalive events. - loop = asyncio.get_event_loop() - ocr_future = loop.run_in_executor( - None, - lambda: build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ), - ) - - # Send keepalive events every 5 seconds while OCR runs - keepalive_count = 0 - while not ocr_future.done(): - try: - cells, columns_meta = await asyncio.wait_for( - asyncio.shield(ocr_future), timeout=5.0, - ) - break # OCR finished - except asyncio.TimeoutError: - keepalive_count += 1 - elapsed = int(time.time() - t0) - yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected during OCR for {session_id}") - ocr_future.cancel() - return - else: - cells, columns_meta = ocr_future.result() - - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected after OCR for {session_id}") - return - - # 4. Apply IPA phonetic fixes - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # 5. Send columns meta - if columns_meta: - yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" - - # 6. Stream all cells - for idx, cell in enumerate(cells): - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": idx + 1, "total": len(cells)}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # 7. Build final result and persist - duration = time.time() - t0 - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " - f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") - - # 8. Send complete event - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" - - -async def _word_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, -): - """SSE generator that yields cell-by-cell OCR progress.""" - t0 = time.time() - - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - - columns_meta = None - total_cells = n_content_rows * n_cols - - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" - - all_cells: List[Dict[str, Any]] = [] - cell_idx = 0 - last_keepalive = time.time() - - for cell, cols_meta, total in build_cell_grid_v2_streaming( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - ): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during streaming for {session_id}") - return - - if columns_meta is None: - columns_meta = cols_meta - meta_update = {"type": "columns", "columns_used": cols_meta} - yield f"data: {json.dumps(meta_update)}\n\n" - - all_cells.append(cell) - cell_idx += 1 - - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": cell_idx, "total": total}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # All cells done - duration = time.time() - t0 - if columns_meta is None: - columns_meta = [] - - # Remove all-empty rows - rows_with_text: set = set() - for c in all_cells: - if c.get("text", "").strip(): - rows_with_text.add(c["row_index"]) - before_filter = len(all_cells) - all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] - empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) - if empty_rows_removed > 0: - logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") - - used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine - - fix_cell_phonetics(all_cells, pronunciation=pronunciation) - - word_result = { - "cells": all_cells, - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(all_cells), - "non_empty_cells": sum(1 for c in all_cells if c.get("text")), - "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), - }, - } - - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(all_cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(all_cells)} cells ({duration:.2f}s)") - - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" +# Backward-compat shim -- module moved to ocr/pipeline/words_stream.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.words_stream") diff --git a/klausur-service/backend/orientation_api.py b/klausur-service/backend/orientation_api.py index f2b322b..75ba362 100644 --- a/klausur-service/backend/orientation_api.py +++ b/klausur-service/backend/orientation_api.py @@ -1,188 +1,4 @@ -""" -Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline). -""" - -import logging -import time -from typing import Any, Dict - -import cv2 -from fastapi import APIRouter, HTTPException - -from cv_vocab_pipeline import detect_and_fix_orientation -from page_crop import detect_page_splits -from ocr_pipeline_session_store import update_session_db - -from orientation_crop_helpers import ensure_cached, append_pipeline_log -from page_sub_sessions import create_page_sub_sessions_full - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Step 1: Orientation -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/orientation") -async def detect_orientation(session_id: str): - """Detect and fix 90/180/270 degree rotations from scanners. - - Reads the original image, applies orientation correction, - stores the result as oriented_png. - """ - cached = await ensure_cached(session_id) - - img_bgr = cached.get("original_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Original image not available") - - t0 = time.time() - - # Detect and fix orientation - oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy()) - - duration = time.time() - t0 - - orientation_result = { - "orientation_degrees": orientation_deg, - "corrected": orientation_deg != 0, - "duration_seconds": round(duration, 2), - } - - # Encode oriented image - success, png_buf = cv2.imencode(".png", oriented_bgr) - oriented_png = png_buf.tobytes() if success else b"" - - # Update cache - cached["oriented_bgr"] = oriented_bgr - cached["orientation_result"] = orientation_result - - # Persist to DB - await update_session_db( - session_id, - oriented_png=oriented_png, - orientation_result=orientation_result, - current_step=2, - ) - - logger.info( - "OCR Pipeline: orientation session %s: %d° (%s) in %.2fs", - session_id, orientation_deg, - "corrected" if orientation_deg else "no change", - duration, - ) - - await append_pipeline_log(session_id, "orientation", { - "orientation_degrees": orientation_deg, - "corrected": orientation_deg != 0, - }, duration_ms=int(duration * 1000)) - - h, w = oriented_bgr.shape[:2] - return { - "session_id": session_id, - **orientation_result, - "image_width": w, - "image_height": h, - "oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented", - } - - -# --------------------------------------------------------------------------- -# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/page-split") -async def detect_page_split(session_id: str): - """Detect if the image is a double-page book spread and split into sub-sessions. - - Must be called **after orientation** (step 1) and **before deskew** (step 2). - Each sub-session receives the raw page region and goes through the full - pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid) - independently, so each page gets its own deskew correction. - - Returns ``{"multi_page": false}`` if only one page is detected. - """ - cached = await ensure_cached(session_id) - - # Use oriented (preferred), fall back to original - img_bgr = next( - (v for k in ("oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), - None, - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for page-split detection") - - t0 = time.time() - page_splits = detect_page_splits(img_bgr) - used_original = False - - if not page_splits or len(page_splits) < 2: - # Orientation may have rotated a landscape double-page spread to - # portrait. Try the original (pre-orientation) image as fallback. - orig_bgr = cached.get("original_bgr") - if orig_bgr is not None and orig_bgr is not img_bgr: - page_splits_orig = detect_page_splits(orig_bgr) - if page_splits_orig and len(page_splits_orig) >= 2: - logger.info( - "OCR Pipeline: page-split session %s: spread detected on " - "ORIGINAL (orientation rotated it away)", - session_id, - ) - img_bgr = orig_bgr - page_splits = page_splits_orig - used_original = True - - if not page_splits or len(page_splits) < 2: - duration = time.time() - t0 - logger.info( - "OCR Pipeline: page-split session %s: single page (%.2fs)", - session_id, duration, - ) - return { - "session_id": session_id, - "multi_page": False, - "duration_seconds": round(duration, 2), - } - - # Multi-page spread detected — create sub-sessions for full pipeline. - # start_step=2 means "ready for deskew" (orientation already applied). - # start_step=1 means "needs orientation too" (split from original image). - start_step = 1 if used_original else 2 - sub_sessions = await create_page_sub_sessions_full( - session_id, cached, img_bgr, page_splits, start_step=start_step, - ) - duration = time.time() - t0 - - split_info: Dict[str, Any] = { - "multi_page": True, - "page_count": len(page_splits), - "page_splits": page_splits, - "used_original": used_original, - "duration_seconds": round(duration, 2), - } - - # Mark parent session as split and hidden from session list - await update_session_db(session_id, crop_result=split_info, status='split') - cached["crop_result"] = split_info - - await append_pipeline_log(session_id, "page_split", { - "multi_page": True, - "page_count": len(page_splits), - }, duration_ms=int(duration * 1000)) - - logger.info( - "OCR Pipeline: page-split session %s: %d pages detected in %.2fs", - session_id, len(page_splits), duration, - ) - - h, w = img_bgr.shape[:2] - return { - "session_id": session_id, - **split_info, - "image_width": w, - "image_height": h, - "sub_sessions": sub_sessions, - } +# Backward-compat shim -- module moved to ocr/pipeline/orientation_api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_api") diff --git a/klausur-service/backend/orientation_crop_api.py b/klausur-service/backend/orientation_crop_api.py index 7eefa55..a22db49 100644 --- a/klausur-service/backend/orientation_crop_api.py +++ b/klausur-service/backend/orientation_crop_api.py @@ -1,16 +1,4 @@ -""" -Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline. - -Barrel re-export: merges routers from orientation_api and crop_api, -and re-exports set_cache_ref for main.py. -""" - -from fastapi import APIRouter - -from orientation_crop_helpers import set_cache_ref # noqa: F401 -from orientation_api import router as _orientation_router -from crop_api import router as _crop_router - -router = APIRouter() -router.include_router(_orientation_router) -router.include_router(_crop_router) +# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_api.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_api") diff --git a/klausur-service/backend/orientation_crop_helpers.py b/klausur-service/backend/orientation_crop_helpers.py index dee66eb..3a0ab6b 100644 --- a/klausur-service/backend/orientation_crop_helpers.py +++ b/klausur-service/backend/orientation_crop_helpers.py @@ -1,86 +1,4 @@ -""" -Orientation & Crop shared helpers - cache management and pipeline logging. -""" - -import logging -from typing import Any, Dict - -import cv2 -import numpy as np -from fastapi import HTTPException - -from ocr_pipeline_session_store import ( - get_session_db, - get_session_image, - update_session_db, -) - -logger = logging.getLogger(__name__) - - -# Reference to the shared cache from ocr_pipeline_api (set in main.py) -_cache: Dict[str, Dict[str, Any]] = {} - - -def set_cache_ref(cache: Dict[str, Dict[str, Any]]): - """Set reference to the shared cache from ocr_pipeline_api.""" - global _cache - _cache = cache - - -def get_cache_ref() -> Dict[str, Dict[str, Any]]: - """Get reference to the shared cache.""" - return _cache - - -async def ensure_cached(session_id: str) -> Dict[str, Any]: - """Ensure session is in cache, loading from DB if needed.""" - if session_id in _cache: - return _cache[session_id] - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - cache_entry: Dict[str, Any] = { - "id": session_id, - **session, - "original_bgr": None, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - } - - for img_type, bgr_key in [ - ("original", "original_bgr"), - ("oriented", "oriented_bgr"), - ("cropped", "cropped_bgr"), - ("deskewed", "deskewed_bgr"), - ("dewarped", "dewarped_bgr"), - ]: - png_data = await get_session_image(session_id, img_type) - if png_data: - arr = np.frombuffer(png_data, dtype=np.uint8) - bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) - cache_entry[bgr_key] = bgr - - _cache[session_id] = cache_entry - return cache_entry - - -async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int): - """Append a step entry to the pipeline log.""" - from datetime import datetime - session = await get_session_db(session_id) - if not session: - return - pipeline_log = session.get("pipeline_log") or {"steps": []} - pipeline_log["steps"].append({ - "step": step, - "completed_at": datetime.utcnow().isoformat(), - "success": True, - "duration_ms": duration_ms, - "metrics": metrics, - }) - await update_session_db(session_id, pipeline_log=pipeline_log) +# Backward-compat shim -- module moved to ocr/pipeline/orientation_crop_helpers.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.orientation_crop_helpers") diff --git a/klausur-service/backend/page_crop.py b/klausur-service/backend/page_crop.py index ca4a8d0..de0fa82 100644 --- a/klausur-service/backend/page_crop.py +++ b/klausur-service/backend/page_crop.py @@ -1,33 +1,4 @@ -""" -Page Crop — Barrel Re-export - -Content-based crop for scanned pages and book scans. - -Split into: -- page_crop_edges.py — Edge detection (spine shadow, gutter, projection) -- page_crop_core.py — Main crop algorithm and format detection - -All public names are re-exported here for backward compatibility. -License: Apache 2.0 -""" - -# Core: main crop functions and format detection -from page_crop_core import ( # noqa: F401 - PAPER_FORMATS, - detect_page_splits, - detect_and_crop_page, - _detect_format, -) - -# Edge detection helpers -from page_crop_edges import ( # noqa: F401 - _INK_THRESHOLD, - _MIN_RUN_FRAC, - _detect_spine_shadow, - _detect_gutter_continuity, - _detect_left_edge_shadow, - _detect_right_edge_shadow, - _detect_top_bottom_edges, - _detect_edge_projection, - _filter_narrow_runs, -) +# Backward-compat shim -- module moved to ocr/pipeline/page_crop.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop") diff --git a/klausur-service/backend/page_crop_core.py b/klausur-service/backend/page_crop_core.py index 53c3723..fefcdf5 100644 --- a/klausur-service/backend/page_crop_core.py +++ b/klausur-service/backend/page_crop_core.py @@ -1,342 +1,4 @@ -""" -Page Crop - Core Crop and Format Detection - -Content-based crop for scanned pages and book scans. Detects the content -boundary by analysing ink density projections and (for book scans) the -spine shadow gradient. - -Extracted from page_crop.py to keep files under 500 LOC. -License: Apache 2.0 -""" - -import logging -from typing import Dict, Any, Tuple - -import cv2 -import numpy as np - -from page_crop_edges import ( - _detect_left_edge_shadow, - _detect_right_edge_shadow, - _detect_top_bottom_edges, -) - -logger = logging.getLogger(__name__) - -# Known paper format aspect ratios (height / width, portrait orientation) -PAPER_FORMATS = { - "A4": 297.0 / 210.0, # 1.4143 - "A5": 210.0 / 148.0, # 1.4189 - "Letter": 11.0 / 8.5, # 1.2941 - "Legal": 14.0 / 8.5, # 1.6471 - "A3": 420.0 / 297.0, # 1.4141 -} - - -def detect_page_splits( - img_bgr: np.ndarray, -) -> list: - """Detect if the image is a multi-page spread and return split rectangles. - - Uses **brightness** (not ink density) to find the spine area: - the scanner bed produces a characteristic gray strip where pages meet, - which is darker than the white paper on either side. - - Returns a list of page dicts ``{x, y, width, height, page_index}`` - or an empty list if only one page is detected. - """ - h, w = img_bgr.shape[:2] - - # Only check landscape-ish images (width > height * 1.15) - if w < h * 1.15: - return [] - - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - - # Column-mean brightness (0-255) — the spine is darker (gray scanner bed) - col_brightness = np.mean(gray, axis=0).astype(np.float64) - - # Heavy smoothing to ignore individual text lines - kern = max(11, w // 50) - if kern % 2 == 0: - kern += 1 - brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same") - - # Page paper is bright (typically > 200), spine/scanner bed is darker - page_brightness = float(np.max(brightness_smooth)) - if page_brightness < 100: - return [] # Very dark image, skip - - # Spine threshold: significantly darker than the page - spine_thresh = page_brightness * 0.88 - - # Search in center region (30-70% of width) - center_lo = int(w * 0.30) - center_hi = int(w * 0.70) - - # Find the darkest valley in the center region - center_brightness = brightness_smooth[center_lo:center_hi] - darkest_val = float(np.min(center_brightness)) - - if darkest_val >= spine_thresh: - logger.debug("No spine detected: min brightness %.0f >= threshold %.0f", - darkest_val, spine_thresh) - return [] - - # Find ALL contiguous dark runs in the center region - is_dark = center_brightness < spine_thresh - dark_runs: list = [] - run_start = -1 - for i in range(len(is_dark)): - if is_dark[i]: - if run_start < 0: - run_start = i - else: - if run_start >= 0: - dark_runs.append((run_start, i)) - run_start = -1 - if run_start >= 0: - dark_runs.append((run_start, len(is_dark))) - - # Filter out runs that are too narrow (< 1% of image width) - min_spine_px = int(w * 0.01) - dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px] - - if not dark_runs: - logger.debug("No dark runs wider than %dpx in center region", min_spine_px) - return [] - - # Score each dark run: prefer centered, dark, narrow valleys - center_region_len = center_hi - center_lo - image_center_in_region = (w * 0.5 - center_lo) - best_score = -1.0 - best_start, best_end = dark_runs[0] - - for rs, re in dark_runs: - run_width = re - rs - run_center = (rs + re) / 2.0 - - sigma = center_region_len * 0.15 - dist = abs(run_center - image_center_in_region) - center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2)) - - run_brightness = float(np.mean(center_brightness[rs:re])) - darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh) - - width_frac = run_width / w - if width_frac <= 0.05: - narrowness_bonus = 1.0 - elif width_frac <= 0.15: - narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 - else: - narrowness_bonus = 0.0 - - score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus) - - logger.debug( - "Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f", - center_lo + rs, center_lo + re, run_width, - center_factor, darkness_factor, narrowness_bonus, score, - ) - - if score > best_score: - best_score = score - best_start, best_end = rs, re - - spine_w = best_end - best_start - spine_x = center_lo + best_start - spine_center = spine_x + spine_w // 2 - - logger.debug( - "Best spine candidate: x=%d..%d (w=%d), score=%.4f", - spine_x, spine_x + spine_w, spine_w, best_score, - ) - - # Verify: must have bright (paper) content on BOTH sides - left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x])) - right_end = center_lo + best_end - right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)])) - - if left_brightness < spine_thresh or right_brightness < spine_thresh: - logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f", - left_brightness, right_brightness, spine_thresh) - return [] - - logger.info( - "Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, " - "left_paper=%.0f, right_paper=%.0f", - spine_x, right_end, spine_w, darkest_val, page_brightness, - left_brightness, right_brightness, - ) - - # Split at the spine center - split_points = [spine_center] - - # Build page rectangles - pages: list = [] - prev_x = 0 - for i, sx in enumerate(split_points): - pages.append({"x": prev_x, "y": 0, "width": sx - prev_x, - "height": h, "page_index": i}) - prev_x = sx - pages.append({"x": prev_x, "y": 0, "width": w - prev_x, - "height": h, "page_index": len(split_points)}) - - # Filter out tiny pages (< 15% of total width) - pages = [p for p in pages if p["width"] >= w * 0.15] - if len(pages) < 2: - return [] - - # Re-index - for i, p in enumerate(pages): - p["page_index"] = i - - logger.info( - "Page split detected: %d pages, spine_w=%d, split_points=%s", - len(pages), spine_w, split_points, - ) - return pages - - -def detect_and_crop_page( - img_bgr: np.ndarray, - margin_frac: float = 0.01, -) -> Tuple[np.ndarray, Dict[str, Any]]: - """Detect content boundary and crop scanner/book borders. - - Algorithm (4-edge detection): - 1. Adaptive threshold -> binary (text=255, bg=0) - 2. Left edge: spine-shadow detection via grayscale column means, - fallback to binary vertical projection - 3. Right edge: binary vertical projection (last ink column) - 4. Top/bottom edges: binary horizontal projection - 5. Sanity checks, then crop with configurable margin - - Args: - img_bgr: Input BGR image (should already be deskewed/dewarped) - margin_frac: Extra margin around content (fraction of dimension, default 1%) - - Returns: - Tuple of (cropped_image, result_dict) - """ - h, w = img_bgr.shape[:2] - total_area = h * w - - result: Dict[str, Any] = { - "crop_applied": False, - "crop_rect": None, - "crop_rect_pct": None, - "original_size": {"width": w, "height": h}, - "cropped_size": {"width": w, "height": h}, - "detected_format": None, - "format_confidence": 0.0, - "aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4), - "border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0}, - } - - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - - # --- Binarise with adaptive threshold --- - binary = cv2.adaptiveThreshold( - gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY_INV, blockSize=51, C=15, - ) - - # --- Edge detection --- - left_edge = _detect_left_edge_shadow(gray, binary, w, h) - right_edge = _detect_right_edge_shadow(gray, binary, w, h) - top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h) - - # Compute border fractions - border_top = top_edge / h - border_bottom = (h - bottom_edge) / h - border_left = left_edge / w - border_right = (w - right_edge) / w - - result["border_fractions"] = { - "top": round(border_top, 4), - "bottom": round(border_bottom, 4), - "left": round(border_left, 4), - "right": round(border_right, 4), - } - - # Sanity: only crop if at least one edge has > 2% border - min_border = 0.02 - if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]): - logger.info("All borders < %.0f%% — no crop needed", min_border * 100) - result["detected_format"], result["format_confidence"] = _detect_format(w, h) - return img_bgr, result - - # Add margin - margin_x = int(w * margin_frac) - margin_y = int(h * margin_frac) - - crop_x = max(0, left_edge - margin_x) - crop_y = max(0, top_edge - margin_y) - crop_x2 = min(w, right_edge + margin_x) - crop_y2 = min(h, bottom_edge + margin_y) - - crop_w = crop_x2 - crop_x - crop_h = crop_y2 - crop_y - - # Sanity: cropped area must be >= 40% of original - if crop_w * crop_h < 0.40 * total_area: - logger.warning("Cropped area too small (%.0f%%) — skipping crop", - 100.0 * crop_w * crop_h / total_area) - result["detected_format"], result["format_confidence"] = _detect_format(w, h) - return img_bgr, result - - cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy() - - detected_format, format_confidence = _detect_format(crop_w, crop_h) - - result["crop_applied"] = True - result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h} - result["crop_rect_pct"] = { - "x": round(100.0 * crop_x / w, 2), - "y": round(100.0 * crop_y / h, 2), - "width": round(100.0 * crop_w / w, 2), - "height": round(100.0 * crop_h / h, 2), - } - result["cropped_size"] = {"width": crop_w, "height": crop_h} - result["detected_format"] = detected_format - result["format_confidence"] = format_confidence - result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4) - - logger.info( - "Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), " - "borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%", - w, h, crop_w, crop_h, detected_format, format_confidence * 100, - border_top * 100, border_bottom * 100, - border_left * 100, border_right * 100, - ) - - return cropped, result - - -# --------------------------------------------------------------------------- -# Format detection (kept as optional metadata) -# --------------------------------------------------------------------------- - -def _detect_format(width: int, height: int) -> Tuple[str, float]: - """Detect paper format from dimensions by comparing aspect ratios.""" - if width <= 0 or height <= 0: - return "unknown", 0.0 - - aspect = max(width, height) / min(width, height) - - best_format = "unknown" - best_diff = float("inf") - - for fmt, expected_ratio in PAPER_FORMATS.items(): - diff = abs(aspect - expected_ratio) - if diff < best_diff: - best_diff = diff - best_format = fmt - - confidence = max(0.0, 1.0 - best_diff * 5.0) - - if confidence < 0.3: - return "unknown", 0.0 - - return best_format, round(confidence, 3) +# Backward-compat shim -- module moved to ocr/pipeline/page_crop_core.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_core") diff --git a/klausur-service/backend/page_crop_edges.py b/klausur-service/backend/page_crop_edges.py index b231078..21a87eb 100644 --- a/klausur-service/backend/page_crop_edges.py +++ b/klausur-service/backend/page_crop_edges.py @@ -1,388 +1,4 @@ -""" -Page Crop - Edge Detection Helpers - -Spine shadow detection, gutter continuity analysis, projection-based -edge detection, and narrow-run filtering for content cropping. - -Extracted from page_crop.py to keep files under 500 LOC. -License: Apache 2.0 -""" - -import logging -from typing import Optional, Tuple - -import cv2 -import numpy as np - -logger = logging.getLogger(__name__) - -# Minimum ink density (fraction of pixels) to count a row/column as "content" -_INK_THRESHOLD = 0.003 # 0.3% - -# Minimum run length (fraction of dimension) to keep — shorter runs are noise -_MIN_RUN_FRAC = 0.005 # 0.5% - - -def _detect_spine_shadow( - gray: np.ndarray, - search_region: np.ndarray, - offset_x: int, - w: int, - side: str, -) -> Optional[int]: - """Find the book spine center (darkest point) in a scanner shadow. - - The scanner produces a gray strip where the book spine presses against - the glass. The darkest column in that strip is the spine center — - that's where we crop. - - Distinguishes real spine shadows from text content by checking: - 1. Strong brightness range (> 40 levels) - 2. Darkest point is genuinely dark (< 180 mean brightness) - 3. The dark area is a NARROW valley, not a text-content plateau - 4. Brightness rises significantly toward the page content side - - Args: - gray: Full grayscale image (for context). - search_region: Column slice of the grayscale image to search in. - offset_x: X offset of search_region relative to full image. - w: Full image width. - side: 'left' or 'right' (for logging). - - Returns: - X coordinate (in full image) of the spine center, or None. - """ - region_w = search_region.shape[1] - if region_w < 10: - return None - - # Column-mean brightness in the search region - col_means = np.mean(search_region, axis=0).astype(np.float64) - - # Smooth with boxcar kernel (width = 1% of image width, min 5) - kernel_size = max(5, w // 100) - if kernel_size % 2 == 0: - kernel_size += 1 - kernel = np.ones(kernel_size) / kernel_size - smoothed_raw = np.convolve(col_means, kernel, mode="same") - - # Trim convolution edge artifacts (edges are zero-padded -> artificially low) - margin = kernel_size // 2 - if region_w <= 2 * margin + 10: - return None - smoothed = smoothed_raw[margin:region_w - margin] - trim_offset = margin # offset of smoothed[0] relative to search_region - - val_min = float(np.min(smoothed)) - val_max = float(np.max(smoothed)) - shadow_range = val_max - val_min - - # --- Check 1: Strong brightness gradient --- - if shadow_range <= 40: - logger.debug( - "%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range, - ) - return None - - # --- Check 2: Darkest point must be genuinely dark --- - if val_min > 180: - logger.debug( - "%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min, - ) - return None - - spine_idx = int(np.argmin(smoothed)) # index in trimmed array - spine_local = spine_idx + trim_offset # index in search_region - trimmed_len = len(smoothed) - - # --- Check 3: Valley width (spine is narrow, text plateau is wide) --- - valley_thresh = val_min + shadow_range * 0.20 - valley_mask = smoothed < valley_thresh - valley_width = int(np.sum(valley_mask)) - max_valley_frac = 0.50 - if valley_width > trimmed_len * max_valley_frac: - logger.debug( - "%s edge: no spine (valley too wide: %d/%d = %.0f%%)", - side.capitalize(), valley_width, trimmed_len, - 100.0 * valley_width / trimmed_len, - ) - return None - - # --- Check 4: Brightness must rise toward page content --- - rise_check_w = max(5, trimmed_len // 5) - if side == "left": - right_start = min(spine_idx + 5, trimmed_len - 1) - right_end = min(right_start + rise_check_w, trimmed_len) - if right_end > right_start: - rise_brightness = float(np.mean(smoothed[right_start:right_end])) - rise = rise_brightness - val_min - if rise < shadow_range * 0.3: - logger.debug( - "%s edge: no spine (insufficient rise: %.0f, need %.0f)", - side.capitalize(), rise, shadow_range * 0.3, - ) - return None - else: # right - left_end = max(spine_idx - 5, 0) - left_start = max(left_end - rise_check_w, 0) - if left_end > left_start: - rise_brightness = float(np.mean(smoothed[left_start:left_end])) - rise = rise_brightness - val_min - if rise < shadow_range * 0.3: - logger.debug( - "%s edge: no spine (insufficient rise: %.0f, need %.0f)", - side.capitalize(), rise, shadow_range * 0.3, - ) - return None - - spine_x = offset_x + spine_local - - logger.info( - "%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)", - side.capitalize(), spine_x, val_min, shadow_range, valley_width, - ) - return spine_x - - -def _detect_gutter_continuity( - gray: np.ndarray, - search_region: np.ndarray, - offset_x: int, - w: int, - side: str, -) -> Optional[int]: - """Detect gutter shadow via vertical continuity analysis. - - Camera book scans produce a subtle brightness gradient at the gutter - that is too faint for scanner-shadow detection (range < 40). However, - the gutter shadow has a unique property: it runs **continuously from - top to bottom** without interruption. - - Algorithm: - 1. Divide image into N horizontal strips (~60px each) - 2. For each column, compute what fraction of strips are darker than - the page median (from the center 50% of the full image) - 3. A "gutter column" has >= 75% of strips darker than page_median - d - 4. Smooth the dark-fraction profile and find the transition point - 5. Validate: gutter band must be 0.5%-10% of image width - """ - region_h, region_w = search_region.shape[:2] - if region_w < 20 or region_h < 100: - return None - - # --- 1. Divide into horizontal strips --- - strip_target_h = 60 - n_strips = max(10, region_h // strip_target_h) - strip_h = region_h // n_strips - - strip_means = np.zeros((n_strips, region_w), dtype=np.float64) - for s in range(n_strips): - y0 = s * strip_h - y1 = min((s + 1) * strip_h, region_h) - strip_means[s] = np.mean(search_region[y0:y1, :], axis=0) - - # --- 2. Page median from center 50% of full image --- - center_lo = w // 4 - center_hi = 3 * w // 4 - page_median = float(np.median(gray[:, center_lo:center_hi])) - - dark_thresh = page_median - 5.0 - - if page_median < 180: - return None - - # --- 3. Per-column dark fraction --- - dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64) - dark_frac = dark_count / n_strips - - # --- 4. Smooth and find transition --- - smooth_w = max(5, w // 100) - if smooth_w % 2 == 0: - smooth_w += 1 - kernel = np.ones(smooth_w) / smooth_w - frac_smooth = np.convolve(dark_frac, kernel, mode="same") - - margin = smooth_w // 2 - if region_w <= 2 * margin + 10: - return None - - transition_thresh = 0.50 - peak_frac = float(np.max(frac_smooth[margin:region_w - margin])) - - if peak_frac < 0.70: - logger.debug( - "%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac, - ) - return None - - peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin - gutter_inner = None - - if side == "right": - for x in range(peak_x, margin, -1): - if frac_smooth[x] < transition_thresh: - gutter_inner = x + 1 - break - else: - for x in range(peak_x, region_w - margin): - if frac_smooth[x] < transition_thresh: - gutter_inner = x - 1 - break - - if gutter_inner is None: - return None - - # --- 5. Validate gutter width --- - if side == "right": - gutter_width = region_w - gutter_inner - else: - gutter_width = gutter_inner - - min_gutter = max(3, int(w * 0.005)) - max_gutter = int(w * 0.10) - - if gutter_width < min_gutter: - logger.debug( - "%s gutter: too narrow (%dpx < %dpx)", side.capitalize(), - gutter_width, min_gutter, - ) - return None - - if gutter_width > max_gutter: - logger.debug( - "%s gutter: too wide (%dpx > %dpx)", side.capitalize(), - gutter_width, max_gutter, - ) - return None - - if side == "right": - gutter_brightness = float(np.mean(strip_means[:, gutter_inner:])) - else: - gutter_brightness = float(np.mean(strip_means[:, :gutter_inner])) - - brightness_drop = page_median - gutter_brightness - if brightness_drop < 3: - logger.debug( - "%s gutter: insufficient brightness drop (%.1f levels)", - side.capitalize(), brightness_drop, - ) - return None - - gutter_x = offset_x + gutter_inner - - logger.info( - "%s gutter (continuity): x=%d, width=%dpx (%.1f%%), " - "brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f", - side.capitalize(), gutter_x, gutter_width, - 100.0 * gutter_width / w, gutter_brightness, page_median, - brightness_drop, float(frac_smooth[gutter_inner]), - ) - return gutter_x - - -def _detect_left_edge_shadow( - gray: np.ndarray, - binary: np.ndarray, - w: int, - h: int, -) -> int: - """Detect left content edge, accounting for book-spine shadow. - - Tries three methods in order: - 1. Scanner spine-shadow (dark gradient, range > 40) - 2. Camera gutter continuity (subtle shadow running top-to-bottom) - 3. Binary projection fallback (first ink column) - """ - search_w = max(1, w // 4) - spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left") - if spine_x is not None: - return spine_x - - gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left") - if gutter_x is not None: - return gutter_x - - return _detect_edge_projection(binary, axis=0, from_start=True, dim=w) - - -def _detect_right_edge_shadow( - gray: np.ndarray, - binary: np.ndarray, - w: int, - h: int, -) -> int: - """Detect right content edge, accounting for book-spine shadow. - - Tries three methods in order: - 1. Scanner spine-shadow (dark gradient, range > 40) - 2. Camera gutter continuity (subtle shadow running top-to-bottom) - 3. Binary projection fallback (last ink column) - """ - search_w = max(1, w // 4) - right_start = w - search_w - spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right") - if spine_x is not None: - return spine_x - - gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right") - if gutter_x is not None: - return gutter_x - - return _detect_edge_projection(binary, axis=0, from_start=False, dim=w) - - -def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]: - """Detect top and bottom content edges via binary horizontal projection.""" - top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h) - bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h) - return top, bottom - - -def _detect_edge_projection( - binary: np.ndarray, - axis: int, - from_start: bool, - dim: int, -) -> int: - """Find the first/last row or column with ink density above threshold. - - axis=0 -> project vertically (column densities) -> returns x position - axis=1 -> project horizontally (row densities) -> returns y position - - Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension. - """ - projection = np.mean(binary, axis=axis) / 255.0 - - ink_mask = projection >= _INK_THRESHOLD - - min_run = max(1, int(dim * _MIN_RUN_FRAC)) - ink_mask = _filter_narrow_runs(ink_mask, min_run) - - ink_positions = np.where(ink_mask)[0] - if len(ink_positions) == 0: - return 0 if from_start else dim - - if from_start: - return int(ink_positions[0]) - else: - return int(ink_positions[-1]) - - -def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray: - """Remove True-runs shorter than min_run pixels.""" - if min_run <= 1: - return mask - - result = mask.copy() - n = len(result) - i = 0 - while i < n: - if result[i]: - start = i - while i < n and result[i]: - i += 1 - if i - start < min_run: - result[start:i] = False - else: - i += 1 - return result +# Backward-compat shim -- module moved to ocr/pipeline/page_crop_edges.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_crop_edges") diff --git a/klausur-service/backend/page_sub_sessions.py b/klausur-service/backend/page_sub_sessions.py index f6a918c..aee311b 100644 --- a/klausur-service/backend/page_sub_sessions.py +++ b/klausur-service/backend/page_sub_sessions.py @@ -1,189 +1,4 @@ -""" -Sub-session creation for multi-page spreads. - -Used by both the page-split and crop steps when a double-page scan is detected. -""" - -import logging -import uuid as uuid_mod -from typing import Any, Dict, List - -import cv2 -import numpy as np - -from page_crop import detect_and_crop_page -from ocr_pipeline_session_store import ( - create_session_db, - get_sub_sessions, - update_session_db, -) -from orientation_crop_helpers import get_cache_ref - -logger = logging.getLogger(__name__) - - -async def create_page_sub_sessions( - parent_session_id: str, - parent_cached: dict, - full_img_bgr: np.ndarray, - page_splits: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: - """Create sub-sessions for each detected page in a multi-page spread. - - Each page region is individually cropped, then stored as a sub-session - with its own cropped image ready for the rest of the pipeline. - """ - # Check for existing sub-sessions (idempotent) - existing = await get_sub_sessions(parent_session_id) - if existing: - return [ - {"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)} - for i, s in enumerate(existing) - ] - - parent_name = parent_cached.get("name", "Scan") - parent_filename = parent_cached.get("filename", "scan.png") - - sub_sessions: List[Dict[str, Any]] = [] - - for page in page_splits: - pi = page["page_index"] - px, py = page["x"], page["y"] - pw, ph = page["width"], page["height"] - - # Extract page region - page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy() - - # Crop each page individually (remove its own borders) - cropped_page, page_crop_info = detect_and_crop_page(page_bgr) - - # Encode as PNG - ok, png_buf = cv2.imencode(".png", cropped_page) - page_png = png_buf.tobytes() if ok else b"" - - sub_id = str(uuid_mod.uuid4()) - sub_name = f"{parent_name} — Seite {pi + 1}" - - await create_session_db( - session_id=sub_id, - name=sub_name, - filename=parent_filename, - original_png=page_png, - ) - - # Pre-populate: set cropped = original (already cropped) - await update_session_db( - sub_id, - cropped_png=page_png, - crop_result=page_crop_info, - current_step=5, - ) - - ch, cw = cropped_page.shape[:2] - sub_sessions.append({ - "id": sub_id, - "name": sub_name, - "page_index": pi, - "source_rect": page, - "cropped_size": {"width": cw, "height": ch}, - "detected_format": page_crop_info.get("detected_format"), - }) - - logger.info( - "Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d", - sub_id, pi + 1, px, pw, cw, ch, - ) - - return sub_sessions - - -async def create_page_sub_sessions_full( - parent_session_id: str, - parent_cached: dict, - full_img_bgr: np.ndarray, - page_splits: List[Dict[str, Any]], - start_step: int = 2, -) -> List[Dict[str, Any]]: - """Create sub-sessions for each page with RAW regions for full pipeline processing. - - Unlike ``create_page_sub_sessions`` (used by the crop step), these - sub-sessions store the *uncropped* page region and start at - ``start_step`` (default 2 = ready for deskew; 1 if orientation still - needed). Each page goes through its own pipeline independently, - which is essential for book spreads where each page has a different tilt. - """ - _cache = get_cache_ref() - - # Idempotent: reuse existing sub-sessions - existing = await get_sub_sessions(parent_session_id) - if existing: - return [ - {"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)} - for i, s in enumerate(existing) - ] - - parent_name = parent_cached.get("name", "Scan") - parent_filename = parent_cached.get("filename", "scan.png") - - sub_sessions: List[Dict[str, Any]] = [] - - for page in page_splits: - pi = page["page_index"] - px, py = page["x"], page["y"] - pw, ph = page["width"], page["height"] - - # Extract RAW page region — NO individual cropping here; each - # sub-session will run its own crop step after deskew + dewarp. - page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy() - - # Encode as PNG - ok, png_buf = cv2.imencode(".png", page_bgr) - page_png = png_buf.tobytes() if ok else b"" - - sub_id = str(uuid_mod.uuid4()) - sub_name = f"{parent_name} — Seite {pi + 1}" - - await create_session_db( - session_id=sub_id, - name=sub_name, - filename=parent_filename, - original_png=page_png, - ) - - # start_step=2 -> ready for deskew (orientation already done on spread) - # start_step=1 -> needs its own orientation (split from original image) - await update_session_db(sub_id, current_step=start_step) - - # Cache the BGR so the pipeline can start immediately - _cache[sub_id] = { - "id": sub_id, - "filename": parent_filename, - "name": sub_name, - "original_bgr": page_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": start_step, - } - - rh, rw = page_bgr.shape[:2] - sub_sessions.append({ - "id": sub_id, - "name": sub_name, - "page_index": pi, - "source_rect": page, - "image_size": {"width": rw, "height": rh}, - }) - - logger.info( - "Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d", - sub_id, pi + 1, px, pw, rw, rh, - ) - - return sub_sessions +# Backward-compat shim -- module moved to ocr/pipeline/page_sub_sessions.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.page_sub_sessions") diff --git a/klausur-service/backend/scan_quality.py b/klausur-service/backend/scan_quality.py index d869140..831639c 100644 --- a/klausur-service/backend/scan_quality.py +++ b/klausur-service/backend/scan_quality.py @@ -1,102 +1,4 @@ -""" -Scan Quality Assessment — Measures image quality before OCR. - -Computes blur score, contrast score, and an overall quality rating. -Used to gate enhancement steps and warn users about degraded scans. - -All operations use OpenCV (Apache-2.0), no additional dependencies. -""" - -import logging -from dataclasses import dataclass, asdict -from typing import Dict, Any - -import cv2 -import numpy as np - -logger = logging.getLogger(__name__) - -# Thresholds (empirically tuned on textbook scans) -BLUR_THRESHOLD = 100.0 # Laplacian variance below this = blurry -CONTRAST_THRESHOLD = 40.0 # Grayscale stddev below this = low contrast -CONFIDENCE_GOOD = 40 # OCR min confidence for good scans -CONFIDENCE_DEGRADED = 30 # OCR min confidence for degraded scans - - -@dataclass -class ScanQualityReport: - """Result of scan quality assessment.""" - blur_score: float # Laplacian variance (higher = sharper) - contrast_score: float # Grayscale std deviation (higher = more contrast) - brightness: float # Mean grayscale value (0-255) - is_blurry: bool - is_low_contrast: bool - is_degraded: bool # True if any quality issue detected - quality_pct: int # 0-100 overall quality estimate - recommended_min_conf: int # Recommended OCR confidence threshold - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -def score_scan_quality(img_bgr: np.ndarray) -> ScanQualityReport: - """ - Assess the quality of a scanned image. - - Uses: - - Laplacian variance for blur detection - - Grayscale standard deviation for contrast - - Mean brightness for exposure assessment - - Args: - img_bgr: BGR image (numpy array from OpenCV) - - Returns: - ScanQualityReport with scores and recommendations - """ - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - - # Blur detection: Laplacian variance - # Higher = sharper edges = better quality - laplacian = cv2.Laplacian(gray, cv2.CV_64F) - blur_score = float(laplacian.var()) - - # Contrast: standard deviation of grayscale - contrast_score = float(np.std(gray)) - - # Brightness: mean grayscale - brightness = float(np.mean(gray)) - - # Quality flags - is_blurry = blur_score < BLUR_THRESHOLD - is_low_contrast = contrast_score < CONTRAST_THRESHOLD - is_degraded = is_blurry or is_low_contrast - - # Overall quality percentage (simple weighted combination) - blur_pct = min(100, blur_score / BLUR_THRESHOLD * 50) - contrast_pct = min(100, contrast_score / CONTRAST_THRESHOLD * 50) - quality_pct = int(min(100, blur_pct + contrast_pct)) - - # Recommended confidence threshold - recommended_min_conf = CONFIDENCE_DEGRADED if is_degraded else CONFIDENCE_GOOD - - report = ScanQualityReport( - blur_score=round(blur_score, 1), - contrast_score=round(contrast_score, 1), - brightness=round(brightness, 1), - is_blurry=is_blurry, - is_low_contrast=is_low_contrast, - is_degraded=is_degraded, - quality_pct=quality_pct, - recommended_min_conf=recommended_min_conf, - ) - - logger.info( - f"Scan quality: blur={report.blur_score} " - f"contrast={report.contrast_score} " - f"quality={report.quality_pct}% " - f"degraded={report.is_degraded} " - f"min_conf={report.recommended_min_conf}" - ) - - return report +# Backward-compat shim -- module moved to ocr/pipeline/scan_quality.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.scan_quality") diff --git a/klausur-service/backend/vision_ocr_fusion.py b/klausur-service/backend/vision_ocr_fusion.py index c2ef216..dc175fb 100644 --- a/klausur-service/backend/vision_ocr_fusion.py +++ b/klausur-service/backend/vision_ocr_fusion.py @@ -1,261 +1,4 @@ -""" -Vision-LLM OCR Fusion — Combines traditional OCR positions with Vision-LLM reading. - -Sends the scan image + OCR word coordinates + document type to Qwen2.5-VL. -The LLM can read degraded text using context understanding and visual inspection, -while OCR coordinates provide structural hints (where text is, column positions). - -Uses Ollama API (same pattern as handwriting_htr_api.py). -""" - -import base64 -import json -import logging -import os -import re -from typing import Any, Dict, List, Optional - -import cv2 -import httpx -import numpy as np - -logger = logging.getLogger(__name__) - -OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") -VISION_FUSION_MODEL = os.getenv("VISION_FUSION_MODEL", "llama3.2-vision:11b") - -# Document category → prompt context -CATEGORY_PROMPTS: Dict[str, Dict[str, str]] = { - "vokabelseite": { - "label": "Vokabelseite eines Schulbuchs (Englisch-Deutsch)", - "columns": "Die Tabelle hat typischerweise 3 Spalten: Englisch, Deutsch, Beispielsatz.", - }, - "woerterbuch": { - "label": "Woerterbuchseite", - "columns": "Die Eintraege haben: Stichwort, Lautschrift, Uebersetzung(en), Beispielsaetze.", - }, - "arbeitsblatt": { - "label": "Arbeitsblatt", - "columns": "Erkenne die Spaltenstruktur aus dem Layout.", - }, - "buchseite": { - "label": "Schulbuchseite", - "columns": "Erkenne die Spaltenstruktur aus dem Layout.", - }, -} - - -def _group_words_into_lines( - words: List[Dict], y_tolerance: float = 15.0, -) -> List[List[Dict]]: - """Group OCR words into lines by Y-proximity.""" - if not words: - return [] - sorted_w = sorted(words, key=lambda w: w.get("top", 0)) - lines: List[List[Dict]] = [[sorted_w[0]]] - for w in sorted_w[1:]: - last_line = lines[-1] - avg_y = sum(ww["top"] for ww in last_line) / len(last_line) - if abs(w["top"] - avg_y) <= y_tolerance: - last_line.append(w) - else: - lines.append([w]) - # Sort words within each line by X - for line in lines: - line.sort(key=lambda w: w.get("left", 0)) - return lines - - -def _build_ocr_context(words: List[Dict], img_h: int) -> str: - """Build a text description of OCR words with positions for the prompt.""" - lines = _group_words_into_lines(words) - context_parts = [] - for i, line in enumerate(lines): - word_descs = [] - for w in line: - text = w.get("text", "").strip() - x = w.get("left", 0) - conf = w.get("conf", 0) - marker = " (?)" if conf < 50 else "" - word_descs.append(f'x={x} "{text}"{marker}') - avg_y = int(sum(w["top"] for w in line) / len(line)) - context_parts.append(f"Zeile {i+1} (y~{avg_y}): {', '.join(word_descs)}") - return "\n".join(context_parts) - - -def _build_prompt( - ocr_context: str, category: str, img_w: int, img_h: int, -) -> str: - """Build the Vision-LLM prompt with OCR context and document type.""" - cat_info = CATEGORY_PROMPTS.get(category, CATEGORY_PROMPTS["buchseite"]) - - return f"""Du siehst eine eingescannte {cat_info['label']}. -{cat_info['columns']} - -Die OCR-Software hat folgende Woerter an diesen Positionen erkannt. -Woerter mit (?) haben niedrige Erkennungssicherheit und sind wahrscheinlich falsch: - -{ocr_context} - -Bildgroesse: {img_w} x {img_h} Pixel. - -AUFGABE: Schau dir das Bild genau an und erstelle die korrekte Tabelle. -- Korrigiere falsch erkannte Woerter anhand dessen was du im Bild siehst -- Fasse Fortsetzungszeilen zusammen (wenn eine Spalte in der naechsten Zeile leer ist, - gehoert der Text zur Zeile darueber — der Autor hat nur einen Zeilenumbruch innerhalb der Zelle gemacht) -- Behalte die Reihenfolge bei - -Antworte NUR mit einem JSON-Array, keine Erklaerungen: -[ - {{"row": 1, "english": "...", "german": "...", "example": "..."}}, - {{"row": 2, "english": "...", "german": "...", "example": "..."}} -]""" - - -def _parse_llm_response(response_text: str) -> Optional[List[Dict]]: - """Parse the LLM JSON response, handling markdown code blocks.""" - text = response_text.strip() - - # Strip markdown code block if present - if text.startswith("```"): - text = re.sub(r"^```(?:json)?\s*", "", text) - text = re.sub(r"\s*```\s*$", "", text) - - # Try to find JSON array - match = re.search(r"\[[\s\S]*\]", text) - if not match: - logger.warning("vision_fuse_ocr: no JSON array found in LLM response") - return None - - try: - data = json.loads(match.group()) - if not isinstance(data, list): - return None - return data - except json.JSONDecodeError as e: - logger.warning(f"vision_fuse_ocr: JSON parse error: {e}") - return None - - -def _vocab_rows_to_words( - rows: List[Dict], img_w: int, img_h: int, -) -> List[Dict]: - """Convert LLM vocab rows back to word dicts for grid building. - - Distributes words across estimated column positions so the - existing grid builder can process them normally. - """ - words = [] - # Estimate column positions (3-column vocab layout) - col_positions = [ - (0.02, 0.28), # EN: 2%-28% of width - (0.30, 0.55), # DE: 30%-55% - (0.57, 0.98), # Example: 57%-98% - ] - - median_h = max(15, img_h // (len(rows) * 3)) if rows else 20 - y_step = max(median_h + 5, img_h // max(len(rows), 1)) - - for i, row in enumerate(rows): - y = int(i * y_step + 20) - row_num = row.get("row", i + 1) - - for col_idx, (field, (x_start_pct, x_end_pct)) in enumerate([ - ("english", col_positions[0]), - ("german", col_positions[1]), - ("example", col_positions[2]), - ]): - text = (row.get(field) or "").strip() - if not text: - continue - x = int(x_start_pct * img_w) - w = int((x_end_pct - x_start_pct) * img_w) - words.append({ - "text": text, - "left": x, - "top": y, - "width": w, - "height": median_h, - "conf": 95, # LLM-corrected → high confidence - "_source": "vision_llm", - "_row": row_num, - "_col_type": f"column_{['en', 'de', 'example'][col_idx]}", - }) - - logger.info(f"vision_fuse_ocr: converted {len(rows)} LLM rows → {len(words)} words") - return words - - -async def vision_fuse_ocr( - img_bgr: np.ndarray, - ocr_words: List[Dict], - document_category: str = "vokabelseite", -) -> List[Dict]: - """Fuse traditional OCR results with Vision-LLM reading. - - Sends the image + OCR word positions to Qwen2.5-VL which can: - - Read degraded text that traditional OCR cannot - - Use document context (knows what a vocab table looks like) - - Merge continuation rows (understands table structure) - - Args: - img_bgr: The cropped/dewarped scan image (BGR) - ocr_words: Traditional OCR word list with positions - document_category: Type of document being scanned - - Returns: - Corrected word list in same format as input, ready for grid building. - Falls back to original ocr_words on error. - """ - img_h, img_w = img_bgr.shape[:2] - - # Build OCR context string - ocr_context = _build_ocr_context(ocr_words, img_h) - - # Build prompt - prompt = _build_prompt(ocr_context, document_category, img_w, img_h) - - # Encode image as base64 - _, img_encoded = cv2.imencode(".png", img_bgr) - img_b64 = base64.b64encode(img_encoded.tobytes()).decode("utf-8") - - # Call Qwen2.5-VL via Ollama - try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.post( - f"{OLLAMA_BASE_URL}/api/generate", - json={ - "model": VISION_FUSION_MODEL, - "prompt": prompt, - "images": [img_b64], - "stream": False, - "options": {"temperature": 0.1, "num_predict": 4096}, - }, - ) - resp.raise_for_status() - data = resp.json() - response_text = data.get("response", "").strip() - except Exception as e: - logger.error(f"vision_fuse_ocr: Ollama call failed: {e}") - return ocr_words # Fallback to original - - if not response_text: - logger.warning("vision_fuse_ocr: empty LLM response") - return ocr_words - - # Parse JSON response - rows = _parse_llm_response(response_text) - if not rows: - logger.warning( - "vision_fuse_ocr: could not parse LLM response, " - "first 200 chars: %s", response_text[:200], - ) - return ocr_words - - logger.info( - f"vision_fuse_ocr: LLM returned {len(rows)} vocab rows " - f"(from {len(ocr_words)} OCR words)" - ) - - # Convert back to word format for grid building - return _vocab_rows_to_words(rows, img_w, img_h) +# Backward-compat shim -- module moved to ocr/pipeline/vision_fusion.py +import importlib as _importlib +import sys as _sys +_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.vision_fusion")