diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index 0a18d9b..642d122 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -1,5426 +1,61 @@ """ OCR Pipeline API - Schrittweise Seitenrekonstruktion. -Zerlegt den OCR-Prozess in 10 einzelne Schritte: -1. Orientierung - 90/180/270° Drehungen korrigieren (orientation_crop_api.py) -2. Begradigung (Deskew) - Scan begradigen -3. Entzerrung (Dewarp) - Buchwoelbung entzerren -4. Zuschneiden - Scannerraender/Buchruecken entfernen (orientation_crop_api.py) -5. Spaltenerkennung - Unsichtbare Spalten finden -6. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen -7. Worterkennung - OCR mit Bounding Boxes -8. LLM-Korrektur - OCR-Fehler per LLM korrigieren -9. Seitenrekonstruktion - Seite nachbauen -10. Ground Truth Validierung - Gesamtpruefung +Thin wrapper that assembles all sub-module routers into a single +composite router. Backward-compatible: main.py and tests can still +import ``router``, ``_cache``, and helper functions from here. + +Sub-modules (each < 1 000 lines): + ocr_pipeline_common – shared state, cache, Pydantic models, helpers + ocr_pipeline_sessions – session CRUD, image serving, doc-type + ocr_pipeline_geometry – deskew, dewarp, structure, columns + ocr_pipeline_rows – row detection, box-overlay helper + ocr_pipeline_words – word detection (SSE), paddle-direct, word GT + ocr_pipeline_ocr_merge – paddle/tesseract merge helpers, kombi endpoints + ocr_pipeline_postprocess – LLM review, reconstruction, export, validation + ocr_pipeline_auto – auto-mode orchestrator, reprocess Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json -import logging -import os -import re -import time -import uuid -from dataclasses import asdict -from datetime import datetime -from typing import Any, Dict, List, Optional +from fastapi import APIRouter -import cv2 -import numpy as np -from fastapi import APIRouter, File, Form, HTTPException, Query, Request, UploadFile -from fastapi.responses import Response, StreamingResponse -from pydantic import BaseModel - -from cv_vocab_pipeline import ( - OLLAMA_REVIEW_MODEL, - DocumentTypeResult, - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _detect_header_footer_gaps, - _detect_sub_columns, - _fix_character_confusion, - _fix_phonetic_brackets, - fix_cell_phonetics, - analyze_layout, - analyze_layout_by_words, - build_cell_grid, - build_cell_grid_streaming, - build_cell_grid_v2, - build_cell_grid_v2_streaming, - build_word_grid, - classify_column_types, - create_layout_image, - create_ocr_image, - deskew_image, - deskew_image_by_word_alignment, - deskew_image_iterative, - deskew_two_pass, - detect_column_geometry, - detect_column_geometry_zoned, - detect_document_type, - detect_row_geometry, - expand_narrow_columns, - _apply_shear, - dewarp_image, - dewarp_image_manual, - llm_review_entries, - llm_review_entries_streaming, - render_image_high_res, - render_pdf_high_res, -) -from cv_box_detect import detect_boxes, split_page_into_zones -from cv_color_detect import detect_word_colors, recover_colored_text, _COLOR_RANGES, _COLOR_HEX -from cv_graphic_detect import detect_graphic_elements -from cv_words_first import build_grid_from_words -from ocr_pipeline_session_store import ( - create_session_db, - delete_all_sessions_db, - delete_session_db, - get_session_db, - get_session_image, - get_sub_sessions, - init_ocr_pipeline_tables, - list_sessions_db, - update_session_db, +# --------------------------------------------------------------------------- +# Shared state (imported by main.py and orientation_crop_api.py) +# --------------------------------------------------------------------------- +from ocr_pipeline_common import ( # noqa: F401 – re-exported + _cache, + _BORDER_GHOST_CHARS, + _filter_border_ghost_words, ) -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) +# --------------------------------------------------------------------------- +# Sub-module routers +# --------------------------------------------------------------------------- +from ocr_pipeline_sessions import router as _sessions_router +from ocr_pipeline_geometry import router as _geometry_router +from ocr_pipeline_rows import router as _rows_router +from ocr_pipeline_words import router as _words_router +from ocr_pipeline_ocr_merge import ( + router as _ocr_merge_router, + # Re-export for test backward compatibility + _split_paddle_multi_words, # noqa: F401 + _group_words_into_rows, # noqa: F401 + _merge_row_sequences, # noqa: F401 + _merge_paddle_tesseract, # noqa: F401 +) +from ocr_pipeline_postprocess import router as _postprocess_router +from ocr_pipeline_auto import router as _auto_router # --------------------------------------------------------------------------- -# In-memory cache for active sessions (BGR numpy arrays for processing) -# DB is source of truth, cache holds BGR arrays during active processing. +# Composite router (used by main.py) # --------------------------------------------------------------------------- - -_cache: Dict[str, Dict[str, Any]] = {} - - -async def _get_base_image_png(session_id: str) -> Optional[bytes]: - """Get the best available base image for a session (cropped > dewarped > original).""" - for img_type in ("cropped", "dewarped", "original"): - png_data = await get_session_image(session_id, img_type) - if png_data: - return png_data - return None - - -async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: - """Load session from DB into cache, decoding PNGs to BGR arrays.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - if session_id in _cache: - return _cache[session_id] - - cache_entry: Dict[str, Any] = { - "id": session_id, - **session, - "original_bgr": None, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - } - - # Decode images from DB into BGR numpy arrays - for img_type, bgr_key in [ - ("original", "original_bgr"), - ("oriented", "oriented_bgr"), - ("cropped", "cropped_bgr"), - ("deskewed", "deskewed_bgr"), - ("dewarped", "dewarped_bgr"), - ]: - png_data = await get_session_image(session_id, img_type) - if png_data: - arr = np.frombuffer(png_data, dtype=np.uint8) - bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) - cache_entry[bgr_key] = bgr - - # Sub-sessions: original image IS the cropped box region. - # Promote original_bgr to cropped_bgr so downstream steps find it. - if session.get("parent_session_id") and cache_entry["original_bgr"] is not None: - if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None: - cache_entry["cropped_bgr"] = cache_entry["original_bgr"] - - _cache[session_id] = cache_entry - return cache_entry - - -def _get_cached(session_id: str) -> Dict[str, Any]: - """Get from cache or raise 404.""" - entry = _cache.get(session_id) - if not entry: - raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") - return entry - - -# --------------------------------------------------------------------------- -# Pydantic Models -# --------------------------------------------------------------------------- - -class ManualDeskewRequest(BaseModel): - angle: float - - -class DeskewGroundTruthRequest(BaseModel): - is_correct: bool - corrected_angle: Optional[float] = None - notes: Optional[str] = None - - -class ManualDewarpRequest(BaseModel): - shear_degrees: float - - -class CombinedAdjustRequest(BaseModel): - rotation_degrees: float = 0.0 - shear_degrees: float = 0.0 - - -class DewarpGroundTruthRequest(BaseModel): - is_correct: bool - corrected_shear: Optional[float] = None - notes: Optional[str] = None - - -VALID_DOCUMENT_CATEGORIES = { - 'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite', - 'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges', -} - - -class UpdateSessionRequest(BaseModel): - name: Optional[str] = None - document_category: Optional[str] = None - - -class ManualColumnsRequest(BaseModel): - columns: List[Dict[str, Any]] - - -class ColumnGroundTruthRequest(BaseModel): - is_correct: bool - corrected_columns: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -class ManualRowsRequest(BaseModel): - rows: List[Dict[str, Any]] - - -class RowGroundTruthRequest(BaseModel): - is_correct: bool - corrected_rows: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -class RemoveHandwritingRequest(BaseModel): - method: str = "auto" # "auto" | "telea" | "ns" - target_ink: str = "all" # "all" | "colored" | "pencil" - dilation: int = 2 # mask dilation iterations (0-5) - use_source: str = "auto" # "original" | "deskewed" | "auto" - - -# --------------------------------------------------------------------------- -# Session Management Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions") -async def list_sessions(include_sub_sessions: bool = False): - """List OCR pipeline sessions. - - By default, sub-sessions (box regions) are hidden. - Pass ?include_sub_sessions=true to show them. - """ - sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) - return {"sessions": sessions} - - -@router.post("/sessions") -async def create_session( - file: UploadFile = File(...), - name: Optional[str] = Form(None), -): - """Upload a PDF or image file and create a pipeline session.""" - file_data = await file.read() - filename = file.filename or "upload" - content_type = file.content_type or "" - - session_id = str(uuid.uuid4()) - is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") - - try: - if is_pdf: - img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) - else: - img_bgr = render_image_high_res(file_data) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Could not process file: {e}") - - # Encode original as PNG bytes - success, png_buf = cv2.imencode(".png", img_bgr) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode image") - - original_png = png_buf.tobytes() - session_name = name or filename - - # Persist to DB - await create_session_db( - session_id=session_id, - name=session_name, - filename=filename, - original_png=original_png, - ) - - # Cache BGR array for immediate processing - _cache[session_id] = { - "id": session_id, - "filename": filename, - "name": session_name, - "original_bgr": img_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - logger.info(f"OCR Pipeline: created session {session_id} from {filename} " - f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") - - return { - "session_id": session_id, - "filename": filename, - "name": session_name, - "image_width": img_bgr.shape[1], - "image_height": img_bgr.shape[0], - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - } - - -@router.get("/sessions/{session_id}") -async def get_session_info(session_id: str): - """Get session info including deskew/dewarp/column results for step navigation.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # Get image dimensions from original PNG - original_png = await get_session_image(session_id, "original") - if original_png: - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) - else: - img_w, img_h = 0, 0 - - result = { - "session_id": session["id"], - "filename": session.get("filename", ""), - "name": session.get("name", ""), - "image_width": img_w, - "image_height": img_h, - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - "current_step": session.get("current_step", 1), - "document_category": session.get("document_category"), - "doc_type": session.get("doc_type"), - } - - if session.get("orientation_result"): - result["orientation_result"] = session["orientation_result"] - if session.get("crop_result"): - result["crop_result"] = session["crop_result"] - if session.get("deskew_result"): - result["deskew_result"] = session["deskew_result"] - if session.get("dewarp_result"): - result["dewarp_result"] = session["dewarp_result"] - if session.get("column_result"): - result["column_result"] = session["column_result"] - if session.get("row_result"): - result["row_result"] = session["row_result"] - if session.get("word_result"): - result["word_result"] = session["word_result"] - if session.get("doc_type_result"): - result["doc_type_result"] = session["doc_type_result"] - - # Sub-session info - if session.get("parent_session_id"): - result["parent_session_id"] = session["parent_session_id"] - result["box_index"] = session.get("box_index") - else: - # Check for sub-sessions - subs = await get_sub_sessions(session_id) - if subs: - result["sub_sessions"] = [ - {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} - for s in subs - ] - - return result - - -@router.put("/sessions/{session_id}") -async def update_session(session_id: str, req: UpdateSessionRequest): - """Update session name and/or document category.""" - kwargs: Dict[str, Any] = {} - if req.name is not None: - kwargs["name"] = req.name - if req.document_category is not None: - if req.document_category not in VALID_DOCUMENT_CATEGORIES: - raise HTTPException( - status_code=400, - detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", - ) - kwargs["document_category"] = req.document_category - if not kwargs: - raise HTTPException(status_code=400, detail="Nothing to update") - updated = await update_session_db(session_id, **kwargs) - if not updated: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, **kwargs} - - -@router.delete("/sessions/{session_id}") -async def delete_session(session_id: str): - """Delete a session.""" - _cache.pop(session_id, None) - deleted = await delete_session_db(session_id) - if not deleted: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "deleted": True} - - -@router.delete("/sessions") -async def delete_all_sessions(): - """Delete ALL sessions (cleanup).""" - _cache.clear() - count = await delete_all_sessions_db() - return {"deleted_count": count} - - -@router.post("/sessions/{session_id}/create-box-sessions") -async def create_box_sessions(session_id: str): - """Create sub-sessions for each detected box region. - - Crops box regions from the cropped/dewarped image and creates - independent sub-sessions that can be processed through the pipeline. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result: - raise HTTPException(status_code=400, detail="Column detection must be completed first") - - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - if not box_zones: - return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} - - # Check for existing sub-sessions - existing = await get_sub_sessions(session_id) - if existing: - return { - "session_id": session_id, - "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], - "message": f"{len(existing)} sub-session(s) already exist", - } - - # Load base image - base_png = await get_session_image(session_id, "cropped") - if not base_png: - base_png = await get_session_image(session_id, "dewarped") - if not base_png: - raise HTTPException(status_code=400, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - parent_name = session.get("name", "Session") - created = [] - - for i, zone in enumerate(box_zones): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - - # Crop box region with small padding - pad = 5 - y1 = max(0, by - pad) - y2 = min(img.shape[0], by + bh + pad) - x1 = max(0, bx - pad) - x2 = min(img.shape[1], bx + bw + pad) - crop = img[y1:y2, x1:x2] - - # Encode as PNG - success, png_buf = cv2.imencode(".png", crop) - if not success: - logger.warning(f"Failed to encode box {i} crop for session {session_id}") - continue - - sub_id = str(uuid.uuid4()) - sub_name = f"{parent_name} — Box {i + 1}" - - await create_session_db( - session_id=sub_id, - name=sub_name, - filename=session.get("filename", "box-crop.png"), - original_png=png_buf.tobytes(), - parent_session_id=session_id, - box_index=i, - ) - - # Cache the BGR for immediate processing - # Promote original to cropped so column/row/word detection finds it - box_bgr = crop.copy() - _cache[sub_id] = { - "id": sub_id, - "filename": session.get("filename", "box-crop.png"), - "name": sub_name, - "parent_session_id": session_id, - "original_bgr": box_bgr, - "oriented_bgr": None, - "cropped_bgr": box_bgr, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - created.append({ - "id": sub_id, - "name": sub_name, - "box_index": i, - "box": box, - "image_width": crop.shape[1], - "image_height": crop.shape[0], - }) - - logger.info(f"Created box sub-session {sub_id} for session {session_id} " - f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") - - return { - "session_id": session_id, - "sub_sessions": created, - "total": len(created), - } - - -@router.get("/sessions/{session_id}/thumbnail") -async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): - """Return a small thumbnail of the original image.""" - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - h, w = img.shape[:2] - scale = size / max(h, w) - new_w, new_h = int(w * scale), int(h * scale) - thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) - _, png_bytes = cv2.imencode(".png", thumb) - return Response(content=png_bytes.tobytes(), media_type="image/png", - headers={"Cache-Control": "public, max-age=3600"}) - - -@router.get("/sessions/{session_id}/pipeline-log") -async def get_pipeline_log(session_id: str): - """Get the pipeline execution log for a session.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} - - -@router.get("/categories") -async def list_categories(): - """List valid document categories.""" - return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} - - -# --------------------------------------------------------------------------- -# Pipeline Log Helper -# --------------------------------------------------------------------------- - -async def _append_pipeline_log( - session_id: str, - step_name: str, - metrics: Dict[str, Any], - success: bool = True, - duration_ms: Optional[int] = None, -): - """Append a step entry to the session's pipeline_log JSONB.""" - session = await get_session_db(session_id) - if not session: - return - log = session.get("pipeline_log") or {"steps": []} - if not isinstance(log, dict): - log = {"steps": []} - entry = { - "step": step_name, - "completed_at": datetime.utcnow().isoformat(), - "success": success, - "metrics": metrics, - } - if duration_ms is not None: - entry["duration_ms"] = duration_ms - log.setdefault("steps", []).append(entry) - await update_session_db(session_id, pipeline_log=log) - - -# --------------------------------------------------------------------------- -# Image Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/image/{image_type}") -async def get_image(session_id: str, image_type: str): - """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" - valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} - if image_type not in valid_types: - raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") - - if image_type == "structure-overlay": - return await _get_structure_overlay(session_id) - - if image_type == "columns-overlay": - return await _get_columns_overlay(session_id) - - if image_type == "rows-overlay": - return await _get_rows_overlay(session_id) - - if image_type == "words-overlay": - return await _get_words_overlay(session_id) - - # Try cache first for fast serving - cached = _cache.get(session_id) - if cached: - png_key = f"{image_type}_png" if image_type != "original" else None - bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None - - # For binarized, check if we have it cached as PNG - if image_type == "binarized" and cached.get("binarized_png"): - return Response(content=cached["binarized_png"], media_type="image/png") - - # Load from DB — for cropped/dewarped, fall back through the chain - if image_type in ("cropped", "dewarped"): - data = await _get_base_image_png(session_id) - else: - data = await get_session_image(session_id, image_type) - if not data: - raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") - - return Response(content=data, media_type="image/png") - - -# --------------------------------------------------------------------------- -# Deskew Endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/deskew") -async def auto_deskew(session_id: str): - """Two-pass deskew: iterative projection (wide range) + word-alignment residual.""" - # Ensure session is in cache - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - # Deskew runs right after orientation — use oriented image, fall back to original - img_bgr = next((v for k in ("oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), None) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for deskewing") - - t0 = time.time() - - # Two-pass deskew: iterative (±5°) + word-alignment residual check - deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy()) - - # Also run individual methods for reporting (non-authoritative) - try: - _, angle_hough = deskew_image(img_bgr.copy()) - except Exception: - angle_hough = 0.0 - - success_enc, png_orig = cv2.imencode(".png", img_bgr) - orig_bytes = png_orig.tobytes() if success_enc else b"" - try: - _, angle_wa = deskew_image_by_word_alignment(orig_bytes) - except Exception: - angle_wa = 0.0 - - angle_iterative = two_pass_debug.get("pass1_angle", 0.0) - angle_residual = two_pass_debug.get("pass2_angle", 0.0) - angle_textline = two_pass_debug.get("pass3_angle", 0.0) - - duration = time.time() - t0 - - method_used = "three_pass" if abs(angle_textline) >= 0.01 else ( - "two_pass" if abs(angle_residual) >= 0.01 else "iterative" - ) - - # Encode as PNG - success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = deskewed_png_buf.tobytes() if success else b"" - - # Create binarized version - binarized_png = None - try: - binarized = create_ocr_image(deskewed_bgr) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception as e: - logger.warning(f"Binarization failed: {e}") - - confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0) - - deskew_result = { - "angle_hough": round(angle_hough, 3), - "angle_word_alignment": round(angle_wa, 3), - "angle_iterative": round(angle_iterative, 3), - "angle_residual": round(angle_residual, 3), - "angle_textline": round(angle_textline, 3), - "angle_applied": round(angle_applied, 3), - "method_used": method_used, - "confidence": round(confidence, 2), - "duration_seconds": round(duration, 2), - "two_pass_debug": two_pass_debug, - } - - # Update cache - cached["deskewed_bgr"] = deskewed_bgr - cached["binarized_png"] = binarized_png - cached["deskew_result"] = deskew_result - - # Persist to DB - db_update = { - "deskewed_png": deskewed_png, - "deskew_result": deskew_result, - "current_step": 3, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: deskew session {session_id}: " - f"hough={angle_hough:.2f} wa={angle_wa:.2f} " - f"iter={angle_iterative:.2f} residual={angle_residual:.2f} " - f"textline={angle_textline:.2f} " - f"-> {method_used} total={angle_applied:.2f}") - - await _append_pipeline_log(session_id, "deskew", { - "angle_applied": round(angle_applied, 3), - "angle_iterative": round(angle_iterative, 3), - "angle_residual": round(angle_residual, 3), - "angle_textline": round(angle_textline, 3), - "confidence": round(confidence, 2), - "method": method_used, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **deskew_result, - "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", - "binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized", - } - - -@router.post("/sessions/{session_id}/deskew/manual") -async def manual_deskew(session_id: str, req: ManualDeskewRequest): - """Apply a manual rotation angle to the oriented image.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = next((v for k in ("oriented_bgr", "original_bgr") - if (v := cached.get(k)) is not None), None) - if img_bgr is None: - raise HTTPException(status_code=400, detail="No image available for deskewing") - - angle = max(-5.0, min(5.0, req.angle)) - - h, w = img_bgr.shape[:2] - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, angle, 1.0) - rotated = cv2.warpAffine(img_bgr, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - success, png_buf = cv2.imencode(".png", rotated) - deskewed_png = png_buf.tobytes() if success else b"" - - # Binarize - binarized_png = None - try: - binarized = create_ocr_image(rotated) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception: - pass - - deskew_result = { - **(cached.get("deskew_result") or {}), - "angle_applied": round(angle, 3), - "method_used": "manual", - } - - # Update cache - cached["deskewed_bgr"] = rotated - cached["binarized_png"] = binarized_png - cached["deskew_result"] = deskew_result - - # Persist to DB - db_update = { - "deskewed_png": deskewed_png, - "deskew_result": deskew_result, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}") - - return { - "session_id": session_id, - "angle_applied": round(angle, 3), - "method_used": "manual", - "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", - } - - -@router.post("/sessions/{session_id}/ground-truth/deskew") -async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest): - """Save ground truth feedback for the deskew step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_angle": req.corrected_angle, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "deskew_result": session.get("deskew_result"), - } - ground_truth["deskew"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - # Update cache - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: " - f"correct={req.is_correct}, corrected_angle={req.corrected_angle}") - - return {"session_id": session_id, "ground_truth": gt} - - -# --------------------------------------------------------------------------- -# Dewarp Endpoints -# --------------------------------------------------------------------------- - -async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: - """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. - - The VLM is shown the image and asked: are the column/table borders tilted? - If yes, by how many degrees? Returns a dict with shear_degrees and confidence. - Confidence is 0.0 if Ollama is unavailable or parsing fails. - """ - import httpx - import base64 - import re - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " - "Are they perfectly vertical, or do they tilt slightly? " - "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " - "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " - "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " - "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" - ) - - img_b64 = base64.b64encode(image_bytes).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON from response (may have surrounding text) - match = re.search(r'\{[^}]+\}', text) - if match: - import json - data = json.loads(match.group(0)) - shear = float(data.get("shear_degrees", 0.0)) - conf = float(data.get("confidence", 0.0)) - # Clamp to reasonable range - shear = max(-3.0, min(3.0, shear)) - conf = max(0.0, min(1.0, conf)) - return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} - except Exception as e: - logger.warning(f"VLM dewarp failed: {e}") - - return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} - - -@router.post("/sessions/{session_id}/dewarp") -async def auto_dewarp( - session_id: str, - method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), -): - """Detect and correct vertical shear on the deskewed image. - - Methods: - - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) - - **cv**: CV ensemble only (same as ensemble) - - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually - """ - if method not in ("ensemble", "cv", "vlm"): - raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") - - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") - - t0 = time.time() - - if method == "vlm": - # Encode deskewed image to PNG for VLM - success, png_buf = cv2.imencode(".png", deskewed_bgr) - img_bytes = png_buf.tobytes() if success else b"" - vlm_det = await _detect_shear_with_vlm(img_bytes) - shear_deg = vlm_det["shear_degrees"] - if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: - dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) - else: - dewarped_bgr = deskewed_bgr - dewarp_info = { - "method": vlm_det["method"], - "shear_degrees": shear_deg, - "confidence": vlm_det["confidence"], - "detections": [vlm_det], - } - else: - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) - - duration = time.time() - t0 - - # Encode as PNG - success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - dewarp_result = { - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(duration, 2), - "detections": dewarp_info.get("detections", []), - } - - # Update cache - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - - # Persist to DB - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), - current_step=4, - ) - - logger.info(f"OCR Pipeline: dewarp session {session_id}: " - f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " - f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") - - await _append_pipeline_log(session_id, "dewarp", { - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "method": dewarp_info["method"], - "ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])], - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **dewarp_result, - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/dewarp/manual") -async def manual_dewarp(session_id: str, req: ManualDewarpRequest): - """Apply shear correction with a manual angle.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") - - shear_deg = max(-2.0, min(2.0, req.shear_degrees)) - - if abs(shear_deg) < 0.001: - dewarped_bgr = deskewed_bgr - else: - dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) - - success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - dewarp_result = { - **(cached.get("dewarp_result") or {}), - "method_used": "manual", - "shear_degrees": round(shear_deg, 3), - } - - # Update cache - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - - # Persist to DB - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - ) - - logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") - - return { - "session_id": session_id, - "shear_degrees": round(shear_deg, 3), - "method_used": "manual", - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/adjust-combined") -async def adjust_combined(session_id: str, req: CombinedAdjustRequest): - """Apply rotation + shear combined to the original image. - - Used by the fine-tuning sliders to preview arbitrary rotation/shear - combinations without re-running the full deskew/dewarp pipeline. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("original_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Original image not available") - - rotation = max(-15.0, min(15.0, req.rotation_degrees)) - shear_deg = max(-5.0, min(5.0, req.shear_degrees)) - - h, w = img_bgr.shape[:2] - result_bgr = img_bgr - - # Step 1: Apply rotation - if abs(rotation) >= 0.001: - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, rotation, 1.0) - result_bgr = cv2.warpAffine(result_bgr, M, (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE) - - # Step 2: Apply shear - if abs(shear_deg) >= 0.001: - result_bgr = dewarp_image_manual(result_bgr, shear_deg) - - # Encode - success, png_buf = cv2.imencode(".png", result_bgr) - dewarped_png = png_buf.tobytes() if success else b"" - - # Binarize - binarized_png = None - try: - binarized = create_ocr_image(result_bgr) - success_bin, bin_buf = cv2.imencode(".png", binarized) - binarized_png = bin_buf.tobytes() if success_bin else None - except Exception: - pass - - # Build combined result dicts - deskew_result = { - **(cached.get("deskew_result") or {}), - "angle_applied": round(rotation, 3), - "method_used": "manual_combined", - } - dewarp_result = { - **(cached.get("dewarp_result") or {}), - "method_used": "manual_combined", - "shear_degrees": round(shear_deg, 3), - } - - # Update cache - cached["deskewed_bgr"] = result_bgr - cached["dewarped_bgr"] = result_bgr - cached["deskew_result"] = deskew_result - cached["dewarp_result"] = dewarp_result - - # Persist to DB - db_update = { - "dewarped_png": dewarped_png, - "deskew_result": deskew_result, - "dewarp_result": dewarp_result, - } - if binarized_png: - db_update["binarized_png"] = binarized_png - db_update["deskewed_png"] = dewarped_png - await update_session_db(session_id, **db_update) - - logger.info(f"OCR Pipeline: combined adjust session {session_id}: " - f"rotation={rotation:.3f} shear={shear_deg:.3f}") - - return { - "session_id": session_id, - "rotation_degrees": round(rotation, 3), - "shear_degrees": round(shear_deg, 3), - "method_used": "manual_combined", - "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", - } - - -@router.post("/sessions/{session_id}/ground-truth/dewarp") -async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): - """Save ground truth feedback for the dewarp step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_shear": req.corrected_shear, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "dewarp_result": session.get("dewarp_result"), - } - ground_truth["dewarp"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " - f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") - - return {"session_id": session_id, "ground_truth": gt} - - -# --------------------------------------------------------------------------- -# Document Type Detection (between Dewarp and Columns) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/detect-type") -async def detect_type(session_id: str): - """Detect document type (vocab_table, full_text, generic_table). - - Should be called after crop (clean image available). - Falls back to dewarped if crop was skipped. - Stores result in session for frontend to decide pipeline flow. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") - - t0 = time.time() - ocr_img = create_ocr_image(img_bgr) - result = detect_document_type(ocr_img, img_bgr) - duration = time.time() - t0 - - result_dict = { - "doc_type": result.doc_type, - "confidence": result.confidence, - "pipeline": result.pipeline, - "skip_steps": result.skip_steps, - "features": result.features, - "duration_seconds": round(duration, 2), - } - - # Persist to DB - await update_session_db( - session_id, - doc_type=result.doc_type, - doc_type_result=result_dict, - ) - - cached["doc_type_result"] = result_dict - - logger.info(f"OCR Pipeline: detect-type session {session_id}: " - f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") - - await _append_pipeline_log(session_id, "detect_type", { - "doc_type": result.doc_type, - "pipeline": result.pipeline, - "confidence": result.confidence, - **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **result_dict} - - -# --------------------------------------------------------------------------- -# Border-ghost word filter -# --------------------------------------------------------------------------- - -# Characters that OCR produces when reading box-border lines. -_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"") - - -def _filter_border_ghost_words( - word_result: Dict, - boxes: List, -) -> int: - """Remove OCR words that are actually box border lines. - - A word is considered a border ghost when it sits on a known box edge - (left, right, top, or bottom) and looks like a line artefact (narrow - aspect ratio or text consists only of line-like characters). - - After removing ghost cells, columns that have become empty are also - removed from ``columns_used`` so the grid no longer shows phantom - columns. - - Modifies *word_result* in-place and returns the number of removed cells. - """ - if not boxes or not word_result: - return 0 - - cells = word_result.get("cells") - if not cells: - return 0 - - # Build border bands — vertical (X) and horizontal (Y) - x_bands = [] # list of (x_lo, x_hi) - y_bands = [] # list of (y_lo, y_hi) - for b in boxes: - bx = b.x if hasattr(b, "x") else b.get("x", 0) - by = b.y if hasattr(b, "y") else b.get("y", 0) - bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) - bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) - bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3) - margin = max(bt * 2, 10) + 6 # generous margin - - # Vertical edges (left / right) - x_bands.append((bx - margin, bx + margin)) - x_bands.append((bx + bw - margin, bx + bw + margin)) - # Horizontal edges (top / bottom) - y_bands.append((by - margin, by + margin)) - y_bands.append((by + bh - margin, by + bh + margin)) - - img_w = word_result.get("image_width", 1) - img_h = word_result.get("image_height", 1) - - def _is_ghost(cell: Dict) -> bool: - text = (cell.get("text") or "").strip() - if not text: - return False - - # Compute absolute pixel position - if cell.get("bbox_px"): - px = cell["bbox_px"] - cx = px["x"] + px["w"] / 2 - cy = px["y"] + px["h"] / 2 - cw = px["w"] - ch = px["h"] - elif cell.get("bbox_pct"): - pct = cell["bbox_pct"] - cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2 - cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2 - cw = (pct["w"] / 100) * img_w - ch = (pct["h"] / 100) * img_h - else: - return False - - # Check if center sits on a vertical or horizontal border - on_vertical = any(lo <= cx <= hi for lo, hi in x_bands) - on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands) - if not on_vertical and not on_horizontal: - return False - - # Very short text (1-2 chars) on a border → very likely ghost - if len(text) <= 2: - # Narrow vertically (line-like) or narrow horizontally (dash-like)? - if ch > 0 and cw / ch < 0.5: - return True - if cw > 0 and ch / cw < 0.5: - return True - # Text is only border-ghost characters? - if all(c in _BORDER_GHOST_CHARS for c in text): - return True - - # Longer text but still only ghost chars and very narrow - if all(c in _BORDER_GHOST_CHARS for c in text): - if ch > 0 and cw / ch < 0.35: - return True - if cw > 0 and ch / cw < 0.35: - return True - return True # all ghost chars on a border → remove - - return False - - before = len(cells) - word_result["cells"] = [c for c in cells if not _is_ghost(c)] - removed = before - len(word_result["cells"]) - - # --- Remove empty columns from columns_used --- - columns_used = word_result.get("columns_used") - if removed and columns_used and len(columns_used) > 1: - remaining_cells = word_result["cells"] - occupied_cols = {c.get("col_index") for c in remaining_cells} - before_cols = len(columns_used) - columns_used = [col for col in columns_used if col.get("index") in occupied_cols] - - # Re-index columns and remap cell col_index values - if len(columns_used) < before_cols: - old_to_new = {} - for new_i, col in enumerate(columns_used): - old_to_new[col["index"]] = new_i - col["index"] = new_i - for cell in remaining_cells: - old_ci = cell.get("col_index") - if old_ci in old_to_new: - cell["col_index"] = old_to_new[old_ci] - word_result["columns_used"] = columns_used - logger.info("border-ghost: removed %d empty column(s), %d remaining", - before_cols - len(columns_used), len(columns_used)) - - if removed: - # Update summary counts - summary = word_result.get("summary", {}) - summary["total_cells"] = len(word_result["cells"]) - summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text")) - word_result["summary"] = summary - gs = word_result.get("grid_shape", {}) - gs["total_cells"] = len(word_result["cells"]) - if columns_used is not None: - gs["cols"] = len(columns_used) - word_result["grid_shape"] = gs - - return removed - - -# --------------------------------------------------------------------------- -# Structure Detection Endpoint -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/detect-structure") -async def detect_structure(session_id: str): - """Detect document structure: boxes, zones, and color regions. - - Runs box detection (line + shading) and color analysis on the cropped - image. Returns structured JSON with all detected elements for the - structure visualization step. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = ( - cached.get("cropped_bgr") - if cached.get("cropped_bgr") is not None - else cached.get("dewarped_bgr") - ) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") - - t0 = time.time() - h, w = img_bgr.shape[:2] - - # --- Content bounds from word result (if available) or full image --- - word_result = cached.get("word_result") - words: List[Dict] = [] - if word_result and word_result.get("cells"): - for cell in word_result["cells"]: - for wb in (cell.get("word_boxes") or []): - words.append(wb) - # Fallback: use raw OCR words if cell word_boxes are empty - if not words and word_result: - for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"): - raw = word_result.get(key, []) - if raw: - words = raw - logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key) - break - # If no words yet, use image dimensions with small margin - if words: - content_x = max(0, min(int(wb["left"]) for wb in words)) - content_y = max(0, min(int(wb["top"]) for wb in words)) - content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words)) - content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words)) - content_w_px = content_r - content_x - content_h_px = content_b - content_y - else: - margin = int(min(w, h) * 0.03) - content_x, content_y = margin, margin - content_w_px = w - 2 * margin - content_h_px = h - 2 * margin - - # --- Box detection --- - boxes = detect_boxes( - img_bgr, - content_x=content_x, - content_w=content_w_px, - content_y=content_y, - content_h=content_h_px, - ) - - # --- Zone splitting --- - from cv_box_detect import split_page_into_zones as _split_zones - zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes) - - # --- Color region sampling --- - # Sample background shading in each detected box - hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) - box_colors = [] - for box in boxes: - # Sample the center region of each box - cy1 = box.y + box.height // 4 - cy2 = box.y + 3 * box.height // 4 - cx1 = box.x + box.width // 4 - cx2 = box.x + 3 * box.width // 4 - cy1 = max(0, min(cy1, h - 1)) - cy2 = max(0, min(cy2, h - 1)) - cx1 = max(0, min(cx1, w - 1)) - cx2 = max(0, min(cx2, w - 1)) - if cy2 > cy1 and cx2 > cx1: - roi_hsv = hsv[cy1:cy2, cx1:cx2] - med_h = float(np.median(roi_hsv[:, :, 0])) - med_s = float(np.median(roi_hsv[:, :, 1])) - med_v = float(np.median(roi_hsv[:, :, 2])) - if med_s > 15: - from cv_color_detect import _hue_to_color_name - bg_name = _hue_to_color_name(med_h) - bg_hex = _COLOR_HEX.get(bg_name, "#6b7280") - else: - bg_name = "gray" if med_v < 220 else "white" - bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff" - else: - bg_name = "unknown" - bg_hex = "#6b7280" - box_colors.append({"color_name": bg_name, "color_hex": bg_hex}) - - # --- Color text detection overview --- - # Quick scan for colored text regions across the page - color_summary: Dict[str, int] = {} - for color_name, ranges in _COLOR_RANGES.items(): - mask = np.zeros((h, w), dtype=np.uint8) - for lower, upper in ranges: - mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) - pixel_count = int(np.sum(mask > 0)) - if pixel_count > 50: # minimum threshold - color_summary[color_name] = pixel_count - - # --- Graphic element detection --- - box_dicts = [ - {"x": b.x, "y": b.y, "w": b.width, "h": b.height} - for b in boxes - ] - graphics = detect_graphic_elements( - img_bgr, words, - detected_boxes=box_dicts, - ) - - # --- Filter border-ghost words from OCR result --- - ghost_count = 0 - if boxes and word_result: - ghost_count = _filter_border_ghost_words(word_result, boxes) - if ghost_count: - logger.info("detect-structure: removed %d border-ghost words", ghost_count) - await update_session_db(session_id, word_result=word_result) - cached["word_result"] = word_result - - duration = time.time() - t0 - - result_dict = { - "image_width": w, - "image_height": h, - "content_bounds": { - "x": content_x, "y": content_y, - "w": content_w_px, "h": content_h_px, - }, - "boxes": [ - { - "x": b.x, "y": b.y, "w": b.width, "h": b.height, - "confidence": b.confidence, - "border_thickness": b.border_thickness, - "bg_color_name": box_colors[i]["color_name"], - "bg_color_hex": box_colors[i]["color_hex"], - } - for i, b in enumerate(boxes) - ], - "zones": [ - { - "index": z.index, - "zone_type": z.zone_type, - "y": z.y, "h": z.height, - "x": z.x, "w": z.width, - } - for z in zones - ], - "graphics": [ - { - "x": g.x, "y": g.y, "w": g.width, "h": g.height, - "area": g.area, - "shape": g.shape, - "color_name": g.color_name, - "color_hex": g.color_hex, - "confidence": round(g.confidence, 2), - } - for g in graphics - ], - "color_pixel_counts": color_summary, - "has_words": len(words) > 0, - "word_count": len(words), - "border_ghosts_removed": ghost_count, - "duration_seconds": round(duration, 2), - } - - # Persist to session - await update_session_db(session_id, structure_result=result_dict) - cached["structure_result"] = result_dict - - logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs", - session_id, len(boxes), len(zones), len(graphics), duration) - - return {"session_id": session_id, **result_dict} - - -# --------------------------------------------------------------------------- -# Column Detection Endpoints (Step 3) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/columns") -async def detect_columns(session_id: str): - """Run column detection on the cropped (or dewarped) image.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection") - - # ----------------------------------------------------------------------- - # Sub-sessions (box crops): skip column detection entirely. - # Instead, create a single pseudo-column spanning the full image width. - # Also run Tesseract + binarization here so that the row detection step - # can reuse the cached intermediates (_word_dicts, _inv, _content_bounds) - # instead of falling back to detect_column_geometry() which may fail - # on small box images with < 5 words. - # ----------------------------------------------------------------------- - session = await get_session_db(session_id) - if session and session.get("parent_session_id"): - h, w = img_bgr.shape[:2] - - # Binarize + invert for row detection (horizontal projection profile) - ocr_img = create_ocr_image(img_bgr) - inv = cv2.bitwise_not(ocr_img) - - # Run Tesseract to get word bounding boxes. - # Word positions are relative to the full image (no ROI crop needed - # because the sub-session image IS the cropped box already). - # detect_row_geometry expects word positions relative to content ROI, - # so with content_bounds = (0, w, 0, h) the coordinates are correct. - try: - from PIL import Image as PILImage - pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - import pytesseract - data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT) - word_dicts = [] - for i in range(len(data['text'])): - conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1 - text = str(data['text'][i]).strip() - if conf < 30 or not text: - continue - word_dicts.append({ - 'text': text, 'conf': conf, - 'left': int(data['left'][i]), - 'top': int(data['top'][i]), - 'width': int(data['width'][i]), - 'height': int(data['height'][i]), - }) - # Log all words including low-confidence ones for debugging - all_count = sum(1 for i in range(len(data['text'])) - if str(data['text'][i]).strip()) - low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) - for i in range(len(data['text'])) - if str(data['text'][i]).strip() - and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30 - and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0] - if low_conf: - logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}") - logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)") - except Exception as e: - logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}") - word_dicts = [] - - # Cache intermediates for row detection (detect_rows reuses these) - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (0, w, 0, h) - - column_result = { - "columns": [{ - "type": "column_text", - "x": 0, "y": 0, - "width": w, "height": h, - }], - "zones": None, - "boxes_detected": 0, - "duration_seconds": 0, - "method": "sub_session_pseudo_column", - } - await update_session_db( - session_id, - column_result=column_result, - row_result=None, - word_result=None, - current_step=6, - ) - cached["column_result"] = column_result - cached.pop("row_result", None) - cached.pop("word_result", None) - logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px") - return {"session_id": session_id, **column_result} - - t0 = time.time() - - # Binarized image for layout analysis - ocr_img = create_ocr_image(img_bgr) - h, w = ocr_img.shape[:2] - - # Phase A: Zone-aware geometry detection - zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr) - - if zoned_result is None: - # Fallback to projection-based layout - layout_img = create_layout_image(img_bgr) - regions = analyze_layout(layout_img, ocr_img) - zones_data = None - boxes_detected = 0 - else: - geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result - content_w = right_x - left_x - boxes_detected = len(boxes) - - # Cache intermediates for row detection (avoids second Tesseract run) - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - cached["_zones_data"] = zones_data - cached["_boxes_detected"] = boxes_detected - - # Detect header/footer early so sub-column clustering ignores them - header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) - - # Split sub-columns (e.g. page references) before classification - geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, - top_y=top_y, header_y=header_y, footer_y=footer_y) - - # Expand narrow columns (sub-columns are often very narrow) - geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts) - - # Phase B: Content-based classification - regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, - left_x=left_x, right_x=right_x, inv=inv) - - duration = time.time() - t0 - - columns = [asdict(r) for r in regions] - - # Determine classification methods used - methods = list(set( - c.get("classification_method", "") for c in columns - if c.get("classification_method") - )) - - column_result = { - "columns": columns, - "classification_methods": methods, - "duration_seconds": round(duration, 2), - "boxes_detected": boxes_detected, - } - - # Add zone data when boxes are present - if zones_data and boxes_detected > 0: - column_result["zones"] = zones_data - - # Persist to DB — also invalidate downstream results (rows, words) - await update_session_db( - session_id, - column_result=column_result, - row_result=None, - word_result=None, - current_step=6, - ) - - # Update cache - cached["column_result"] = column_result - cached.pop("row_result", None) - cached.pop("word_result", None) - - col_count = len([c for c in columns if c["type"].startswith("column")]) - logger.info(f"OCR Pipeline: columns session {session_id}: " - f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)") - - img_w = img_bgr.shape[1] - await _append_pipeline_log(session_id, "columns", { - "total_columns": len(columns), - "column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns], - "column_types": [c["type"] for c in columns], - "boxes_detected": boxes_detected, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **column_result, - } - - -@router.post("/sessions/{session_id}/columns/manual") -async def set_manual_columns(session_id: str, req: ManualColumnsRequest): - """Override detected columns with manual definitions.""" - column_result = { - "columns": req.columns, - "duration_seconds": 0, - "method": "manual", - } - - await update_session_db(session_id, column_result=column_result, - row_result=None, word_result=None) - - if session_id in _cache: - _cache[session_id]["column_result"] = column_result - _cache[session_id].pop("row_result", None) - _cache[session_id].pop("word_result", None) - - logger.info(f"OCR Pipeline: manual columns session {session_id}: " - f"{len(req.columns)} columns set") - - return {"session_id": session_id, **column_result} - - -@router.post("/sessions/{session_id}/ground-truth/columns") -async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): - """Save ground truth feedback for the column detection step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_columns": req.corrected_columns, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "column_result": session.get("column_result"), - } - ground_truth["columns"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@router.get("/sessions/{session_id}/ground-truth/columns") -async def get_column_ground_truth(session_id: str): - """Retrieve saved ground truth for column detection, including auto vs GT diff.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - columns_gt = ground_truth.get("columns") - if not columns_gt: - raise HTTPException(status_code=404, detail="No column ground truth saved") - - return { - "session_id": session_id, - "columns_gt": columns_gt, - "columns_auto": session.get("column_result"), - } - - -def _draw_box_exclusion_overlay( - img: np.ndarray, - zones: List[Dict], - *, - label: str = "BOX — separat verarbeitet", -) -> None: - """Draw red semi-transparent rectangles over box zones (in-place). - - Reusable for columns, rows, and words overlays. - """ - for zone in zones: - if zone.get("zone_type") != "box" or not zone.get("box"): - continue - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - - # Red semi-transparent fill (~25 %) - box_overlay = img.copy() - cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1) - cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img) - - # Border - cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2) - - # Label - cv2.putText(img, label, (bx + 10, by + bh - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) - - -async def _get_structure_overlay(session_id: str) -> Response: - """Generate overlay image showing detected boxes, zones, and color regions.""" - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - h, w = img.shape[:2] - - # Get structure result (run detection if not cached) - session = await get_session_db(session_id) - structure = (session or {}).get("structure_result") - - if not structure: - # Run detection on-the-fly - margin = int(min(w, h) * 0.03) - content_x, content_y = margin, margin - content_w_px = w - 2 * margin - content_h_px = h - 2 * margin - boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px) - zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes) - structure = { - "boxes": [ - {"x": b.x, "y": b.y, "w": b.width, "h": b.height, - "confidence": b.confidence, "border_thickness": b.border_thickness} - for b in boxes - ], - "zones": [ - {"index": z.index, "zone_type": z.zone_type, - "y": z.y, "h": z.height, "x": z.x, "w": z.width} - for z in zones - ], - } - - overlay = img.copy() - - # --- Draw zone boundaries --- - zone_colors = { - "content": (200, 200, 200), # light gray - "box": (255, 180, 0), # blue-ish (BGR) - } - for zone in structure.get("zones", []): - zx = zone["x"] - zy = zone["y"] - zw = zone["w"] - zh = zone["h"] - color = zone_colors.get(zone["zone_type"], (200, 200, 200)) - - # Draw zone boundary as dashed line - dash_len = 12 - for edge_x in range(zx, zx + zw, dash_len * 2): - end_x = min(edge_x + dash_len, zx + zw) - cv2.line(img, (edge_x, zy), (end_x, zy), color, 1) - cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1) - - # Zone label - zone_label = f"Zone {zone['index']} ({zone['zone_type']})" - cv2.putText(img, zone_label, (zx + 5, zy + 15), - cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1) - - # --- Draw detected boxes --- - # Color map for box backgrounds (BGR) - bg_hex_to_bgr = { - "#dc2626": (38, 38, 220), # red - "#2563eb": (235, 99, 37), # blue - "#16a34a": (74, 163, 22), # green - "#ea580c": (12, 88, 234), # orange - "#9333ea": (234, 51, 147), # purple - "#ca8a04": (4, 138, 202), # yellow - "#6b7280": (128, 114, 107), # gray - } - - for box_data in structure.get("boxes", []): - bx = box_data["x"] - by = box_data["y"] - bw = box_data["w"] - bh = box_data["h"] - conf = box_data.get("confidence", 0) - thickness = box_data.get("border_thickness", 0) - bg_hex = box_data.get("bg_color_hex", "#6b7280") - bg_name = box_data.get("bg_color_name", "") - - # Box fill color - fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107)) - - # Semi-transparent fill - cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1) - - # Solid border - border_color = fill_bgr - cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3) - - # Label - label = f"BOX" - if bg_name and bg_name not in ("unknown", "white"): - label += f" ({bg_name})" - if thickness > 0: - label += f" border={thickness}px" - label += f" {int(conf * 100)}%" - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2) - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # --- Draw color regions (HSV masks) --- - hsv = cv2.cvtColor( - cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR), - cv2.COLOR_BGR2HSV, - ) - color_bgr_map = { - "red": (0, 0, 255), - "orange": (0, 140, 255), - "yellow": (0, 200, 255), - "green": (0, 200, 0), - "blue": (255, 150, 0), - "purple": (200, 0, 200), - } - for color_name, ranges in _COLOR_RANGES.items(): - mask = np.zeros((h, w), dtype=np.uint8) - for lower, upper in ranges: - mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) - # Only draw if there are significant colored pixels - if np.sum(mask > 0) < 100: - continue - # Draw colored contours - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - draw_color = color_bgr_map.get(color_name, (200, 200, 200)) - for cnt in contours: - area = cv2.contourArea(cnt) - if area < 20: - continue - cv2.drawContours(img, [cnt], -1, draw_color, 2) - - # --- Draw graphic elements --- - graphics_data = structure.get("graphics", []) - shape_icons = { - "image": "IMAGE", - "illustration": "ILLUST", - } - for gfx in graphics_data: - gx, gy = gfx["x"], gfx["y"] - gw, gh = gfx["w"], gfx["h"] - shape = gfx.get("shape", "icon") - color_hex = gfx.get("color_hex", "#6b7280") - conf = gfx.get("confidence", 0) - - # Pick draw color based on element color (BGR) - gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) - - # Draw bounding box (dashed style via short segments) - dash = 6 - for seg_x in range(gx, gx + gw, dash * 2): - end_x = min(seg_x + dash, gx + gw) - cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) - cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) - for seg_y in range(gy, gy + gh, dash * 2): - end_y = min(seg_y + dash, gy + gh) - cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) - cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) - - # Label - icon = shape_icons.get(shape, shape.upper()[:5]) - label = f"{icon} {int(conf * 100)}%" - # White background for readability - (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) - lx = gx + 2 - ly = max(gy - 4, th + 4) - cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) - cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) - - # Encode result - _, png_buf = cv2.imencode(".png", img) - return Response(content=png_buf.tobytes(), media_type="image/png") - - -async def _get_columns_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with column borders drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result or not column_result.get("columns"): - raise HTTPException(status_code=404, detail="No column data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for region types (BGR) - colors = { - "column_en": (255, 180, 0), # Blue - "column_de": (0, 200, 0), # Green - "column_example": (0, 140, 255), # Orange - "column_text": (200, 200, 0), # Cyan/Turquoise - "page_ref": (200, 0, 200), # Purple - "column_marker": (0, 0, 220), # Red - "column_ignore": (180, 180, 180), # Light Gray - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for col in column_result["columns"]: - x, y = col["x"], col["y"] - w, h = col["width"], col["height"] - color = colors.get(col.get("type", ""), (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) - - # Label with confidence - label = col.get("type", "unknown").replace("column_", "").upper() - conf = col.get("classification_confidence") - if conf is not None and conf < 1.0: - label = f"{label} {int(conf * 100)}%" - cv2.putText(img, label, (x + 10, y + 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) - - # Blend overlay at 20% opacity - cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) - - # Draw detected box boundaries as dashed rectangles - zones = column_result.get("zones") or [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - box_color = (0, 200, 255) # Yellow (BGR) - # Draw dashed rectangle by drawing short line segments - dash_len = 15 - for edge_x in range(bx, bx + bw, dash_len * 2): - end_x = min(edge_x + dash_len, bx + bw) - cv2.line(img, (edge_x, by), (end_x, by), box_color, 2) - cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2) - for edge_y in range(by, by + bh, dash_len * 2): - end_y = min(edge_y + dash_len, by + bh) - cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2) - cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2) - cv2.putText(img, "BOX", (bx + 10, by + bh - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -# --------------------------------------------------------------------------- -# Row Detection Endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/rows") -async def detect_rows(session_id: str): - """Run row detection on the cropped (or dewarped) image using horizontal gap analysis.""" - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if dewarped_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection") - - t0 = time.time() - - # Try to reuse cached word_dicts and inv from column detection - word_dicts = cached.get("_word_dicts") - inv = cached.get("_inv") - content_bounds = cached.get("_content_bounds") - - if word_dicts is None or inv is None or content_bounds is None: - # Not cached — run column geometry to get intermediates - ocr_img = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img, dewarped_bgr) - if geo_result is None: - raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows") - _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - else: - left_x, right_x, top_y, bottom_y = content_bounds - - # Read zones from column_result to exclude box regions - session = await get_session_db(session_id) - column_result = (session or {}).get("column_result") or {} - is_sub_session = bool((session or {}).get("parent_session_id")) - - # Sub-sessions (box crops): use word-grouping instead of gap-based - # row detection. Box images are small with complex internal layouts - # (headings, sub-columns) where the horizontal projection approach - # merges rows. Word-grouping directly clusters words by Y proximity, - # which is more robust for these cases. - if is_sub_session and word_dicts: - from cv_layout import _build_rows_from_word_grouping - rows = _build_rows_from_word_grouping( - word_dicts, left_x, right_x, top_y, bottom_y, - right_x - left_x, bottom_y - top_y, - ) - logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows") - else: - zones = column_result.get("zones") or [] # zones can be None for sub-sessions - - # Collect box y-ranges for filtering. - # Use border_thickness to shrink the exclusion zone: the border pixels - # belong visually to the box frame, but text rows above/below the box - # may overlap with the border area and must not be clipped. - box_ranges = [] # [(y_start, y_end)] - box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin - box_ranges.append((box["y"], box["y"] + box["height"])) - # Inner range: shrink by border thickness so boundary rows aren't excluded - box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) - - if box_ranges and inv is not None: - # Combined-image approach: strip box regions from inv image, - # run row detection on the combined image, then remap y-coords back. - content_strips = [] # [(y_start, y_end)] in absolute coords - # Build content strips by subtracting box inner ranges from [top_y, bottom_y]. - # Using inner ranges means the border area is included in the content - # strips, so the last row above a box isn't clipped by the border. - sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0]) - strip_start = top_y - for by_start, by_end in sorted_boxes: - if by_start > strip_start: - content_strips.append((strip_start, by_start)) - strip_start = max(strip_start, by_end) - if strip_start < bottom_y: - content_strips.append((strip_start, bottom_y)) - - # Filter to strips with meaningful height - content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20] - - if content_strips: - # Stack content strips vertically - inv_strips = [inv[ys:ye, :] for ys, ye in content_strips] - combined_inv = np.vstack(inv_strips) - - # Filter word_dicts to only include words from content strips - combined_words = [] - cum_y = 0 - strip_offsets = [] # (combined_y_start, strip_height, abs_y_start) - for ys, ye in content_strips: - h = ye - ys - strip_offsets.append((cum_y, h, ys)) - for w in word_dicts: - w_abs_y = w['top'] + top_y # word y is relative to content top - w_center = w_abs_y + w['height'] / 2 - if ys <= w_center < ye: - # Remap to combined coordinates - w_copy = dict(w) - w_copy['top'] = cum_y + (w_abs_y - ys) - combined_words.append(w_copy) - cum_y += h - - # Run row detection on combined image - combined_h = combined_inv.shape[0] - rows = detect_row_geometry( - combined_inv, combined_words, left_x, right_x, 0, combined_h, - ) - - # Remap y-coordinates back to absolute page coords - def _combined_y_to_abs(cy: int) -> int: - for c_start, s_h, abs_start in strip_offsets: - if cy < c_start + s_h: - return abs_start + (cy - c_start) - last_c, last_h, last_abs = strip_offsets[-1] - return last_abs + last_h - - for r in rows: - abs_y = _combined_y_to_abs(r.y) - abs_y_end = _combined_y_to_abs(r.y + r.height) - r.y = abs_y - r.height = abs_y_end - abs_y - else: - rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - else: - # No boxes — standard row detection - rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - - duration = time.time() - t0 - - # Assign zone_index based on which content zone each row falls in - # Build content zone list with indices - zones = column_result.get("zones") or [] - content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else [] - - # Build serializable result (exclude words to keep payload small) - rows_data = [] - for r in rows: - # Determine zone_index - zone_idx = 0 - row_center_y = r.y + r.height / 2 - for zi, zone in content_zones: - zy = zone["y"] - zh = zone["height"] - if zy <= row_center_y < zy + zh: - zone_idx = zi - break - - rd = { - "index": r.index, - "x": r.x, - "y": r.y, - "width": r.width, - "height": r.height, - "word_count": r.word_count, - "row_type": r.row_type, - "gap_before": r.gap_before, - "zone_index": zone_idx, - } - rows_data.append(rd) - - type_counts = {} - for r in rows: - type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1 - - row_result = { - "rows": rows_data, - "summary": type_counts, - "total_rows": len(rows), - "duration_seconds": round(duration, 2), - } - - # Persist to DB — also invalidate word_result since rows changed - await update_session_db( - session_id, - row_result=row_result, - word_result=None, - current_step=7, - ) - - cached["row_result"] = row_result - cached.pop("word_result", None) - - logger.info(f"OCR Pipeline: rows session {session_id}: " - f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}") - - content_rows = sum(1 for r in rows if r.row_type == "content") - avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0 - await _append_pipeline_log(session_id, "rows", { - "total_rows": len(rows), - "content_rows": content_rows, - "artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0), - "avg_row_height_px": avg_height, - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **row_result, - } - - -@router.post("/sessions/{session_id}/rows/manual") -async def set_manual_rows(session_id: str, req: ManualRowsRequest): - """Override detected rows with manual definitions.""" - row_result = { - "rows": req.rows, - "total_rows": len(req.rows), - "duration_seconds": 0, - "method": "manual", - } - - await update_session_db(session_id, row_result=row_result, word_result=None) - - if session_id in _cache: - _cache[session_id]["row_result"] = row_result - _cache[session_id].pop("word_result", None) - - logger.info(f"OCR Pipeline: manual rows session {session_id}: " - f"{len(req.rows)} rows set") - - return {"session_id": session_id, **row_result} - - -@router.post("/sessions/{session_id}/ground-truth/rows") -async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest): - """Save ground truth feedback for the row detection step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_rows": req.corrected_rows, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "row_result": session.get("row_result"), - } - ground_truth["rows"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@router.get("/sessions/{session_id}/ground-truth/rows") -async def get_row_ground_truth(session_id: str): - """Retrieve saved ground truth for row detection.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - rows_gt = ground_truth.get("rows") - if not rows_gt: - raise HTTPException(status_code=404, detail="No row ground truth saved") - - return { - "session_id": session_id, - "rows_gt": rows_gt, - "rows_auto": session.get("row_result"), - } - - -# --------------------------------------------------------------------------- -# Word Recognition Endpoints (Step 5) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/words") -async def detect_words( - session_id: str, - request: Request, - engine: str = "auto", - pronunciation: str = "british", - stream: bool = False, - skip_heal_gaps: bool = False, - grid_method: str = "v2", -): - """Build word grid from columns × rows, OCR each cell. - - Query params: - engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' - pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup - stream: false (default) for JSON response, true for SSE streaming - skip_heal_gaps: false (default). When true, cells keep exact row geometry - positions without gap-healing expansion. Better for overlay rendering. - grid_method: 'v2' (default) or 'words_first' — grid construction strategy. - 'v2' uses pre-detected columns/rows (top-down). - 'words_first' clusters words bottom-up (no column/row detection needed). - """ - # PaddleOCR is full-page remote OCR → force words_first grid method - if engine == "paddle" and grid_method != "words_first": - logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) - grid_method = "words_first" - - if session_id not in _cache: - logger.info("detect_words: session %s not in cache, loading from DB", session_id) - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if dewarped_bgr is None: - logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", - session_id, [k for k in cached.keys() if k.endswith('_bgr')]) - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - row_result = session.get("row_result") - if not column_result or not column_result.get("columns"): - # No column detection — synthesize a single full-page pseudo-column. - # This enables the overlay pipeline which skips column detection. - img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] - column_result = { - "columns": [{ - "type": "column_text", - "x": 0, "y": 0, - "width": img_w_tmp, "height": img_h_tmp, - "classification_confidence": 1.0, - "classification_method": "full_page_fallback", - }], - "zones": [], - "duration_seconds": 0, - } - logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) - if grid_method != "words_first" and (not row_result or not row_result.get("rows")): - raise HTTPException(status_code=400, detail="Row detection must be completed first") - - # Convert column dicts back to PageRegion objects - col_regions = [ - PageRegion( - type=c["type"], - x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - # Convert row dicts back to RowGeometry objects - row_geoms = [ - RowGeometry( - index=r["index"], - x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), - words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - # Cell-First OCR (v2): no full-page word re-population needed. - # Each cell is cropped and OCR'd in isolation → no neighbour bleeding. - # We still need word_count > 0 for row filtering in build_cell_grid_v2, - # so populate from cached words if available (just for counting). - word_dicts = cached.get("_word_dicts") - if word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if word_dicts: - content_bounds = cached.get("_content_bounds") - if content_bounds: - _lx, _rx, top_y, _by = content_bounds - else: - top_y = min(r.y for r in row_geoms) if row_geoms else 0 - - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - # Exclude rows that fall within box zones. - # Use inner box range (shrunk by border_thickness) so that rows at - # the boundary (overlapping with the box border) are NOT excluded. - zones = column_result.get("zones") or [] - box_ranges_inner = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin - box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) - - if box_ranges_inner: - def _row_in_box(r): - center_y = r.y + r.height / 2 - return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) - - before_count = len(row_geoms) - row_geoms = [r for r in row_geoms if not _row_in_box(r)] - excluded = before_count - len(row_geoms) - if excluded: - logger.info(f"detect_words: excluded {excluded} rows inside box zones") - - # --- Words-First path: bottom-up grid from word boxes --- - if grid_method == "words_first": - t0 = time.time() - img_h, img_w = dewarped_bgr.shape[:2] - - # For paddle engine: run remote PaddleOCR full-page instead of Tesseract - if engine == "paddle": - from cv_ocr_engines import ocr_region_paddle - - wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) - # PaddleOCR returns absolute coordinates, no content_bounds offset needed - cached["_paddle_word_dicts"] = wf_word_dicts - else: - # Get word_dicts from cache or run Tesseract full-page - wf_word_dicts = cached.get("_word_dicts") - if wf_word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result - cached["_word_dicts"] = wf_word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if not wf_word_dicts: - raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid") - - # Convert word coordinates to absolute image coordinates if needed - # (detect_column_geometry returns words relative to content ROI) - # PaddleOCR already returns absolute coordinates — skip offset. - if engine != "paddle": - content_bounds = cached.get("_content_bounds") - if content_bounds: - lx, _rx, ty, _by = content_bounds - abs_words = [] - for w in wf_word_dicts: - abs_words.append({ - **w, - 'left': w['left'] + lx, - 'top': w['top'] + ty, - }) - wf_word_dicts = abs_words - - # Extract box rects for box-aware column clustering - box_rects = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box_rects.append(zone["box"]) - - cells, columns_meta = build_grid_from_words( - wf_word_dicts, img_w, img_h, box_rects=box_rects or None, - ) - duration = time.time() - t0 - - # Apply IPA phonetic fixes - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # Add zone_index for backward compat - for cell in cells: - cell.setdefault("zone_index", 0) - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - used_engine = "paddle" if engine == "paddle" else "words_first" - - word_result = { - "cells": cells, - "grid_shape": { - "rows": n_rows, - "cols": n_cols, - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "grid_method": "words_first", - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - if is_vocab or 'column_text' in col_types: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words-first session {session_id}: " - f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") - - await _append_pipeline_log(session_id, "words", { - "grid_method": "words_first", - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - if stream: - # Cell-First OCR v2: use batch-then-stream approach instead of - # per-cell streaming. The parallel ThreadPoolExecutor in - # build_cell_grid_v2 is much faster than sequential streaming. - return StreamingResponse( - _word_batch_stream_generator( - session_id, cached, col_regions, row_geoms, - dewarped_bgr, engine, pronunciation, request, - skip_heal_gaps=skip_heal_gaps, - ), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - # --- Non-streaming path (grid_method=v2) --- - t0 = time.time() - - # Create binarized OCR image (for Tesseract) - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - # Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation - cells, columns_meta = build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ) - duration = time.time() - t0 - - # Add zone_index to each cell (default 0 for backward compatibility) - for cell in cells: - cell.setdefault("zone_index", 0) - - # Layout detection - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - - # Count content rows and columns for grid_shape - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len(columns_meta) - - # Determine which engine was actually used - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - # Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # Grid result (always generic) - word_result = { - "cells": cells, - "grid_shape": { - "rows": n_content_rows, - "cols": n_cols, - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - # For vocab layout or single-column (box sub-sessions): map cells 1:1 - # to vocab entries (row→entry). - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - # Persist to DB - await update_session_db( - session_id, - word_result=word_result, - current_step=8, - ) - - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") - - await _append_pipeline_log(session_id, "words", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "low_confidence_count": word_result["summary"]["low_confidence"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - "entry_count": word_result.get("entry_count", 0), - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **word_result, - } - - -async def _word_batch_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, - skip_heal_gaps: bool = False, -): - """SSE generator that runs batch OCR (parallel) then streams results. - - Unlike the old per-cell streaming, this uses build_cell_grid_v2 with - ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events. - The 'preparing' event keeps the connection alive during OCR processing. - """ - import asyncio - - t0 = time.time() - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - total_cells = n_content_rows * n_cols - - # 1. Send meta event immediately - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - # 2. Send preparing event (keepalive for proxy) - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" - - # 3. Run batch OCR in thread pool with periodic keepalive events. - # The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE - # connections after 30-60s. Send keepalive every 5s to prevent this. - loop = asyncio.get_event_loop() - ocr_future = loop.run_in_executor( - None, - lambda: build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ), - ) - - # Send keepalive events every 5 seconds while OCR runs - keepalive_count = 0 - while not ocr_future.done(): - try: - cells, columns_meta = await asyncio.wait_for( - asyncio.shield(ocr_future), timeout=5.0, - ) - break # OCR finished - except asyncio.TimeoutError: - keepalive_count += 1 - elapsed = int(time.time() - t0) - yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected during OCR for {session_id}") - ocr_future.cancel() - return - else: - cells, columns_meta = ocr_future.result() - - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected after OCR for {session_id}") - return - - # 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # 5. Send columns meta - if columns_meta: - yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" - - # 6. Stream all cells - for idx, cell in enumerate(cells): - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": idx + 1, "total": len(cells)}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # 6. Build final result and persist - duration = time.time() - t0 - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " - f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") - - # 7. Send complete event - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" - - -async def _word_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, -): - """SSE generator that yields cell-by-cell OCR progress.""" - t0 = time.time() - - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - # Compute grid shape upfront for the meta event - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - - # Determine layout - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - - # Start streaming — first event: meta - columns_meta = None # will be set from first yield - total_cells = n_content_rows * n_cols - - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - # Keepalive: send preparing event so proxy doesn't timeout during OCR init - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" - - # Stream cells one by one - all_cells: List[Dict[str, Any]] = [] - cell_idx = 0 - last_keepalive = time.time() - - for cell, cols_meta, total in build_cell_grid_v2_streaming( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - ): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during streaming for {session_id}") - return - - if columns_meta is None: - columns_meta = cols_meta - # Send columns_used as part of first cell or update meta - meta_update = { - "type": "columns", - "columns_used": cols_meta, - } - yield f"data: {json.dumps(meta_update)}\n\n" - - all_cells.append(cell) - cell_idx += 1 - - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": cell_idx, "total": total}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # All cells done — build final result - duration = time.time() - t0 - if columns_meta is None: - columns_meta = [] - - # Post-OCR: remove rows where ALL cells are empty (inter-row gaps - # that had stray Tesseract artifacts giving word_count > 0). - rows_with_text: set = set() - for c in all_cells: - if c.get("text", "").strip(): - rows_with_text.add(c["row_index"]) - before_filter = len(all_cells) - all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] - empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) - if empty_rows_removed > 0: - logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") - - used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine - - # Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(all_cells, pronunciation=pronunciation) - - word_result = { - "cells": all_cells, - "grid_shape": { - "rows": n_content_rows, - "cols": n_cols, - "total_cells": len(all_cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(all_cells), - "non_empty_cells": sum(1 for c in all_cells if c.get("text")), - "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), - }, - } - - # For vocab layout or single-column (box sub-sessions): map cells 1:1 - # to vocab entries (row→entry). - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(all_cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - # Persist to DB - await update_session_db( - session_id, - word_result=word_result, - current_step=8, - ) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(all_cells)} cells ({duration:.2f}s)") - - # Final complete event - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" - - -@router.post("/sessions/{session_id}/paddle-direct") -async def paddle_direct(session_id: str): - """Run PaddleOCR on the preprocessed image and build a word grid directly. - - Expects orientation/deskew/dewarp/crop to be done already. - Uses the cropped image (falls back to dewarped, then original). - The used image is stored as cropped_png so OverlayReconstruction - can display it as the background. - """ - # Try preprocessed images first (crop > dewarp > original) - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode original image") - - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_paddle - - t0 = time.time() - word_dicts = await ocr_region_paddle(img_bgr, region=None) - if not word_dicts: - raise HTTPException(status_code=400, detail="PaddleOCR returned no words") - - # Reuse build_grid_from_words — same function that works in the regular - # pipeline with PaddleOCR (engine=paddle, grid_method=words_first). - # Handles phrase splitting, column clustering, and reading order. - cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h) - duration = time.time() - t0 - - # Tag cells as paddle_direct - for cell in cells: - cell["ocr_engine"] = "paddle_direct" - - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": { - "rows": n_rows, - "cols": n_cols, - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "paddle_direct", - "grid_method": "paddle_direct", - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - # Store preprocessed image as cropped_png so OverlayReconstruction shows it - await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, - ) - - logger.info( - "paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs", - session_id, len(cells), n_rows, n_cols, duration, - ) - - await _append_pipeline_log(session_id, "paddle_direct", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "ocr_engine": "paddle_direct", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -def _split_paddle_multi_words(words: list) -> list: - """Split PaddleOCR multi-word boxes into individual word boxes. - - PaddleOCR often returns entire phrases as a single box, e.g. - "More than 200 singers took part in the" with one bounding box. - This splits them into individual words with proportional widths. - Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"]) - and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]). - """ - import re - - result = [] - for w in words: - raw_text = w.get("text", "").strip() - if not raw_text: - continue - # Split on whitespace, before "[" (IPA), and after "!" before letter - tokens = re.split( - r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text - ) - tokens = [t for t in tokens if t] - - if len(tokens) <= 1: - result.append(w) - else: - # Split proportionally by character count - total_chars = sum(len(t) for t in tokens) - if total_chars == 0: - continue - n_gaps = len(tokens) - 1 - gap_px = w["width"] * 0.02 - usable_w = w["width"] - gap_px * n_gaps - cursor = w["left"] - for t in tokens: - token_w = max(1, usable_w * len(t) / total_chars) - result.append({ - "text": t, - "left": round(cursor), - "top": w["top"], - "width": round(token_w), - "height": w["height"], - "conf": w.get("conf", 0), - }) - cursor += token_w + gap_px - return result - - -def _group_words_into_rows(words: list, row_gap: int = 12) -> list: - """Group words into rows by Y-position clustering. - - Words whose vertical centers are within `row_gap` pixels are on the same row. - Returns list of rows, each row is a list of words sorted left-to-right. - """ - if not words: - return [] - # Sort by vertical center - sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) - rows: list = [] - current_row: list = [sorted_words[0]] - current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 - - for w in sorted_words[1:]: - cy = w["top"] + w.get("height", 0) / 2 - if abs(cy - current_cy) <= row_gap: - current_row.append(w) - else: - # Sort current row left-to-right before saving - rows.append(sorted(current_row, key=lambda w: w["left"])) - current_row = [w] - current_cy = cy - if current_row: - rows.append(sorted(current_row, key=lambda w: w["left"])) - return rows - - -def _row_center_y(row: list) -> float: - """Average vertical center of a row of words.""" - if not row: - return 0.0 - return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) - - -def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: - """Merge two word sequences from the same row using sequence alignment. - - Both sequences are sorted left-to-right. Walk through both simultaneously: - - If words match (same/similar text): take Paddle text with averaged coords - - If they don't match: the extra word is unique to one engine, include it - - This prevents duplicates because both engines produce words in the same order. - """ - merged = [] - pi, ti = 0, 0 - - while pi < len(paddle_row) and ti < len(tess_row): - pw = paddle_row[pi] - tw = tess_row[ti] - - # Check if these are the same word - pt = pw.get("text", "").lower().strip() - tt = tw.get("text", "").lower().strip() - - # Same text or one contains the other - is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) - - # Spatial overlap check: if words overlap >= 40% horizontally, - # they're the same physical word regardless of OCR text differences. - # (40% catches borderline cases like "Stick"/"Stück" at 48% overlap) - spatial_match = False - if not is_same: - overlap_left = max(pw["left"], tw["left"]) - overlap_right = min( - pw["left"] + pw.get("width", 0), - tw["left"] + tw.get("width", 0), - ) - overlap_w = max(0, overlap_right - overlap_left) - min_w = min(pw.get("width", 1), tw.get("width", 1)) - if min_w > 0 and overlap_w / min_w >= 0.4: - is_same = True - spatial_match = True - - if is_same: - # Matched — average coordinates weighted by confidence - pc = pw.get("conf", 80) - tc = tw.get("conf", 50) - total = pc + tc - if total == 0: - total = 1 - # Text: prefer higher-confidence engine when texts differ - # (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80) - if spatial_match and pc < tc: - best_text = tw["text"] - else: - best_text = pw["text"] - merged.append({ - "text": best_text, - "left": round((pw["left"] * pc + tw["left"] * tc) / total), - "top": round((pw["top"] * pc + tw["top"] * tc) / total), - "width": round((pw["width"] * pc + tw["width"] * tc) / total), - "height": round((pw["height"] * pc + tw["height"] * tc) / total), - "conf": max(pc, tc), - }) - pi += 1 - ti += 1 - else: - # Different text — one engine found something extra - # Look ahead: is the current Paddle word somewhere in Tesseract ahead? - paddle_ahead = any( - tess_row[t].get("text", "").lower().strip() == pt - for t in range(ti + 1, min(ti + 4, len(tess_row))) - ) - # Is the current Tesseract word somewhere in Paddle ahead? - tess_ahead = any( - paddle_row[p].get("text", "").lower().strip() == tt - for p in range(pi + 1, min(pi + 4, len(paddle_row))) - ) - - if paddle_ahead and not tess_ahead: - # Tesseract has an extra word (e.g. "!" or bullet) → include it - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - elif tess_ahead and not paddle_ahead: - # Paddle has an extra word → include it - merged.append(pw) - pi += 1 - else: - # Both have unique words or neither found ahead → take leftmost first - if pw["left"] <= tw["left"]: - merged.append(pw) - pi += 1 - else: - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - # Remaining words from either engine - while pi < len(paddle_row): - merged.append(paddle_row[pi]) - pi += 1 - while ti < len(tess_row): - tw = tess_row[ti] - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - return merged - - -def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: - """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment. - - Strategy: - 1. Group each engine's words into rows (by Y-position clustering) - 2. Match rows between engines (by vertical center proximity) - 3. Within each matched row: merge sequences left-to-right, deduplicating - words that appear in both engines at the same sequence position - 4. Unmatched rows from either engine: keep as-is - - This prevents: - - Cross-line averaging (words from different lines being merged) - - Duplicate words (same word from both engines shown twice) - """ - if not paddle_words and not tess_words: - return [] - if not paddle_words: - return [w for w in tess_words if w.get("conf", 0) >= 40] - if not tess_words: - return list(paddle_words) - - # Step 1: Group into rows - paddle_rows = _group_words_into_rows(paddle_words) - tess_rows = _group_words_into_rows(tess_words) - - # Step 2: Match rows between engines by vertical center proximity - used_tess_rows: set = set() - merged_all: list = [] - - for pr in paddle_rows: - pr_cy = _row_center_y(pr) - best_dist, best_tri = float("inf"), -1 - for tri, tr in enumerate(tess_rows): - if tri in used_tess_rows: - continue - tr_cy = _row_center_y(tr) - dist = abs(pr_cy - tr_cy) - if dist < best_dist: - best_dist, best_tri = dist, tri - - # Row height threshold — rows must be within ~1.5x typical line height - max_row_dist = max( - max((w.get("height", 20) for w in pr), default=20), - 15, - ) - - if best_tri >= 0 and best_dist <= max_row_dist: - # Matched row — merge sequences - tr = tess_rows[best_tri] - used_tess_rows.add(best_tri) - merged_all.extend(_merge_row_sequences(pr, tr)) - else: - # No matching Tesseract row — keep Paddle row as-is - merged_all.extend(pr) - - # Add unmatched Tesseract rows - for tri, tr in enumerate(tess_rows): - if tri not in used_tess_rows: - for tw in tr: - if tw.get("conf", 0) >= 40: - merged_all.append(tw) - - return merged_all - - -def _deduplicate_words(words: list) -> list: - """Remove duplicate words with same text at overlapping positions. - - PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =") - that produce duplicate words after splitting. This pass removes them. - - A word is a duplicate only when BOTH horizontal AND vertical overlap - exceed 50% — same text on the same visual line at the same position. - """ - if not words: - return words - - result: list = [] - for w in words: - wt = w.get("text", "").lower().strip() - if not wt: - continue - is_dup = False - w_right = w["left"] + w.get("width", 0) - w_bottom = w["top"] + w.get("height", 0) - for existing in result: - et = existing.get("text", "").lower().strip() - if wt != et: - continue - # Horizontal overlap - ox_l = max(w["left"], existing["left"]) - ox_r = min(w_right, existing["left"] + existing.get("width", 0)) - ox = max(0, ox_r - ox_l) - min_w = min(w.get("width", 1), existing.get("width", 1)) - if min_w <= 0 or ox / min_w < 0.5: - continue - # Vertical overlap — must also be on the same line - oy_t = max(w["top"], existing["top"]) - oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) - oy = max(0, oy_b - oy_t) - min_h = min(w.get("height", 1), existing.get("height", 1)) - if min_h > 0 and oy / min_h >= 0.5: - is_dup = True - break - if not is_dup: - result.append(w) - - removed = len(words) - len(result) - if removed: - logger.info("dedup: removed %d duplicate words", removed) - return result - - -@router.post("/sessions/{session_id}/paddle-kombi") -async def paddle_kombi(session_id: str): - """Run PaddleOCR + Tesseract on the preprocessed image and merge results. - - Both engines run on the same preprocessed (cropped/dewarped) image. - Word boxes are matched by IoU and coordinates are averaged weighted by - confidence. Unmatched Tesseract words (bullets, symbols) are added. - """ - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode image") - - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_paddle - - t0 = time.time() - - # --- PaddleOCR --- - paddle_words = await ocr_region_paddle(img_bgr, region=None) - if not paddle_words: - paddle_words = [] - - # --- Tesseract --- - from PIL import Image - import pytesseract - - pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - data = pytesseract.image_to_data( - pil_img, lang="eng+deu", - config="--psm 6 --oem 3", - output_type=pytesseract.Output.DICT, - ) - tess_words = [] - for i in range(len(data["text"])): - text = str(data["text"][i]).strip() - conf_raw = str(data["conf"][i]) - conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 - if not text or conf < 20: - continue - tess_words.append({ - "text": text, - "left": data["left"][i], - "top": data["top"][i], - "width": data["width"][i], - "height": data["height"][i], - "conf": conf, - }) - - # --- Split multi-word Paddle boxes into individual words --- - paddle_words_split = _split_paddle_multi_words(paddle_words) - logger.info( - "paddle_kombi: split %d paddle boxes → %d individual words", - len(paddle_words), len(paddle_words_split), - ) - - # --- Merge --- - if not paddle_words_split and not tess_words: - raise HTTPException(status_code=400, detail="Both OCR engines returned no words") - - merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words) - merged_words = _deduplicate_words(merged_words) - - cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) - duration = time.time() - t0 - - for cell in cells: - cell["ocr_engine"] = "kombi" - - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "kombi", - "grid_method": "kombi", - "raw_paddle_words": paddle_words, - "raw_paddle_words_split": paddle_words_split, - "raw_tesseract_words": tess_words, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - "paddle_words": len(paddle_words), - "paddle_words_split": len(paddle_words_split), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - }, - } - - await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, - ) - # Update in-memory cache so detect-structure can access word_result - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info( - "paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " - "[paddle=%d, tess=%d, merged=%d]", - session_id, len(cells), n_rows, n_cols, duration, - len(paddle_words), len(tess_words), len(merged_words), - ) - - await _append_pipeline_log(session_id, "paddle_kombi", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "paddle_words": len(paddle_words), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - "ocr_engine": "kombi", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -@router.post("/sessions/{session_id}/rapid-kombi") -async def rapid_kombi(session_id: str): - """Run RapidOCR + Tesseract on the preprocessed image and merge results. - - Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime) - instead of remote PaddleOCR service. - """ - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode image") - - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_rapid - from cv_vocab_types import PageRegion - - t0 = time.time() - - # --- RapidOCR (local, synchronous) --- - full_region = PageRegion( - type="full_page", x=0, y=0, width=img_w, height=img_h, - ) - rapid_words = ocr_region_rapid(img_bgr, full_region) - if not rapid_words: - rapid_words = [] - - # --- Tesseract --- - from PIL import Image - import pytesseract - - pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - data = pytesseract.image_to_data( - pil_img, lang="eng+deu", - config="--psm 6 --oem 3", - output_type=pytesseract.Output.DICT, - ) - tess_words = [] - for i in range(len(data["text"])): - text = str(data["text"][i]).strip() - conf_raw = str(data["conf"][i]) - conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 - if not text or conf < 20: - continue - tess_words.append({ - "text": text, - "left": data["left"][i], - "top": data["top"][i], - "width": data["width"][i], - "height": data["height"][i], - "conf": conf, - }) - - # --- Split multi-word RapidOCR boxes into individual words --- - rapid_words_split = _split_paddle_multi_words(rapid_words) - logger.info( - "rapid_kombi: split %d rapid boxes → %d individual words", - len(rapid_words), len(rapid_words_split), - ) - - # --- Merge --- - if not rapid_words_split and not tess_words: - raise HTTPException(status_code=400, detail="Both OCR engines returned no words") - - merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words) - merged_words = _deduplicate_words(merged_words) - - cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) - duration = time.time() - t0 - - for cell in cells: - cell["ocr_engine"] = "rapid_kombi" - - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "rapid_kombi", - "grid_method": "rapid_kombi", - "raw_rapid_words": rapid_words, - "raw_rapid_words_split": rapid_words_split, - "raw_tesseract_words": tess_words, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - "rapid_words": len(rapid_words), - "rapid_words_split": len(rapid_words_split), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - }, - } - - await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, - ) - # Update in-memory cache so detect-structure can access word_result - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info( - "rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " - "[rapid=%d, tess=%d, merged=%d]", - session_id, len(cells), n_rows, n_cols, duration, - len(rapid_words), len(tess_words), len(merged_words), - ) - - await _append_pipeline_log(session_id, "rapid_kombi", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "rapid_words": len(rapid_words), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - "ocr_engine": "rapid_kombi", - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - -class WordGroundTruthRequest(BaseModel): - is_correct: bool - corrected_entries: Optional[List[Dict[str, Any]]] = None - notes: Optional[str] = None - - -@router.post("/sessions/{session_id}/ground-truth/words") -async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): - """Save ground truth feedback for the word recognition step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - gt = { - "is_correct": req.is_correct, - "corrected_entries": req.corrected_entries, - "notes": req.notes, - "saved_at": datetime.utcnow().isoformat(), - "word_result": session.get("word_result"), - } - ground_truth["words"] = gt - - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - return {"session_id": session_id, "ground_truth": gt} - - -@router.get("/sessions/{session_id}/ground-truth/words") -async def get_word_ground_truth(session_id: str): - """Retrieve saved ground truth for word recognition.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - words_gt = ground_truth.get("words") - if not words_gt: - raise HTTPException(status_code=404, detail="No word ground truth saved") - - return { - "session_id": session_id, - "words_gt": words_gt, - "words_auto": session.get("word_result"), - } - - -# --------------------------------------------------------------------------- -# LLM Review Endpoints (Step 6) -# --------------------------------------------------------------------------- - - -@router.post("/sessions/{session_id}/llm-review") -async def run_llm_review(session_id: str, request: Request, stream: bool = False): - """Run LLM-based correction on vocab entries from Step 5. - - Query params: - stream: false (default) for JSON response, true for SSE streaming - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") - - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if not entries: - raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") - - # Optional model override from request body - body = {} - try: - body = await request.json() - except Exception: - pass - model = body.get("model") or OLLAMA_REVIEW_MODEL - - if stream: - return StreamingResponse( - _llm_review_stream_generator(session_id, entries, word_result, model, request), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, - ) - - # Non-streaming path - try: - result = await llm_review_entries(entries, model=model) - except Exception as e: - import traceback - logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") - - # Store result inside word_result as a sub-key - word_result["llm_review"] = { - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "entries_corrected": result["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " - f"{result['duration_ms']}ms, model={result['model_used']}") - - await _append_pipeline_log(session_id, "correction", { - "engine": "llm", - "model": result["model_used"], - "total_entries": len(entries), - "corrections_proposed": len(result["changes"]), - }, duration_ms=result["duration_ms"]) - - return { - "session_id": session_id, - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "total_entries": len(entries), - "corrections_found": len(result["changes"]), - } - - -async def _llm_review_stream_generator( - session_id: str, - entries: List[Dict], - word_result: Dict, - model: str, - request: Request, -): - """SSE generator that yields batch-by-batch LLM review progress.""" - try: - async for event in llm_review_entries_streaming(entries, model=model): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during LLM review for {session_id}") - return - - yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" - - # On complete: persist to DB - if event.get("type") == "complete": - word_result["llm_review"] = { - "changes": event["changes"], - "model_used": event["model_used"], - "duration_ms": event["duration_ms"], - "entries_corrected": event["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " - f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") - - except Exception as e: - import traceback - logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} - yield f"data: {json.dumps(error_event)}\n\n" - - -@router.post("/sessions/{session_id}/llm-review/apply") -async def apply_llm_corrections(session_id: str, request: Request): - """Apply selected LLM corrections to vocab entries.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - llm_review = word_result.get("llm_review") - if not llm_review: - raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") - - body = await request.json() - accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] - - changes = llm_review.get("changes", []) - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - - # Build a lookup: (row_index, field) -> new_value for accepted changes - corrections = {} - applied_count = 0 - for idx, change in enumerate(changes): - if idx in accepted_indices: - key = (change["row_index"], change["field"]) - corrections[key] = change["new"] - applied_count += 1 - - # Apply corrections to entries - for entry in entries: - row_idx = entry.get("row_index", -1) - for field_name in ("english", "german", "example"): - key = (row_idx, field_name) - if key in corrections: - entry[field_name] = corrections[key] - entry["llm_corrected"] = True - - # Update word_result - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["llm_review"]["applied_count"] = applied_count - word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() - - await update_session_db(session_id, word_result=word_result) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") - - return { - "session_id": session_id, - "applied_count": applied_count, - "total_changes": len(changes), - } - - -@router.post("/sessions/{session_id}/reconstruction") -async def save_reconstruction(session_id: str, request: Request): - """Save edited cell texts from reconstruction step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - body = await request.json() - cell_updates = body.get("cells", []) - - if not cell_updates: - await update_session_db(session_id, current_step=10) - return {"session_id": session_id, "updated": 0} - - # Build update map: cell_id -> new text - update_map = {c["cell_id"]: c["text"] for c in cell_updates} - - # Separate sub-session updates (cell_ids prefixed with "box{N}_") - sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} - main_updates: Dict[str, str] = {} - for cell_id, text in update_map.items(): - m = re.match(r'^box(\d+)_(.+)$', cell_id) - if m: - bi = int(m.group(1)) - original_id = m.group(2) - sub_updates.setdefault(bi, {})[original_id] = text - else: - main_updates[cell_id] = text - - # Update main session cells - cells = word_result.get("cells", []) - updated_count = 0 - for cell in cells: - if cell["cell_id"] in main_updates: - cell["text"] = main_updates[cell["cell_id"]] - cell["status"] = "edited" - updated_count += 1 - - word_result["cells"] = cells - - # Also update vocab_entries if present - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if entries: - # Map cell_id pattern "R{row}_C{col}" to entry fields - for entry in entries: - row_idx = entry.get("row_index", -1) - # Check each field's cell - for col_idx, field_name in enumerate(["english", "german", "example"]): - cell_id = f"R{row_idx:02d}_C{col_idx}" - # Also try without zero-padding - cell_id_alt = f"R{row_idx}_C{col_idx}" - new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) - if new_text is not None: - entry[field_name] = new_text - - word_result["vocab_entries"] = entries - if "entries" in word_result: - word_result["entries"] = entries - - await update_session_db(session_id, word_result=word_result, current_step=10) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - # Route sub-session updates - sub_updated = 0 - if sub_updates: - subs = await get_sub_sessions(session_id) - sub_by_index = {s.get("box_index"): s["id"] for s in subs} - for bi, updates in sub_updates.items(): - sub_id = sub_by_index.get(bi) - if not sub_id: - continue - sub_session = await get_session_db(sub_id) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word: - continue - sub_cells = sub_word.get("cells", []) - for cell in sub_cells: - if cell["cell_id"] in updates: - cell["text"] = updates[cell["cell_id"]] - cell["status"] = "edited" - sub_updated += 1 - sub_word["cells"] = sub_cells - await update_session_db(sub_id, word_result=sub_word) - if sub_id in _cache: - _cache[sub_id]["word_result"] = sub_word - - total_updated = updated_count + sub_updated - logger.info(f"Reconstruction saved for session {session_id}: " - f"{updated_count} main + {sub_updated} sub-session cells updated") - - return { - "session_id": session_id, - "updated": total_updated, - "main_updated": updated_count, - "sub_updated": sub_updated, - } - - -@router.get("/sessions/{session_id}/reconstruction/fabric-json") -async def get_fabric_json(session_id: str): - """Return cell grid as Fabric.js-compatible JSON for the canvas editor. - - If the session has sub-sessions (box regions), their cells are merged - into the result at the correct Y positions. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = list(word_result.get("cells", [])) - img_w = word_result.get("image_width", 800) - img_h = word_result.get("image_height", 600) - - # Merge sub-session cells at box positions - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word or not sub_word.get("cells"): - continue - - bi = sub.get("box_index", 0) - if bi < len(box_zones): - box = box_zones[bi]["box"] - box_y, box_x = box["y"], box["x"] - else: - box_y, box_x = 0, 0 - - # Offset sub-session cells to absolute page coordinates - for cell in sub_word["cells"]: - cell_copy = dict(cell) - # Prefix cell_id with box index - cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" - cell_copy["source"] = f"box_{bi}" - # Offset bbox_px - bbox = cell_copy.get("bbox_px", {}) - if bbox: - bbox = dict(bbox) - bbox["x"] = bbox.get("x", 0) + box_x - bbox["y"] = bbox.get("y", 0) + box_y - cell_copy["bbox_px"] = bbox - cells.append(cell_copy) - - from services.layout_reconstruction_service import cells_to_fabric_json - fabric_json = cells_to_fabric_json(cells, img_w, img_h) - - return fabric_json - - -@router.get("/sessions/{session_id}/vocab-entries/merged") -async def get_merged_vocab_entries(session_id: str): - """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") or {} - entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) - - # Tag main entries - for e in entries: - e.setdefault("source", "main") - - # Merge sub-session entries - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") or {} - sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] - - bi = sub.get("box_index", 0) - box_y = 0 - if bi < len(box_zones): - box_y = box_zones[bi]["box"]["y"] - - for e in sub_entries: - e_copy = dict(e) - e_copy["source"] = f"box_{bi}" - e_copy["source_y"] = box_y # for sorting - entries.append(e_copy) - - # Sort by approximate Y position - def _sort_key(e): - if e.get("source", "main") == "main": - return e.get("row_index", 0) * 100 # main entries by row index - return e.get("source_y", 0) * 100 + e.get("row_index", 0) - - entries.sort(key=_sort_key) - - return { - "session_id": session_id, - "entries": entries, - "total": len(entries), - "sources": list(set(e.get("source", "main") for e in entries)), - } - - -@router.get("/sessions/{session_id}/reconstruction/export/pdf") -async def export_reconstruction_pdf(session_id: str): - """Export the reconstructed cell grid as a PDF table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - # Build table data: rows × columns - table_data: list[list[str]] = [] - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - table_data.append(header) - - for r in range(n_rows): - row_texts = [] - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - row_texts.append(cell.get("text", "") if cell else "") - table_data.append(row_texts) - - # Generate PDF with reportlab - try: - from reportlab.lib.pagesizes import A4 - from reportlab.lib import colors - from reportlab.platypus import SimpleDocTemplate, Table, TableStyle - import io as _io - - buf = _io.BytesIO() - doc = SimpleDocTemplate(buf, pagesize=A4) - if not table_data or not table_data[0]: - raise HTTPException(status_code=400, detail="No data to export") - - t = Table(table_data) - t.setStyle(TableStyle([ - ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), - ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), - ('FONTSIZE', (0, 0), (-1, -1), 9), - ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), - ('VALIGN', (0, 0), (-1, -1), 'TOP'), - ('WORDWRAP', (0, 0), (-1, -1), True), - ])) - doc.build([t]) - buf.seek(0) - - from fastapi.responses import StreamingResponse - return StreamingResponse( - buf, - media_type="application/pdf", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="reportlab not installed") - - -@router.get("/sessions/{session_id}/reconstruction/export/docx") -async def export_reconstruction_docx(session_id: str): - """Export the reconstructed cell grid as a DOCX table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - try: - from docx import Document - from docx.shared import Pt - import io as _io - - doc = Document() - doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1) - - # Build header - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - - table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) - table.style = 'Table Grid' - - # Header row - for ci, h in enumerate(header): - table.rows[0].cells[ci].text = h - - # Data rows - for r in range(n_rows): - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" - - buf = _io.BytesIO() - doc.save(buf) - buf.seek(0) - - from fastapi.responses import StreamingResponse - return StreamingResponse( - buf, - media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="python-docx not installed") - - -# --------------------------------------------------------------------------- -# Step 8: Validation — Original vs. Reconstruction -# --------------------------------------------------------------------------- - -STYLE_SUFFIXES = { - "educational": "educational illustration, textbook style, clear, colorful", - "cartoon": "cartoon, child-friendly, simple shapes", - "sketch": "pencil sketch, hand-drawn, black and white", - "clipart": "clipart, flat vector style, simple", - "realistic": "photorealistic, high detail", -} - - -class ValidationRequest(BaseModel): - notes: Optional[str] = None - score: Optional[int] = None - - -class GenerateImageRequest(BaseModel): - region_index: int - prompt: str - style: str = "educational" - - -@router.post("/sessions/{session_id}/reconstruction/detect-images") -async def detect_image_regions(session_id: str): - """Detect illustration/image regions in the original scan using VLM. - - Sends the original image to qwen2.5vl to find non-text, non-table - image areas, returning bounding boxes (in %) and descriptions. - """ - import base64 - import httpx - import re - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # Get original image bytes - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=400, detail="No original image found") - - # Build context from vocab entries for richer descriptions - word_result = session.get("word_result") or {} - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - vocab_context = "" - if entries: - sample = entries[:10] - words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] - if words: - vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "Analyze this scanned page. Find ALL illustration/image/picture regions " - "(NOT text, NOT table cells, NOT blank areas). " - "For each image region found, return its bounding box as percentage of page dimensions " - "and a short English description of what the image shows. " - "Reply with ONLY a JSON array like: " - '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' - "where x, y, w, h are percentages (0-100) of the page width/height. " - "If there are NO images on the page, return an empty array: []" - f"{vocab_context}" - ) - - img_b64 = base64.b64encode(original_png).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON array from response - match = re.search(r'\[.*?\]', text, re.DOTALL) - if match: - raw_regions = json.loads(match.group(0)) - else: - raw_regions = [] - - # Normalize to ImageRegion format - regions = [] - for r in raw_regions: - regions.append({ - "bbox_pct": { - "x": max(0, min(100, float(r.get("x", 0)))), - "y": max(0, min(100, float(r.get("y", 0)))), - "w": max(1, min(100, float(r.get("w", 10)))), - "h": max(1, min(100, float(r.get("h", 10)))), - }, - "description": r.get("description", ""), - "prompt": r.get("description", ""), - "image_b64": None, - "style": "educational", - }) - - # Enrich prompts with nearby vocab context - if entries: - for region in regions: - ry = region["bbox_pct"]["y"] - rh = region["bbox_pct"]["h"] - nearby = [ - e for e in entries - if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 - ] - if nearby: - en_words = [e.get("english", "") for e in nearby if e.get("english")] - de_words = [e.get("german", "") for e in nearby if e.get("german")] - if en_words or de_words: - context = f" (vocabulary context: {', '.join(en_words[:5])}" - if de_words: - context += f" / {', '.join(de_words[:5])}" - context += ")" - region["prompt"] = region["description"] + context - - # Save to ground_truth JSONB - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["image_regions"] = regions - validation["detected_at"] = datetime.utcnow().isoformat() - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Detected {len(regions)} image regions for session {session_id}") - - return {"regions": regions, "count": len(regions)} - - except httpx.ConnectError: - logger.warning(f"VLM not available at {ollama_base} for image detection") - return {"regions": [], "count": 0, "error": "VLM not available"} - except Exception as e: - logger.error(f"Image detection failed for {session_id}: {e}") - return {"regions": [], "count": 0, "error": str(e)} - - -@router.post("/sessions/{session_id}/reconstruction/generate-image") -async def generate_image_for_region(session_id: str, req: GenerateImageRequest): - """Generate a replacement image for a detected region using mflux. - - Sends the prompt (with style suffix) to the mflux-service running - natively on the Mac Mini (Metal GPU required). - """ - import httpx - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - regions = validation.get("image_regions") or [] - - if req.region_index < 0 or req.region_index >= len(regions): - raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") - - mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") - style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) - full_prompt = f"{req.prompt}, {style_suffix}" - - # Determine image size from region aspect ratio (snap to multiples of 64) - region = regions[req.region_index] - bbox = region["bbox_pct"] - aspect = bbox["w"] / max(bbox["h"], 1) - if aspect > 1.3: - width, height = 768, 512 - elif aspect < 0.7: - width, height = 512, 768 - else: - width, height = 512, 512 - - try: - async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post(f"{mflux_url}/generate", json={ - "prompt": full_prompt, - "width": width, - "height": height, - "steps": 4, - }) - resp.raise_for_status() - data = resp.json() - image_b64 = data.get("image_b64") - - if not image_b64: - return {"image_b64": None, "success": False, "error": "No image returned"} - - # Save to ground_truth - regions[req.region_index]["image_b64"] = image_b64 - regions[req.region_index]["prompt"] = req.prompt - regions[req.region_index]["style"] = req.style - validation["image_regions"] = regions - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Generated image for session {session_id} region {req.region_index}") - return {"image_b64": image_b64, "success": True} - - except httpx.ConnectError: - logger.warning(f"mflux-service not available at {mflux_url}") - return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} - except Exception as e: - logger.error(f"Image generation failed for {session_id}: {e}") - return {"image_b64": None, "success": False, "error": str(e)} - - -@router.post("/sessions/{session_id}/reconstruction/validate") -async def save_validation(session_id: str, req: ValidationRequest): - """Save final validation results for step 8. - - Stores notes, score, and preserves any detected/generated image regions. - Sets current_step = 10 to mark pipeline as complete. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["validated_at"] = datetime.utcnow().isoformat() - validation["notes"] = req.notes - validation["score"] = req.score - ground_truth["validation"] = validation - - await update_session_db(session_id, ground_truth=ground_truth, current_step=11) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Validation saved for session {session_id}: score={req.score}") - - return {"session_id": session_id, "validation": validation} - - -@router.get("/sessions/{session_id}/reconstruction/validation") -async def get_validation(session_id: str): - """Retrieve saved validation data for step 8.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") - - return { - "session_id": session_id, - "validation": validation, - "word_result": session.get("word_result"), - } - - -@router.post("/sessions/{session_id}/reprocess") -async def reprocess_session(session_id: str, request: Request): - """Re-run pipeline from a specific step, clearing downstream data. - - Body: {"from_step": 5} (1-indexed step number) - - Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) → - Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10) - - Clears downstream results: - - from_step <= 1: orientation_result + all downstream - - from_step <= 2: deskew_result + all downstream - - from_step <= 3: dewarp_result + all downstream - - from_step <= 4: crop_result + all downstream - - from_step <= 5: column_result, row_result, word_result - - from_step <= 6: row_result, word_result - - from_step <= 7: word_result (cells, vocab_entries) - - from_step <= 8: word_result.llm_review only - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - body = await request.json() - from_step = body.get("from_step", 1) - if not isinstance(from_step, int) or from_step < 1 or from_step > 10: - raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") - - update_kwargs: Dict[str, Any] = {"current_step": from_step} - - # Clear downstream data based on from_step - # New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) → - # Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11) - if from_step <= 8: - update_kwargs["word_result"] = None - elif from_step == 9: - # Only clear LLM review from word_result - word_result = session.get("word_result") - if word_result: - word_result.pop("llm_review", None) - word_result.pop("llm_corrections", None) - update_kwargs["word_result"] = word_result - - if from_step <= 7: - update_kwargs["row_result"] = None - if from_step <= 6: - update_kwargs["column_result"] = None - if from_step <= 4: - update_kwargs["crop_result"] = None - if from_step <= 3: - update_kwargs["dewarp_result"] = None - if from_step <= 2: - update_kwargs["deskew_result"] = None - if from_step <= 1: - update_kwargs["orientation_result"] = None - - await update_session_db(session_id, **update_kwargs) - - # Also clear cache - if session_id in _cache: - for key in list(update_kwargs.keys()): - if key != "current_step": - _cache[session_id][key] = update_kwargs[key] - _cache[session_id]["current_step"] = from_step - - logger.info(f"Session {session_id} reprocessing from step {from_step}") - - return { - "session_id": session_id, - "from_step": from_step, - "cleared": [k for k in update_kwargs if k != "current_step"], - } - - -async def _get_rows_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with row bands drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - row_result = session.get("row_result") - if not row_result or not row_result.get("rows"): - raise HTTPException(status_code=404, detail="No row data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for row types (BGR) - row_colors = { - "content": (255, 180, 0), # Blue - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for row in row_result["rows"]: - x, y = row["x"], row["y"] - w, h = row["width"], row["height"] - row_type = row.get("row_type", "content") - color = row_colors.get(row_type, (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) - - # Label - idx = row.get("index", 0) - label = f"R{idx} {row_type.upper()}" - wc = row.get("word_count", 0) - if wc: - label = f"{label} ({wc}w)" - cv2.putText(img, label, (x + 5, y + 18), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # Draw zone separator lines if zones exist - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - if zones: - img_w_px = img.shape[1] - zone_color = (0, 200, 255) # Yellow (BGR) - dash_len = 20 - for zone in zones: - if zone.get("zone_type") == "box": - zy = zone["y"] - zh = zone["height"] - for line_y in [zy, zy + zh]: - for sx in range(0, img_w_px, dash_len * 2): - ex = min(sx + dash_len, img_w_px) - cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -async def _get_words_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with cell grid drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=404, detail="No word data available") - - # Support both new cell-based and legacy entry-based formats - cells = word_result.get("cells") - if not cells and not word_result.get("entries"): - raise HTTPException(status_code=404, detail="No word data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - img_h, img_w = img.shape[:2] - - overlay = img.copy() - - if cells: - # New cell-based overlay: color by column index - col_palette = [ - (255, 180, 0), # Blue (BGR) - (0, 200, 0), # Green - (0, 140, 255), # Orange - (200, 100, 200), # Purple - (200, 200, 0), # Cyan - (100, 200, 200), # Yellow-ish - ] - - for cell in cells: - bbox = cell.get("bbox_px", {}) - cx = bbox.get("x", 0) - cy = bbox.get("y", 0) - cw = bbox.get("w", 0) - ch = bbox.get("h", 0) - if cw <= 0 or ch <= 0: - continue - - col_idx = cell.get("col_index", 0) - color = col_palette[col_idx % len(col_palette)] - - # Cell rectangle border - cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1) - # Semi-transparent fill - cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1) - - # Cell-ID label (top-left corner) - cell_id = cell.get("cell_id", "") - cv2.putText(img, cell_id, (cx + 2, cy + 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1) - - # Text label (bottom of cell) - text = cell.get("text", "") - if text: - conf = cell.get("confidence", 0) - if conf >= 70: - text_color = (0, 180, 0) - elif conf >= 50: - text_color = (0, 180, 220) - else: - text_color = (0, 0, 220) - - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, cy + ch - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - else: - # Legacy fallback: entry-based overlay (for old sessions) - column_result = session.get("column_result") - row_result = session.get("row_result") - col_colors = { - "column_en": (255, 180, 0), - "column_de": (0, 200, 0), - "column_example": (0, 140, 255), - } - - columns = [] - if column_result and column_result.get("columns"): - columns = [c for c in column_result["columns"] - if c.get("type", "").startswith("column_")] - - content_rows_data = [] - if row_result and row_result.get("rows"): - content_rows_data = [r for r in row_result["rows"] - if r.get("row_type") == "content"] - - for col in columns: - col_type = col.get("type", "") - color = col_colors.get(col_type, (200, 200, 200)) - cx, cw = col["x"], col["width"] - for row in content_rows_data: - ry, rh = row["y"], row["height"] - cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1) - cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1) - - entries = word_result["entries"] - entry_by_row: Dict[int, Dict] = {} - for entry in entries: - entry_by_row[entry.get("row_index", -1)] = entry - - for row_idx, row in enumerate(content_rows_data): - entry = entry_by_row.get(row_idx) - if not entry: - continue - conf = entry.get("confidence", 0) - text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220) - ry, rh = row["y"], row["height"] - for col in columns: - col_type = col.get("type", "") - cx, cw = col["x"], col["width"] - field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "") - text = entry.get(field, "") if field else "" - if text: - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, ry + rh - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - - # Blend overlay at 10% opacity - cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) - - # Red semi-transparent overlay for box zones - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -# --------------------------------------------------------------------------- -# Handwriting Removal Endpoint -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/remove-handwriting") -async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): - """ - Remove handwriting from a session image using inpainting. - - Steps: - 1. Load source image (auto → deskewed if available, else original) - 2. Detect handwriting mask (filtered by target_ink) - 3. Dilate mask to cover stroke edges - 4. Inpaint the image - 5. Store result as clean_png in the session - - Returns metadata including the URL to fetch the clean image. - """ - import time as _time - t0 = _time.monotonic() - - from services.handwriting_detection import detect_handwriting - from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # 1. Determine source image - source = req.use_source - if source == "auto": - deskewed = await get_session_image(session_id, "deskewed") - source = "deskewed" if deskewed else "original" - - image_bytes = await get_session_image(session_id, source) - if not image_bytes: - raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") - - # 2. Detect handwriting mask - detection = detect_handwriting(image_bytes, target_ink=req.target_ink) - - # 3. Convert mask to PNG bytes and dilate - import io - from PIL import Image as _PILImage - mask_img = _PILImage.fromarray(detection.mask) - mask_buf = io.BytesIO() - mask_img.save(mask_buf, format="PNG") - mask_bytes = mask_buf.getvalue() - - if req.dilation > 0: - mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) - - # 4. Inpaint - method_map = { - "telea": InpaintingMethod.OPENCV_TELEA, - "ns": InpaintingMethod.OPENCV_NS, - "auto": InpaintingMethod.AUTO, - } - inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) - - result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) - if not result.success: - raise HTTPException(status_code=500, detail="Inpainting failed") - - elapsed_ms = int((_time.monotonic() - t0) * 1000) - - meta = { - "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), - "handwriting_ratio": round(detection.handwriting_ratio, 4), - "detection_confidence": round(detection.confidence, 4), - "target_ink": req.target_ink, - "dilation": req.dilation, - "source_image": source, - "processing_time_ms": elapsed_ms, - } - - # 5. Persist clean image (convert BGR ndarray → PNG bytes) - clean_png_bytes = image_to_png(result.image) - await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) - - return { - **meta, - "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", - "session_id": session_id, - } - - -# --------------------------------------------------------------------------- -# Auto-Mode Endpoint (Improvement 3) -# --------------------------------------------------------------------------- - -class RunAutoRequest(BaseModel): - from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review - ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" - pronunciation: str = "british" - skip_llm_review: bool = False - dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" - - -async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: - """Format a single SSE event line.""" - import json as _json - payload = {"step": step, "status": status, **data} - return f"data: {_json.dumps(payload)}\n\n" - - -@router.post("/sessions/{session_id}/run-auto") -async def run_auto(session_id: str, req: RunAutoRequest, request: Request): - """Run the full OCR pipeline automatically from a given step, streaming SSE progress. - - Steps: - 1. Deskew — straighten the scan - 2. Dewarp — correct vertical shear (ensemble CV or VLM) - 3. Columns — detect column layout - 4. Rows — detect row layout - 5. Words — OCR each cell - 6. LLM review — correct OCR errors (optional) - - Already-completed steps are skipped unless `from_step` forces a rerun. - Yields SSE events of the form: - data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} - - Final event: - data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} - """ - if req.from_step < 1 or req.from_step > 6: - raise HTTPException(status_code=400, detail="from_step must be 1-6") - if req.dewarp_method not in ("ensemble", "vlm", "cv"): - raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") - - if session_id not in _cache: - await _load_session_to_cache(session_id) - - async def _generate(): - steps_run: List[str] = [] - steps_skipped: List[str] = [] - error_step: Optional[str] = None - - session = await get_session_db(session_id) - if not session: - yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) - return - - cached = _get_cached(session_id) - - # ----------------------------------------------------------------- - # Step 1: Deskew - # ----------------------------------------------------------------- - if req.from_step <= 1: - yield await _auto_sse_event("deskew", "start", {}) - try: - t0 = time.time() - orig_bgr = cached.get("original_bgr") - if orig_bgr is None: - raise ValueError("Original image not loaded") - - # Method 1: Hough lines - try: - deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) - except Exception: - deskewed_hough, angle_hough = orig_bgr, 0.0 - - # Method 2: Word alignment - success_enc, png_orig = cv2.imencode(".png", orig_bgr) - orig_bytes = png_orig.tobytes() if success_enc else b"" - try: - deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) - except Exception: - deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 - - # Pick best method - if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: - method_used = "word_alignment" - angle_applied = angle_wa - wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) - deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) - if deskewed_bgr is None: - deskewed_bgr = deskewed_hough - method_used = "hough" - angle_applied = angle_hough - else: - method_used = "hough" - angle_applied = angle_hough - deskewed_bgr = deskewed_hough - - success, png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = png_buf.tobytes() if success else b"" - - deskew_result = { - "method_used": method_used, - "rotation_degrees": round(float(angle_applied), 3), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["deskewed_bgr"] = deskewed_bgr - cached["deskew_result"] = deskew_result - await update_session_db( - session_id, - deskewed_png=deskewed_png, - deskew_result=deskew_result, - auto_rotation_degrees=float(angle_applied), - current_step=3, - ) - session = await get_session_db(session_id) - - steps_run.append("deskew") - yield await _auto_sse_event("deskew", "done", deskew_result) - except Exception as e: - logger.error(f"Auto-mode deskew failed for {session_id}: {e}") - error_step = "deskew" - yield await _auto_sse_event("deskew", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("deskew") - yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) - - # ----------------------------------------------------------------- - # Step 2: Dewarp - # ----------------------------------------------------------------- - if req.from_step <= 2: - yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) - try: - t0 = time.time() - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise ValueError("Deskewed image not available") - - if req.dewarp_method == "vlm": - success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) - img_bytes = png_buf.tobytes() if success_enc else b"" - vlm_det = await _detect_shear_with_vlm(img_bytes) - shear_deg = vlm_det["shear_degrees"] - if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: - dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) - else: - dewarped_bgr = deskewed_bgr - dewarp_info = { - "method": vlm_det["method"], - "shear_degrees": shear_deg, - "confidence": vlm_det["confidence"], - "detections": [vlm_det], - } - else: - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) - - success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success_enc else b"" - - dewarp_result = { - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(time.time() - t0, 2), - "detections": dewarp_info.get("detections", []), - } - - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), - current_step=4, - ) - session = await get_session_db(session_id) - - steps_run.append("dewarp") - yield await _auto_sse_event("dewarp", "done", dewarp_result) - except Exception as e: - logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") - error_step = "dewarp" - yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("dewarp") - yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) - - # ----------------------------------------------------------------- - # Step 3: Columns - # ----------------------------------------------------------------- - if req.from_step <= 3: - yield await _auto_sse_event("columns", "start", {}) - try: - t0 = time.time() - col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if col_img is None: - raise ValueError("Cropped/dewarped image not available") - - ocr_img = create_ocr_image(col_img) - h, w = ocr_img.shape[:2] - - geo_result = detect_column_geometry(ocr_img, col_img) - if geo_result is None: - layout_img = create_layout_image(col_img) - regions = analyze_layout(layout_img, ocr_img) - cached["_word_dicts"] = None - cached["_inv"] = None - cached["_content_bounds"] = None - else: - geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - content_w = right_x - left_x - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) - geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, - top_y=top_y, header_y=header_y, footer_y=footer_y) - regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, - left_x=left_x, right_x=right_x, inv=inv) - - columns = [asdict(r) for r in regions] - column_result = { - "columns": columns, - "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["column_result"] = column_result - await update_session_db(session_id, column_result=column_result, - row_result=None, word_result=None, current_step=6) - session = await get_session_db(session_id) - - steps_run.append("columns") - yield await _auto_sse_event("columns", "done", { - "column_count": len(columns), - "duration_seconds": column_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode columns failed for {session_id}: {e}") - error_step = "columns" - yield await _auto_sse_event("columns", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("columns") - yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) - - # ----------------------------------------------------------------- - # Step 4: Rows - # ----------------------------------------------------------------- - if req.from_step <= 4: - yield await _auto_sse_event("rows", "start", {}) - try: - t0 = time.time() - row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - column_result = session.get("column_result") or cached.get("column_result") - if not column_result or not column_result.get("columns"): - raise ValueError("Column detection must complete first") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - word_dicts = cached.get("_word_dicts") - inv = cached.get("_inv") - content_bounds = cached.get("_content_bounds") - - if word_dicts is None or inv is None or content_bounds is None: - ocr_img_tmp = create_ocr_image(row_img) - geo_result = detect_column_geometry(ocr_img_tmp, row_img) - if geo_result is None: - raise ValueError("Column geometry detection failed — cannot detect rows") - _g, lx, rx, ty, by, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (lx, rx, ty, by) - content_bounds = (lx, rx, ty, by) - - left_x, right_x, top_y, bottom_y = content_bounds - row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - - row_list = [ - { - "index": r.index, "x": r.x, "y": r.y, - "width": r.width, "height": r.height, - "word_count": r.word_count, - "row_type": r.row_type, - "gap_before": r.gap_before, - } - for r in row_geoms - ] - row_result = { - "rows": row_list, - "row_count": len(row_list), - "content_rows": len([r for r in row_geoms if r.row_type == "content"]), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["row_result"] = row_result - await update_session_db(session_id, row_result=row_result, current_step=7) - session = await get_session_db(session_id) - - steps_run.append("rows") - yield await _auto_sse_event("rows", "done", { - "row_count": len(row_list), - "content_rows": row_result["content_rows"], - "duration_seconds": row_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode rows failed for {session_id}: {e}") - error_step = "rows" - yield await _auto_sse_event("rows", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("rows") - yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) - - # ----------------------------------------------------------------- - # Step 5: Words (OCR) - # ----------------------------------------------------------------- - if req.from_step <= 5: - yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) - try: - t0 = time.time() - word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - - column_result = session.get("column_result") or cached.get("column_result") - row_result = session.get("row_result") or cached.get("row_result") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - row_geoms = [ - RowGeometry( - index=r["index"], x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - word_dicts = cached.get("_word_dicts") - if word_dicts is not None: - content_bounds = cached.get("_content_bounds") - top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - ocr_img = create_ocr_image(word_img) - img_h, img_w = word_img.shape[:2] - - cells, columns_meta = build_cell_grid( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=req.ocr_engine, img_bgr=word_img, - ) - duration = time.time() - t0 - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine - - # Apply IPA phonetic fixes directly to cell texts - fix_cell_phonetics(cells, pronunciation=req.pronunciation) - - word_result_data = { - "cells": cells, - "grid_shape": { - "rows": n_content_rows, - "cols": len(columns_meta), - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) - word_result_data["vocab_entries"] = entries - word_result_data["entries"] = entries - word_result_data["entry_count"] = len(entries) - word_result_data["summary"]["total_entries"] = len(entries) - - await update_session_db(session_id, word_result=word_result_data, current_step=8) - cached["word_result"] = word_result_data - session = await get_session_db(session_id) - - steps_run.append("words") - yield await _auto_sse_event("words", "done", { - "total_cells": len(cells), - "layout": word_result_data["layout"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": word_result_data["summary"], - }) - except Exception as e: - logger.error(f"Auto-mode words failed for {session_id}: {e}") - error_step = "words" - yield await _auto_sse_event("words", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("words") - yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) - - # ----------------------------------------------------------------- - # Step 6: LLM Review (optional) - # ----------------------------------------------------------------- - if req.from_step <= 6 and not req.skip_llm_review: - yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) - try: - session = await get_session_db(session_id) - word_result = session.get("word_result") or cached.get("word_result") - entries = word_result.get("entries") or word_result.get("vocab_entries") or [] - - if not entries: - yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) - steps_skipped.append("llm_review") - else: - reviewed = await llm_review_entries(entries) - - session = await get_session_db(session_id) - word_result_updated = dict(session.get("word_result") or {}) - word_result_updated["entries"] = reviewed - word_result_updated["vocab_entries"] = reviewed - word_result_updated["llm_reviewed"] = True - word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL - - await update_session_db(session_id, word_result=word_result_updated, current_step=9) - cached["word_result"] = word_result_updated - - steps_run.append("llm_review") - yield await _auto_sse_event("llm_review", "done", { - "entries_reviewed": len(reviewed), - "model": OLLAMA_REVIEW_MODEL, - }) - except Exception as e: - logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") - yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) - steps_skipped.append("llm_review") - else: - steps_skipped.append("llm_review") - reason = "skipped by request" if req.skip_llm_review else "from_step > 6" - yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) - - # ----------------------------------------------------------------- - # Final event - # ----------------------------------------------------------------- - yield await _auto_sse_event("complete", "done", { - "steps_run": steps_run, - "steps_skipped": steps_skipped, - }) - - return StreamingResponse( - _generate(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) +router = APIRouter() +router.include_router(_sessions_router) +router.include_router(_geometry_router) +router.include_router(_rows_router) +router.include_router(_words_router) +router.include_router(_ocr_merge_router) +router.include_router(_postprocess_router) +router.include_router(_auto_router) diff --git a/klausur-service/backend/ocr_pipeline_auto.py b/klausur-service/backend/ocr_pipeline_auto.py new file mode 100644 index 0000000..c85ac49 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_auto.py @@ -0,0 +1,705 @@ +""" +OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints. + +Extracted from ocr_pipeline_api.py — contains: +- POST /sessions/{session_id}/reprocess (clear downstream + restart from step) +- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming) + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import re +import time +from dataclasses import asdict +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _detect_header_footer_gaps, + _detect_sub_columns, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + analyze_layout, + build_cell_grid, + classify_column_types, + create_layout_image, + create_ocr_image, + deskew_image, + deskew_image_by_word_alignment, + detect_column_geometry, + detect_row_geometry, + _apply_shear, + dewarp_image, + llm_review_entries, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _get_base_image_png, + _append_pipeline_log, +) +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Reprocess endpoint +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reprocess") +async def reprocess_session(session_id: str, request: Request): + """Re-run pipeline from a specific step, clearing downstream data. + + Body: {"from_step": 5} (1-indexed step number) + + Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) → + Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10) + + Clears downstream results: + - from_step <= 1: orientation_result + all downstream + - from_step <= 2: deskew_result + all downstream + - from_step <= 3: dewarp_result + all downstream + - from_step <= 4: crop_result + all downstream + - from_step <= 5: column_result, row_result, word_result + - from_step <= 6: row_result, word_result + - from_step <= 7: word_result (cells, vocab_entries) + - from_step <= 8: word_result.llm_review only + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + body = await request.json() + from_step = body.get("from_step", 1) + if not isinstance(from_step, int) or from_step < 1 or from_step > 10: + raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") + + update_kwargs: Dict[str, Any] = {"current_step": from_step} + + # Clear downstream data based on from_step + # New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) → + # Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11) + if from_step <= 8: + update_kwargs["word_result"] = None + elif from_step == 9: + # Only clear LLM review from word_result + word_result = session.get("word_result") + if word_result: + word_result.pop("llm_review", None) + word_result.pop("llm_corrections", None) + update_kwargs["word_result"] = word_result + + if from_step <= 7: + update_kwargs["row_result"] = None + if from_step <= 6: + update_kwargs["column_result"] = None + if from_step <= 4: + update_kwargs["crop_result"] = None + if from_step <= 3: + update_kwargs["dewarp_result"] = None + if from_step <= 2: + update_kwargs["deskew_result"] = None + if from_step <= 1: + update_kwargs["orientation_result"] = None + + await update_session_db(session_id, **update_kwargs) + + # Also clear cache + if session_id in _cache: + for key in list(update_kwargs.keys()): + if key != "current_step": + _cache[session_id][key] = update_kwargs[key] + _cache[session_id]["current_step"] = from_step + + logger.info(f"Session {session_id} reprocessing from step {from_step}") + + return { + "session_id": session_id, + "from_step": from_step, + "cleared": [k for k in update_kwargs if k != "current_step"], + } + + +# --------------------------------------------------------------------------- +# VLM shear detection helper (used by dewarp step in auto-mode) +# --------------------------------------------------------------------------- + +async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + import re + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + import json + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} + + +# --------------------------------------------------------------------------- +# Auto-mode orchestrator +# --------------------------------------------------------------------------- + +class RunAutoRequest(BaseModel): + from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review + ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" + pronunciation: str = "british" + skip_llm_review: bool = False + dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" + + +async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: + """Format a single SSE event line.""" + import json as _json + payload = {"step": step, "status": status, **data} + return f"data: {_json.dumps(payload)}\n\n" + + +@router.post("/sessions/{session_id}/run-auto") +async def run_auto(session_id: str, req: RunAutoRequest, request: Request): + """Run the full OCR pipeline automatically from a given step, streaming SSE progress. + + Steps: + 1. Deskew — straighten the scan + 2. Dewarp — correct vertical shear (ensemble CV or VLM) + 3. Columns — detect column layout + 4. Rows — detect row layout + 5. Words — OCR each cell + 6. LLM review — correct OCR errors (optional) + + Already-completed steps are skipped unless `from_step` forces a rerun. + Yields SSE events of the form: + data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} + + Final event: + data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} + """ + if req.from_step < 1 or req.from_step > 6: + raise HTTPException(status_code=400, detail="from_step must be 1-6") + if req.dewarp_method not in ("ensemble", "vlm", "cv"): + raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + + async def _generate(): + steps_run: List[str] = [] + steps_skipped: List[str] = [] + error_step: Optional[str] = None + + session = await get_session_db(session_id) + if not session: + yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) + return + + cached = _get_cached(session_id) + + # ----------------------------------------------------------------- + # Step 1: Deskew + # ----------------------------------------------------------------- + if req.from_step <= 1: + yield await _auto_sse_event("deskew", "start", {}) + try: + t0 = time.time() + orig_bgr = cached.get("original_bgr") + if orig_bgr is None: + raise ValueError("Original image not loaded") + + # Method 1: Hough lines + try: + deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) + except Exception: + deskewed_hough, angle_hough = orig_bgr, 0.0 + + # Method 2: Word alignment + success_enc, png_orig = cv2.imencode(".png", orig_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 + + # Pick best method + if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: + method_used = "word_alignment" + angle_applied = angle_wa + wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) + deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) + if deskewed_bgr is None: + deskewed_bgr = deskewed_hough + method_used = "hough" + angle_applied = angle_hough + else: + method_used = "hough" + angle_applied = angle_hough + deskewed_bgr = deskewed_hough + + success, png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = png_buf.tobytes() if success else b"" + + deskew_result = { + "method_used": method_used, + "rotation_degrees": round(float(angle_applied), 3), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["deskewed_bgr"] = deskewed_bgr + cached["deskew_result"] = deskew_result + await update_session_db( + session_id, + deskewed_png=deskewed_png, + deskew_result=deskew_result, + auto_rotation_degrees=float(angle_applied), + current_step=3, + ) + session = await get_session_db(session_id) + + steps_run.append("deskew") + yield await _auto_sse_event("deskew", "done", deskew_result) + except Exception as e: + logger.error(f"Auto-mode deskew failed for {session_id}: {e}") + error_step = "deskew" + yield await _auto_sse_event("deskew", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("deskew") + yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) + + # ----------------------------------------------------------------- + # Step 2: Dewarp + # ----------------------------------------------------------------- + if req.from_step <= 2: + yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) + try: + t0 = time.time() + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise ValueError("Deskewed image not available") + + if req.dewarp_method == "vlm": + success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success_enc else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success_enc else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(time.time() - t0, 2), + "detections": dewarp_info.get("detections", []), + } + + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=4, + ) + session = await get_session_db(session_id) + + steps_run.append("dewarp") + yield await _auto_sse_event("dewarp", "done", dewarp_result) + except Exception as e: + logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") + error_step = "dewarp" + yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("dewarp") + yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) + + # ----------------------------------------------------------------- + # Step 3: Columns + # ----------------------------------------------------------------- + if req.from_step <= 3: + yield await _auto_sse_event("columns", "start", {}) + try: + t0 = time.time() + col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if col_img is None: + raise ValueError("Cropped/dewarped image not available") + + ocr_img = create_ocr_image(col_img) + h, w = ocr_img.shape[:2] + + geo_result = detect_column_geometry(ocr_img, col_img) + if geo_result is None: + layout_img = create_layout_image(col_img) + regions = analyze_layout(layout_img, ocr_img) + cached["_word_dicts"] = None + cached["_inv"] = None + cached["_content_bounds"] = None + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + content_w = right_x - left_x + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + columns = [asdict(r) for r in regions] + column_result = { + "columns": columns, + "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["column_result"] = column_result + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None, current_step=6) + session = await get_session_db(session_id) + + steps_run.append("columns") + yield await _auto_sse_event("columns", "done", { + "column_count": len(columns), + "duration_seconds": column_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode columns failed for {session_id}: {e}") + error_step = "columns" + yield await _auto_sse_event("columns", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("columns") + yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) + + # ----------------------------------------------------------------- + # Step 4: Rows + # ----------------------------------------------------------------- + if req.from_step <= 4: + yield await _auto_sse_event("rows", "start", {}) + try: + t0 = time.time() + row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + column_result = session.get("column_result") or cached.get("column_result") + if not column_result or not column_result.get("columns"): + raise ValueError("Column detection must complete first") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + ocr_img_tmp = create_ocr_image(row_img) + geo_result = detect_column_geometry(ocr_img_tmp, row_img) + if geo_result is None: + raise ValueError("Column geometry detection failed — cannot detect rows") + _g, lx, rx, ty, by, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (lx, rx, ty, by) + content_bounds = (lx, rx, ty, by) + + left_x, right_x, top_y, bottom_y = content_bounds + row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + row_list = [ + { + "index": r.index, "x": r.x, "y": r.y, + "width": r.width, "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + } + for r in row_geoms + ] + row_result = { + "rows": row_list, + "row_count": len(row_list), + "content_rows": len([r for r in row_geoms if r.row_type == "content"]), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["row_result"] = row_result + await update_session_db(session_id, row_result=row_result, current_step=7) + session = await get_session_db(session_id) + + steps_run.append("rows") + yield await _auto_sse_event("rows", "done", { + "row_count": len(row_list), + "content_rows": row_result["content_rows"], + "duration_seconds": row_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode rows failed for {session_id}: {e}") + error_step = "rows" + yield await _auto_sse_event("rows", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("rows") + yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) + + # ----------------------------------------------------------------- + # Step 5: Words (OCR) + # ----------------------------------------------------------------- + if req.from_step <= 5: + yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) + try: + t0 = time.time() + word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + + column_result = session.get("column_result") or cached.get("column_result") + row_result = session.get("row_result") or cached.get("row_result") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + row_geoms = [ + RowGeometry( + index=r["index"], x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + word_dicts = cached.get("_word_dicts") + if word_dicts is not None: + content_bounds = cached.get("_content_bounds") + top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + ocr_img = create_ocr_image(word_img) + img_h, img_w = word_img.shape[:2] + + cells, columns_meta = build_cell_grid( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=req.ocr_engine, img_bgr=word_img, + ) + duration = time.time() - t0 + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine + + # Apply IPA phonetic fixes directly to cell texts + fix_cell_phonetics(cells, pronunciation=req.pronunciation) + + word_result_data = { + "cells": cells, + "grid_shape": { + "rows": n_content_rows, + "cols": len(columns_meta), + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) + word_result_data["vocab_entries"] = entries + word_result_data["entries"] = entries + word_result_data["entry_count"] = len(entries) + word_result_data["summary"]["total_entries"] = len(entries) + + await update_session_db(session_id, word_result=word_result_data, current_step=8) + cached["word_result"] = word_result_data + session = await get_session_db(session_id) + + steps_run.append("words") + yield await _auto_sse_event("words", "done", { + "total_cells": len(cells), + "layout": word_result_data["layout"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": word_result_data["summary"], + }) + except Exception as e: + logger.error(f"Auto-mode words failed for {session_id}: {e}") + error_step = "words" + yield await _auto_sse_event("words", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("words") + yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) + + # ----------------------------------------------------------------- + # Step 6: LLM Review (optional) + # ----------------------------------------------------------------- + if req.from_step <= 6 and not req.skip_llm_review: + yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) + try: + session = await get_session_db(session_id) + word_result = session.get("word_result") or cached.get("word_result") + entries = word_result.get("entries") or word_result.get("vocab_entries") or [] + + if not entries: + yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) + steps_skipped.append("llm_review") + else: + reviewed = await llm_review_entries(entries) + + session = await get_session_db(session_id) + word_result_updated = dict(session.get("word_result") or {}) + word_result_updated["entries"] = reviewed + word_result_updated["vocab_entries"] = reviewed + word_result_updated["llm_reviewed"] = True + word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL + + await update_session_db(session_id, word_result=word_result_updated, current_step=9) + cached["word_result"] = word_result_updated + + steps_run.append("llm_review") + yield await _auto_sse_event("llm_review", "done", { + "entries_reviewed": len(reviewed), + "model": OLLAMA_REVIEW_MODEL, + }) + except Exception as e: + logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") + yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) + steps_skipped.append("llm_review") + else: + steps_skipped.append("llm_review") + reason = "skipped by request" if req.skip_llm_review else "from_step > 6" + yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) + + # ----------------------------------------------------------------- + # Final event + # ----------------------------------------------------------------- + yield await _auto_sse_event("complete", "done", { + "steps_run": steps_run, + "steps_skipped": steps_skipped, + }) + + return StreamingResponse( + _generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/klausur-service/backend/ocr_pipeline_common.py b/klausur-service/backend/ocr_pipeline_common.py new file mode 100644 index 0000000..3a13cff --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_common.py @@ -0,0 +1,354 @@ +""" +Shared common module for the OCR pipeline. + +Contains in-memory cache, helper functions, Pydantic request models, +pipeline logging, and border-ghost word filtering used by the pipeline +API endpoints and related modules. +""" + +import logging +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import HTTPException +from pydantic import BaseModel + +from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db + +__all__ = [ + # Cache + "_cache", + # Helper functions + "_get_base_image_png", + "_load_session_to_cache", + "_get_cached", + # Pydantic models + "ManualDeskewRequest", + "DeskewGroundTruthRequest", + "ManualDewarpRequest", + "CombinedAdjustRequest", + "DewarpGroundTruthRequest", + "VALID_DOCUMENT_CATEGORIES", + "UpdateSessionRequest", + "ManualColumnsRequest", + "ColumnGroundTruthRequest", + "ManualRowsRequest", + "RowGroundTruthRequest", + "RemoveHandwritingRequest", + # Pipeline log + "_append_pipeline_log", + # Border-ghost filter + "_BORDER_GHOST_CHARS", + "_filter_border_ghost_words", +] + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# In-memory cache for active sessions (BGR numpy arrays for processing) +# DB is source of truth, cache holds BGR arrays during active processing. +# --------------------------------------------------------------------------- + +_cache: Dict[str, Dict[str, Any]] = {} + + +async def _get_base_image_png(session_id: str) -> Optional[bytes]: + """Get the best available base image for a session (cropped > dewarped > original).""" + for img_type in ("cropped", "dewarped", "original"): + png_data = await get_session_image(session_id, img_type) + if png_data: + return png_data + return None + + +async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: + """Load session from DB into cache, decoding PNGs to BGR arrays.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + if session_id in _cache: + return _cache[session_id] + + cache_entry: Dict[str, Any] = { + "id": session_id, + **session, + "original_bgr": None, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + } + + # Decode images from DB into BGR numpy arrays + for img_type, bgr_key in [ + ("original", "original_bgr"), + ("oriented", "oriented_bgr"), + ("cropped", "cropped_bgr"), + ("deskewed", "deskewed_bgr"), + ("dewarped", "dewarped_bgr"), + ]: + png_data = await get_session_image(session_id, img_type) + if png_data: + arr = np.frombuffer(png_data, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + cache_entry[bgr_key] = bgr + + # Sub-sessions: original image IS the cropped box region. + # Promote original_bgr to cropped_bgr so downstream steps find it. + if session.get("parent_session_id") and cache_entry["original_bgr"] is not None: + if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None: + cache_entry["cropped_bgr"] = cache_entry["original_bgr"] + + _cache[session_id] = cache_entry + return cache_entry + + +def _get_cached(session_id: str) -> Dict[str, Any]: + """Get from cache or raise 404.""" + entry = _cache.get(session_id) + if not entry: + raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") + return entry + + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +class ManualDeskewRequest(BaseModel): + angle: float + + +class DeskewGroundTruthRequest(BaseModel): + is_correct: bool + corrected_angle: Optional[float] = None + notes: Optional[str] = None + + +class ManualDewarpRequest(BaseModel): + shear_degrees: float + + +class CombinedAdjustRequest(BaseModel): + rotation_degrees: float = 0.0 + shear_degrees: float = 0.0 + + +class DewarpGroundTruthRequest(BaseModel): + is_correct: bool + corrected_shear: Optional[float] = None + notes: Optional[str] = None + + +VALID_DOCUMENT_CATEGORIES = { + 'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite', + 'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges', +} + + +class UpdateSessionRequest(BaseModel): + name: Optional[str] = None + document_category: Optional[str] = None + + +class ManualColumnsRequest(BaseModel): + columns: List[Dict[str, Any]] + + +class ColumnGroundTruthRequest(BaseModel): + is_correct: bool + corrected_columns: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +class ManualRowsRequest(BaseModel): + rows: List[Dict[str, Any]] + + +class RowGroundTruthRequest(BaseModel): + is_correct: bool + corrected_rows: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +class RemoveHandwritingRequest(BaseModel): + method: str = "auto" # "auto" | "telea" | "ns" + target_ink: str = "all" # "all" | "colored" | "pencil" + dilation: int = 2 # mask dilation iterations (0-5) + use_source: str = "auto" # "original" | "deskewed" | "auto" + + +# --------------------------------------------------------------------------- +# Pipeline Log Helper +# --------------------------------------------------------------------------- + +async def _append_pipeline_log( + session_id: str, + step_name: str, + metrics: Dict[str, Any], + success: bool = True, + duration_ms: Optional[int] = None, +): + """Append a step entry to the session's pipeline_log JSONB.""" + session = await get_session_db(session_id) + if not session: + return + log = session.get("pipeline_log") or {"steps": []} + if not isinstance(log, dict): + log = {"steps": []} + entry = { + "step": step_name, + "completed_at": datetime.utcnow().isoformat(), + "success": success, + "metrics": metrics, + } + if duration_ms is not None: + entry["duration_ms"] = duration_ms + log.setdefault("steps", []).append(entry) + await update_session_db(session_id, pipeline_log=log) + + +# --------------------------------------------------------------------------- +# Border-ghost word filter +# --------------------------------------------------------------------------- + +# Characters that OCR produces when reading box-border lines. +_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"") + + +def _filter_border_ghost_words( + word_result: Dict, + boxes: List, +) -> int: + """Remove OCR words that are actually box border lines. + + A word is considered a border ghost when it sits on a known box edge + (left, right, top, or bottom) and looks like a line artefact (narrow + aspect ratio or text consists only of line-like characters). + + After removing ghost cells, columns that have become empty are also + removed from ``columns_used`` so the grid no longer shows phantom + columns. + + Modifies *word_result* in-place and returns the number of removed cells. + """ + if not boxes or not word_result: + return 0 + + cells = word_result.get("cells") + if not cells: + return 0 + + # Build border bands — vertical (X) and horizontal (Y) + x_bands = [] # list of (x_lo, x_hi) + y_bands = [] # list of (y_lo, y_hi) + for b in boxes: + bx = b.x if hasattr(b, "x") else b.get("x", 0) + by = b.y if hasattr(b, "y") else b.get("y", 0) + bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) + bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) + bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3) + margin = max(bt * 2, 10) + 6 # generous margin + + # Vertical edges (left / right) + x_bands.append((bx - margin, bx + margin)) + x_bands.append((bx + bw - margin, bx + bw + margin)) + # Horizontal edges (top / bottom) + y_bands.append((by - margin, by + margin)) + y_bands.append((by + bh - margin, by + bh + margin)) + + img_w = word_result.get("image_width", 1) + img_h = word_result.get("image_height", 1) + + def _is_ghost(cell: Dict) -> bool: + text = (cell.get("text") or "").strip() + if not text: + return False + + # Compute absolute pixel position + if cell.get("bbox_px"): + px = cell["bbox_px"] + cx = px["x"] + px["w"] / 2 + cy = px["y"] + px["h"] / 2 + cw = px["w"] + ch = px["h"] + elif cell.get("bbox_pct"): + pct = cell["bbox_pct"] + cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2 + cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2 + cw = (pct["w"] / 100) * img_w + ch = (pct["h"] / 100) * img_h + else: + return False + + # Check if center sits on a vertical or horizontal border + on_vertical = any(lo <= cx <= hi for lo, hi in x_bands) + on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands) + if not on_vertical and not on_horizontal: + return False + + # Very short text (1-2 chars) on a border → very likely ghost + if len(text) <= 2: + # Narrow vertically (line-like) or narrow horizontally (dash-like)? + if ch > 0 and cw / ch < 0.5: + return True + if cw > 0 and ch / cw < 0.5: + return True + # Text is only border-ghost characters? + if all(c in _BORDER_GHOST_CHARS for c in text): + return True + + # Longer text but still only ghost chars and very narrow + if all(c in _BORDER_GHOST_CHARS for c in text): + if ch > 0 and cw / ch < 0.35: + return True + if cw > 0 and ch / cw < 0.35: + return True + return True # all ghost chars on a border → remove + + return False + + before = len(cells) + word_result["cells"] = [c for c in cells if not _is_ghost(c)] + removed = before - len(word_result["cells"]) + + # --- Remove empty columns from columns_used --- + columns_used = word_result.get("columns_used") + if removed and columns_used and len(columns_used) > 1: + remaining_cells = word_result["cells"] + occupied_cols = {c.get("col_index") for c in remaining_cells} + before_cols = len(columns_used) + columns_used = [col for col in columns_used if col.get("index") in occupied_cols] + + # Re-index columns and remap cell col_index values + if len(columns_used) < before_cols: + old_to_new = {} + for new_i, col in enumerate(columns_used): + old_to_new[col["index"]] = new_i + col["index"] = new_i + for cell in remaining_cells: + old_ci = cell.get("col_index") + if old_ci in old_to_new: + cell["col_index"] = old_to_new[old_ci] + word_result["columns_used"] = columns_used + logger.info("border-ghost: removed %d empty column(s), %d remaining", + before_cols - len(columns_used), len(columns_used)) + + if removed: + # Update summary counts + summary = word_result.get("summary", {}) + summary["total_cells"] = len(word_result["cells"]) + summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text")) + word_result["summary"] = summary + gs = word_result.get("grid_shape", {}) + gs["total_cells"] = len(word_result["cells"]) + if columns_used is not None: + gs["cols"] = len(columns_used) + word_result["grid_shape"] = gs + + return removed diff --git a/klausur-service/backend/ocr_pipeline_geometry.py b/klausur-service/backend/ocr_pipeline_geometry.py new file mode 100644 index 0000000..6e81936 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_geometry.py @@ -0,0 +1,1025 @@ +""" +OCR Pipeline Geometry API - Deskew, Dewarp, Structure Detection, Column Detection. + +Extracted from ocr_pipeline_api.py to keep modules focused. +Each endpoint group handles a geometric correction or detection step: +- Deskew (Step 2): Correct scan rotation +- Dewarp (Step 3): Correct vertical shear / book warp +- Structure Detection: Boxes, zones, color regions, graphics +- Column Detection (Step 5): Find invisible columns + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import os +import time +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Query + +from cv_vocab_pipeline import ( + _apply_shear, + _detect_header_footer_gaps, + _detect_sub_columns, + classify_column_types, + create_layout_image, + create_ocr_image, + analyze_layout, + deskew_image, + deskew_image_by_word_alignment, + deskew_two_pass, + detect_column_geometry_zoned, + dewarp_image, + dewarp_image_manual, + expand_narrow_columns, +) +from cv_box_detect import detect_boxes +from cv_color_detect import _COLOR_RANGES, _COLOR_HEX +from cv_graphic_detect import detect_graphic_elements +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _get_base_image_png, + _append_pipeline_log, + _filter_border_ghost_words, + ManualDeskewRequest, + DeskewGroundTruthRequest, + ManualDewarpRequest, + CombinedAdjustRequest, + DewarpGroundTruthRequest, + ManualColumnsRequest, + ColumnGroundTruthRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + +# --------------------------------------------------------------------------- +# Deskew Endpoints (Step 2) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/deskew") +async def auto_deskew(session_id: str): + """Two-pass deskew: iterative projection (wide range) + word-alignment residual.""" + # Ensure session is in cache + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + # Deskew runs right after orientation — use oriented image, fall back to original + img_bgr = next((v for k in ("oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), None) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for deskewing") + + t0 = time.time() + + # Two-pass deskew: iterative (±5°) + word-alignment residual check + deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy()) + + # Also run individual methods for reporting (non-authoritative) + try: + _, angle_hough = deskew_image(img_bgr.copy()) + except Exception: + angle_hough = 0.0 + + success_enc, png_orig = cv2.imencode(".png", img_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + _, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + angle_wa = 0.0 + + angle_iterative = two_pass_debug.get("pass1_angle", 0.0) + angle_residual = two_pass_debug.get("pass2_angle", 0.0) + angle_textline = two_pass_debug.get("pass3_angle", 0.0) + + duration = time.time() - t0 + + method_used = "three_pass" if abs(angle_textline) >= 0.01 else ( + "two_pass" if abs(angle_residual) >= 0.01 else "iterative" + ) + + # Encode as PNG + success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = deskewed_png_buf.tobytes() if success else b"" + + # Create binarized version + binarized_png = None + try: + binarized = create_ocr_image(deskewed_bgr) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception as e: + logger.warning(f"Binarization failed: {e}") + + confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0) + + deskew_result = { + "angle_hough": round(angle_hough, 3), + "angle_word_alignment": round(angle_wa, 3), + "angle_iterative": round(angle_iterative, 3), + "angle_residual": round(angle_residual, 3), + "angle_textline": round(angle_textline, 3), + "angle_applied": round(angle_applied, 3), + "method_used": method_used, + "confidence": round(confidence, 2), + "duration_seconds": round(duration, 2), + "two_pass_debug": two_pass_debug, + } + + # Update cache + cached["deskewed_bgr"] = deskewed_bgr + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + "current_step": 3, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: deskew session {session_id}: " + f"hough={angle_hough:.2f} wa={angle_wa:.2f} " + f"iter={angle_iterative:.2f} residual={angle_residual:.2f} " + f"textline={angle_textline:.2f} " + f"-> {method_used} total={angle_applied:.2f}") + + await _append_pipeline_log(session_id, "deskew", { + "angle_applied": round(angle_applied, 3), + "angle_iterative": round(angle_iterative, 3), + "angle_residual": round(angle_residual, 3), + "angle_textline": round(angle_textline, 3), + "confidence": round(confidence, 2), + "method": method_used, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **deskew_result, + "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", + "binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized", + } + + +@router.post("/sessions/{session_id}/deskew/manual") +async def manual_deskew(session_id: str, req: ManualDeskewRequest): + """Apply a manual rotation angle to the oriented image.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = next((v for k in ("oriented_bgr", "original_bgr") + if (v := cached.get(k)) is not None), None) + if img_bgr is None: + raise HTTPException(status_code=400, detail="No image available for deskewing") + + angle = max(-5.0, min(5.0, req.angle)) + + h, w = img_bgr.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(img_bgr, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + success, png_buf = cv2.imencode(".png", rotated) + deskewed_png = png_buf.tobytes() if success else b"" + + # Binarize + binarized_png = None + try: + binarized = create_ocr_image(rotated) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception: + pass + + deskew_result = { + **(cached.get("deskew_result") or {}), + "angle_applied": round(angle, 3), + "method_used": "manual", + } + + # Update cache + cached["deskewed_bgr"] = rotated + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}") + + return { + "session_id": session_id, + "angle_applied": round(angle, 3), + "method_used": "manual", + "deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed", + } + + +@router.post("/sessions/{session_id}/ground-truth/deskew") +async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest): + """Save ground truth feedback for the deskew step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_angle": req.corrected_angle, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "deskew_result": session.get("deskew_result"), + } + ground_truth["deskew"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + # Update cache + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: " + f"correct={req.is_correct}, corrected_angle={req.corrected_angle}") + + return {"session_id": session_id, "ground_truth": gt} + + +# --------------------------------------------------------------------------- +# Dewarp Endpoints +# --------------------------------------------------------------------------- + +async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + import re + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + import json + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} + + +@router.post("/sessions/{session_id}/dewarp") +async def auto_dewarp( + session_id: str, + method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), +): + """Detect and correct vertical shear on the deskewed image. + + Methods: + - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) + - **cv**: CV ensemble only (same as ensemble) + - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually + """ + if method not in ("ensemble", "cv", "vlm"): + raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") + + t0 = time.time() + + if method == "vlm": + # Encode deskewed image to PNG for VLM + success, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + duration = time.time() - t0 + + # Encode as PNG + success, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(duration, 2), + "detections": dewarp_info.get("detections", []), + } + + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=4, + ) + + logger.info(f"OCR Pipeline: dewarp session {session_id}: " + f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " + f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") + + await _append_pipeline_log(session_id, "dewarp", { + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "method": dewarp_info["method"], + "ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])], + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **dewarp_result, + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/dewarp/manual") +async def manual_dewarp(session_id: str, req: ManualDewarpRequest): + """Apply shear correction with a manual angle.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") + + shear_deg = max(-2.0, min(2.0, req.shear_degrees)) + + if abs(shear_deg) < 0.001: + dewarped_bgr = deskewed_bgr + else: + dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) + + success, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + dewarp_result = { + **(cached.get("dewarp_result") or {}), + "method_used": "manual", + "shear_degrees": round(shear_deg, 3), + } + + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + ) + + logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") + + return { + "session_id": session_id, + "shear_degrees": round(shear_deg, 3), + "method_used": "manual", + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/adjust-combined") +async def adjust_combined(session_id: str, req: CombinedAdjustRequest): + """Apply rotation + shear combined to the original image. + + Used by the fine-tuning sliders to preview arbitrary rotation/shear + combinations without re-running the full deskew/dewarp pipeline. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("original_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Original image not available") + + rotation = max(-15.0, min(15.0, req.rotation_degrees)) + shear_deg = max(-5.0, min(5.0, req.shear_degrees)) + + h, w = img_bgr.shape[:2] + result_bgr = img_bgr + + # Step 1: Apply rotation + if abs(rotation) >= 0.001: + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, rotation, 1.0) + result_bgr = cv2.warpAffine(result_bgr, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE) + + # Step 2: Apply shear + if abs(shear_deg) >= 0.001: + result_bgr = dewarp_image_manual(result_bgr, shear_deg) + + # Encode + success, png_buf = cv2.imencode(".png", result_bgr) + dewarped_png = png_buf.tobytes() if success else b"" + + # Binarize + binarized_png = None + try: + binarized = create_ocr_image(result_bgr) + success_bin, bin_buf = cv2.imencode(".png", binarized) + binarized_png = bin_buf.tobytes() if success_bin else None + except Exception: + pass + + # Build combined result dicts + deskew_result = { + **(cached.get("deskew_result") or {}), + "angle_applied": round(rotation, 3), + "method_used": "manual_combined", + } + dewarp_result = { + **(cached.get("dewarp_result") or {}), + "method_used": "manual_combined", + "shear_degrees": round(shear_deg, 3), + } + + # Update cache + cached["deskewed_bgr"] = result_bgr + cached["dewarped_bgr"] = result_bgr + cached["deskew_result"] = deskew_result + cached["dewarp_result"] = dewarp_result + + # Persist to DB + db_update = { + "dewarped_png": dewarped_png, + "deskew_result": deskew_result, + "dewarp_result": dewarp_result, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + db_update["deskewed_png"] = dewarped_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: combined adjust session {session_id}: " + f"rotation={rotation:.3f} shear={shear_deg:.3f}") + + return { + "session_id": session_id, + "rotation_degrees": round(rotation, 3), + "shear_degrees": round(shear_deg, 3), + "method_used": "manual_combined", + "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", + } + + +@router.post("/sessions/{session_id}/ground-truth/dewarp") +async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): + """Save ground truth feedback for the dewarp step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_shear": req.corrected_shear, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "dewarp_result": session.get("dewarp_result"), + } + ground_truth["dewarp"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " + f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") + + return {"session_id": session_id, "ground_truth": gt} + + +# --------------------------------------------------------------------------- +# Structure Detection Endpoint +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-structure") +async def detect_structure(session_id: str): + """Detect document structure: boxes, zones, and color regions. + + Runs box detection (line + shading) and color analysis on the cropped + image. Returns structured JSON with all detected elements for the + structure visualization step. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = ( + cached.get("cropped_bgr") + if cached.get("cropped_bgr") is not None + else cached.get("dewarped_bgr") + ) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") + + t0 = time.time() + h, w = img_bgr.shape[:2] + + # --- Content bounds from word result (if available) or full image --- + word_result = cached.get("word_result") + words: List[Dict] = [] + if word_result and word_result.get("cells"): + for cell in word_result["cells"]: + for wb in (cell.get("word_boxes") or []): + words.append(wb) + # Fallback: use raw OCR words if cell word_boxes are empty + if not words and word_result: + for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"): + raw = word_result.get(key, []) + if raw: + words = raw + logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key) + break + # If no words yet, use image dimensions with small margin + if words: + content_x = max(0, min(int(wb["left"]) for wb in words)) + content_y = max(0, min(int(wb["top"]) for wb in words)) + content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words)) + content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words)) + content_w_px = content_r - content_x + content_h_px = content_b - content_y + else: + margin = int(min(w, h) * 0.03) + content_x, content_y = margin, margin + content_w_px = w - 2 * margin + content_h_px = h - 2 * margin + + # --- Box detection --- + boxes = detect_boxes( + img_bgr, + content_x=content_x, + content_w=content_w_px, + content_y=content_y, + content_h=content_h_px, + ) + + # --- Zone splitting --- + from cv_box_detect import split_page_into_zones as _split_zones + zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes) + + # --- Color region sampling --- + # Sample background shading in each detected box + hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) + box_colors = [] + for box in boxes: + # Sample the center region of each box + cy1 = box.y + box.height // 4 + cy2 = box.y + 3 * box.height // 4 + cx1 = box.x + box.width // 4 + cx2 = box.x + 3 * box.width // 4 + cy1 = max(0, min(cy1, h - 1)) + cy2 = max(0, min(cy2, h - 1)) + cx1 = max(0, min(cx1, w - 1)) + cx2 = max(0, min(cx2, w - 1)) + if cy2 > cy1 and cx2 > cx1: + roi_hsv = hsv[cy1:cy2, cx1:cx2] + med_h = float(np.median(roi_hsv[:, :, 0])) + med_s = float(np.median(roi_hsv[:, :, 1])) + med_v = float(np.median(roi_hsv[:, :, 2])) + if med_s > 15: + from cv_color_detect import _hue_to_color_name + bg_name = _hue_to_color_name(med_h) + bg_hex = _COLOR_HEX.get(bg_name, "#6b7280") + else: + bg_name = "gray" if med_v < 220 else "white" + bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff" + else: + bg_name = "unknown" + bg_hex = "#6b7280" + box_colors.append({"color_name": bg_name, "color_hex": bg_hex}) + + # --- Color text detection overview --- + # Quick scan for colored text regions across the page + color_summary: Dict[str, int] = {} + for color_name, ranges in _COLOR_RANGES.items(): + mask = np.zeros((h, w), dtype=np.uint8) + for lower, upper in ranges: + mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) + pixel_count = int(np.sum(mask > 0)) + if pixel_count > 50: # minimum threshold + color_summary[color_name] = pixel_count + + # --- Graphic element detection --- + box_dicts = [ + {"x": b.x, "y": b.y, "w": b.width, "h": b.height} + for b in boxes + ] + graphics = detect_graphic_elements( + img_bgr, words, + detected_boxes=box_dicts, + ) + + # --- Filter border-ghost words from OCR result --- + ghost_count = 0 + if boxes and word_result: + ghost_count = _filter_border_ghost_words(word_result, boxes) + if ghost_count: + logger.info("detect-structure: removed %d border-ghost words", ghost_count) + await update_session_db(session_id, word_result=word_result) + cached["word_result"] = word_result + + duration = time.time() - t0 + + result_dict = { + "image_width": w, + "image_height": h, + "content_bounds": { + "x": content_x, "y": content_y, + "w": content_w_px, "h": content_h_px, + }, + "boxes": [ + { + "x": b.x, "y": b.y, "w": b.width, "h": b.height, + "confidence": b.confidence, + "border_thickness": b.border_thickness, + "bg_color_name": box_colors[i]["color_name"], + "bg_color_hex": box_colors[i]["color_hex"], + } + for i, b in enumerate(boxes) + ], + "zones": [ + { + "index": z.index, + "zone_type": z.zone_type, + "y": z.y, "h": z.height, + "x": z.x, "w": z.width, + } + for z in zones + ], + "graphics": [ + { + "x": g.x, "y": g.y, "w": g.width, "h": g.height, + "area": g.area, + "shape": g.shape, + "color_name": g.color_name, + "color_hex": g.color_hex, + "confidence": round(g.confidence, 2), + } + for g in graphics + ], + "color_pixel_counts": color_summary, + "has_words": len(words) > 0, + "word_count": len(words), + "border_ghosts_removed": ghost_count, + "duration_seconds": round(duration, 2), + } + + # Persist to session + await update_session_db(session_id, structure_result=result_dict) + cached["structure_result"] = result_dict + + logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs", + session_id, len(boxes), len(zones), len(graphics), duration) + + return {"session_id": session_id, **result_dict} + + +# --------------------------------------------------------------------------- +# Column Detection Endpoints (Step 3) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/columns") +async def detect_columns(session_id: str): + """Run column detection on the cropped (or dewarped) image.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection") + + # ----------------------------------------------------------------------- + # Sub-sessions (box crops): skip column detection entirely. + # Instead, create a single pseudo-column spanning the full image width. + # Also run Tesseract + binarization here so that the row detection step + # can reuse the cached intermediates (_word_dicts, _inv, _content_bounds) + # instead of falling back to detect_column_geometry() which may fail + # on small box images with < 5 words. + # ----------------------------------------------------------------------- + session = await get_session_db(session_id) + if session and session.get("parent_session_id"): + h, w = img_bgr.shape[:2] + + # Binarize + invert for row detection (horizontal projection profile) + ocr_img = create_ocr_image(img_bgr) + inv = cv2.bitwise_not(ocr_img) + + # Run Tesseract to get word bounding boxes. + # Word positions are relative to the full image (no ROI crop needed + # because the sub-session image IS the cropped box already). + # detect_row_geometry expects word positions relative to content ROI, + # so with content_bounds = (0, w, 0, h) the coordinates are correct. + try: + from PIL import Image as PILImage + pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) + import pytesseract + data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT) + word_dicts = [] + for i in range(len(data['text'])): + conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1 + text = str(data['text'][i]).strip() + if conf < 30 or not text: + continue + word_dicts.append({ + 'text': text, 'conf': conf, + 'left': int(data['left'][i]), + 'top': int(data['top'][i]), + 'width': int(data['width'][i]), + 'height': int(data['height'][i]), + }) + # Log all words including low-confidence ones for debugging + all_count = sum(1 for i in range(len(data['text'])) + if str(data['text'][i]).strip()) + low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) + for i in range(len(data['text'])) + if str(data['text'][i]).strip() + and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30 + and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0] + if low_conf: + logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}") + logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)") + except Exception as e: + logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}") + word_dicts = [] + + # Cache intermediates for row detection (detect_rows reuses these) + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (0, w, 0, h) + + column_result = { + "columns": [{ + "type": "column_text", + "x": 0, "y": 0, + "width": w, "height": h, + }], + "zones": None, + "boxes_detected": 0, + "duration_seconds": 0, + "method": "sub_session_pseudo_column", + } + await update_session_db( + session_id, + column_result=column_result, + row_result=None, + word_result=None, + current_step=6, + ) + cached["column_result"] = column_result + cached.pop("row_result", None) + cached.pop("word_result", None) + logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px") + return {"session_id": session_id, **column_result} + + t0 = time.time() + + # Binarized image for layout analysis + ocr_img = create_ocr_image(img_bgr) + h, w = ocr_img.shape[:2] + + # Phase A: Zone-aware geometry detection + zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr) + + if zoned_result is None: + # Fallback to projection-based layout + layout_img = create_layout_image(img_bgr) + regions = analyze_layout(layout_img, ocr_img) + zones_data = None + boxes_detected = 0 + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result + content_w = right_x - left_x + boxes_detected = len(boxes) + + # Cache intermediates for row detection (avoids second Tesseract run) + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + cached["_zones_data"] = zones_data + cached["_boxes_detected"] = boxes_detected + + # Detect header/footer early so sub-column clustering ignores them + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + + # Split sub-columns (e.g. page references) before classification + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + + # Expand narrow columns (sub-columns are often very narrow) + geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts) + + # Phase B: Content-based classification + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + duration = time.time() - t0 + + columns = [asdict(r) for r in regions] + + # Determine classification methods used + methods = list(set( + c.get("classification_method", "") for c in columns + if c.get("classification_method") + )) + + column_result = { + "columns": columns, + "classification_methods": methods, + "duration_seconds": round(duration, 2), + "boxes_detected": boxes_detected, + } + + # Add zone data when boxes are present + if zones_data and boxes_detected > 0: + column_result["zones"] = zones_data + + # Persist to DB — also invalidate downstream results (rows, words) + await update_session_db( + session_id, + column_result=column_result, + row_result=None, + word_result=None, + current_step=6, + ) + + # Update cache + cached["column_result"] = column_result + cached.pop("row_result", None) + cached.pop("word_result", None) + + col_count = len([c for c in columns if c["type"].startswith("column")]) + logger.info(f"OCR Pipeline: columns session {session_id}: " + f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)") + + img_w = img_bgr.shape[1] + await _append_pipeline_log(session_id, "columns", { + "total_columns": len(columns), + "column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns], + "column_types": [c["type"] for c in columns], + "boxes_detected": boxes_detected, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **column_result, + } + + +@router.post("/sessions/{session_id}/columns/manual") +async def set_manual_columns(session_id: str, req: ManualColumnsRequest): + """Override detected columns with manual definitions.""" + column_result = { + "columns": req.columns, + "duration_seconds": 0, + "method": "manual", + } + + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None) + + if session_id in _cache: + _cache[session_id]["column_result"] = column_result + _cache[session_id].pop("row_result", None) + _cache[session_id].pop("word_result", None) + + logger.info(f"OCR Pipeline: manual columns session {session_id}: " + f"{len(req.columns)} columns set") + + return {"session_id": session_id, **column_result} + + +@router.post("/sessions/{session_id}/ground-truth/columns") +async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): + """Save ground truth feedback for the column detection step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_columns": req.corrected_columns, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "column_result": session.get("column_result"), + } + ground_truth["columns"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@router.get("/sessions/{session_id}/ground-truth/columns") +async def get_column_ground_truth(session_id: str): + """Retrieve saved ground truth for column detection, including auto vs GT diff.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + columns_gt = ground_truth.get("columns") + if not columns_gt: + raise HTTPException(status_code=404, detail="No column ground truth saved") + + return { + "session_id": session_id, + "columns_gt": columns_gt, + "columns_auto": session.get("column_result"), + } diff --git a/klausur-service/backend/ocr_pipeline_ocr_merge.py b/klausur-service/backend/ocr_pipeline_ocr_merge.py new file mode 100644 index 0000000..d8b4c8c --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_ocr_merge.py @@ -0,0 +1,615 @@ +""" +OCR Merge Helpers and Kombi Endpoints. + +Contains merge helper functions for combining PaddleOCR/RapidOCR with Tesseract +results, plus the paddle-kombi and rapid-kombi endpoints. + +Extracted from ocr_pipeline_api.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from typing import Any, Dict, List + +import cv2 +import httpx +import numpy as np +from fastapi import APIRouter, HTTPException + +from cv_words_first import build_grid_from_words +from ocr_pipeline_common import _cache, _append_pipeline_log +from ocr_pipeline_session_store import get_session_image, update_session_db + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Merge helper functions +# --------------------------------------------------------------------------- + + +def _split_paddle_multi_words(words: list) -> list: + """Split PaddleOCR multi-word boxes into individual word boxes. + + PaddleOCR often returns entire phrases as a single box, e.g. + "More than 200 singers took part in the" with one bounding box. + This splits them into individual words with proportional widths. + Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"]) + and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]). + """ + import re + + result = [] + for w in words: + raw_text = w.get("text", "").strip() + if not raw_text: + continue + # Split on whitespace, before "[" (IPA), and after "!" before letter + tokens = re.split( + r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text + ) + tokens = [t for t in tokens if t] + + if len(tokens) <= 1: + result.append(w) + else: + # Split proportionally by character count + total_chars = sum(len(t) for t in tokens) + if total_chars == 0: + continue + n_gaps = len(tokens) - 1 + gap_px = w["width"] * 0.02 + usable_w = w["width"] - gap_px * n_gaps + cursor = w["left"] + for t in tokens: + token_w = max(1, usable_w * len(t) / total_chars) + result.append({ + "text": t, + "left": round(cursor), + "top": w["top"], + "width": round(token_w), + "height": w["height"], + "conf": w.get("conf", 0), + }) + cursor += token_w + gap_px + return result + + +def _group_words_into_rows(words: list, row_gap: int = 12) -> list: + """Group words into rows by Y-position clustering. + + Words whose vertical centers are within `row_gap` pixels are on the same row. + Returns list of rows, each row is a list of words sorted left-to-right. + """ + if not words: + return [] + # Sort by vertical center + sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) + rows: list = [] + current_row: list = [sorted_words[0]] + current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 + + for w in sorted_words[1:]: + cy = w["top"] + w.get("height", 0) / 2 + if abs(cy - current_cy) <= row_gap: + current_row.append(w) + else: + # Sort current row left-to-right before saving + rows.append(sorted(current_row, key=lambda w: w["left"])) + current_row = [w] + current_cy = cy + if current_row: + rows.append(sorted(current_row, key=lambda w: w["left"])) + return rows + + +def _row_center_y(row: list) -> float: + """Average vertical center of a row of words.""" + if not row: + return 0.0 + return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) + + +def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: + """Merge two word sequences from the same row using sequence alignment. + + Both sequences are sorted left-to-right. Walk through both simultaneously: + - If words match (same/similar text): take Paddle text with averaged coords + - If they don't match: the extra word is unique to one engine, include it + + This prevents duplicates because both engines produce words in the same order. + """ + merged = [] + pi, ti = 0, 0 + + while pi < len(paddle_row) and ti < len(tess_row): + pw = paddle_row[pi] + tw = tess_row[ti] + + # Check if these are the same word + pt = pw.get("text", "").lower().strip() + tt = tw.get("text", "").lower().strip() + + # Same text or one contains the other + is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) + + # Spatial overlap check: if words overlap >= 40% horizontally, + # they're the same physical word regardless of OCR text differences. + # (40% catches borderline cases like "Stick"/"Stück" at 48% overlap) + spatial_match = False + if not is_same: + overlap_left = max(pw["left"], tw["left"]) + overlap_right = min( + pw["left"] + pw.get("width", 0), + tw["left"] + tw.get("width", 0), + ) + overlap_w = max(0, overlap_right - overlap_left) + min_w = min(pw.get("width", 1), tw.get("width", 1)) + if min_w > 0 and overlap_w / min_w >= 0.4: + is_same = True + spatial_match = True + + if is_same: + # Matched — average coordinates weighted by confidence + pc = pw.get("conf", 80) + tc = tw.get("conf", 50) + total = pc + tc + if total == 0: + total = 1 + # Text: prefer higher-confidence engine when texts differ + # (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80) + if spatial_match and pc < tc: + best_text = tw["text"] + else: + best_text = pw["text"] + merged.append({ + "text": best_text, + "left": round((pw["left"] * pc + tw["left"] * tc) / total), + "top": round((pw["top"] * pc + tw["top"] * tc) / total), + "width": round((pw["width"] * pc + tw["width"] * tc) / total), + "height": round((pw["height"] * pc + tw["height"] * tc) / total), + "conf": max(pc, tc), + }) + pi += 1 + ti += 1 + else: + # Different text — one engine found something extra + # Look ahead: is the current Paddle word somewhere in Tesseract ahead? + paddle_ahead = any( + tess_row[t].get("text", "").lower().strip() == pt + for t in range(ti + 1, min(ti + 4, len(tess_row))) + ) + # Is the current Tesseract word somewhere in Paddle ahead? + tess_ahead = any( + paddle_row[p].get("text", "").lower().strip() == tt + for p in range(pi + 1, min(pi + 4, len(paddle_row))) + ) + + if paddle_ahead and not tess_ahead: + # Tesseract has an extra word (e.g. "!" or bullet) → include it + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + elif tess_ahead and not paddle_ahead: + # Paddle has an extra word → include it + merged.append(pw) + pi += 1 + else: + # Both have unique words or neither found ahead → take leftmost first + if pw["left"] <= tw["left"]: + merged.append(pw) + pi += 1 + else: + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + # Remaining words from either engine + while pi < len(paddle_row): + merged.append(paddle_row[pi]) + pi += 1 + while ti < len(tess_row): + tw = tess_row[ti] + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + return merged + + +def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: + """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment. + + Strategy: + 1. Group each engine's words into rows (by Y-position clustering) + 2. Match rows between engines (by vertical center proximity) + 3. Within each matched row: merge sequences left-to-right, deduplicating + words that appear in both engines at the same sequence position + 4. Unmatched rows from either engine: keep as-is + + This prevents: + - Cross-line averaging (words from different lines being merged) + - Duplicate words (same word from both engines shown twice) + """ + if not paddle_words and not tess_words: + return [] + if not paddle_words: + return [w for w in tess_words if w.get("conf", 0) >= 40] + if not tess_words: + return list(paddle_words) + + # Step 1: Group into rows + paddle_rows = _group_words_into_rows(paddle_words) + tess_rows = _group_words_into_rows(tess_words) + + # Step 2: Match rows between engines by vertical center proximity + used_tess_rows: set = set() + merged_all: list = [] + + for pr in paddle_rows: + pr_cy = _row_center_y(pr) + best_dist, best_tri = float("inf"), -1 + for tri, tr in enumerate(tess_rows): + if tri in used_tess_rows: + continue + tr_cy = _row_center_y(tr) + dist = abs(pr_cy - tr_cy) + if dist < best_dist: + best_dist, best_tri = dist, tri + + # Row height threshold — rows must be within ~1.5x typical line height + max_row_dist = max( + max((w.get("height", 20) for w in pr), default=20), + 15, + ) + + if best_tri >= 0 and best_dist <= max_row_dist: + # Matched row — merge sequences + tr = tess_rows[best_tri] + used_tess_rows.add(best_tri) + merged_all.extend(_merge_row_sequences(pr, tr)) + else: + # No matching Tesseract row — keep Paddle row as-is + merged_all.extend(pr) + + # Add unmatched Tesseract rows + for tri, tr in enumerate(tess_rows): + if tri not in used_tess_rows: + for tw in tr: + if tw.get("conf", 0) >= 40: + merged_all.append(tw) + + return merged_all + + +def _deduplicate_words(words: list) -> list: + """Remove duplicate words with same text at overlapping positions. + + PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =") + that produce duplicate words after splitting. This pass removes them. + + A word is a duplicate only when BOTH horizontal AND vertical overlap + exceed 50% — same text on the same visual line at the same position. + """ + if not words: + return words + + result: list = [] + for w in words: + wt = w.get("text", "").lower().strip() + if not wt: + continue + is_dup = False + w_right = w["left"] + w.get("width", 0) + w_bottom = w["top"] + w.get("height", 0) + for existing in result: + et = existing.get("text", "").lower().strip() + if wt != et: + continue + # Horizontal overlap + ox_l = max(w["left"], existing["left"]) + ox_r = min(w_right, existing["left"] + existing.get("width", 0)) + ox = max(0, ox_r - ox_l) + min_w = min(w.get("width", 1), existing.get("width", 1)) + if min_w <= 0 or ox / min_w < 0.5: + continue + # Vertical overlap — must also be on the same line + oy_t = max(w["top"], existing["top"]) + oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) + oy = max(0, oy_b - oy_t) + min_h = min(w.get("height", 1), existing.get("height", 1)) + if min_h > 0 and oy / min_h >= 0.5: + is_dup = True + break + if not is_dup: + result.append(w) + + removed = len(words) - len(result) + if removed: + logger.info("dedup: removed %d duplicate words", removed) + return result + + +# --------------------------------------------------------------------------- +# Kombi endpoints +# --------------------------------------------------------------------------- + + +@router.post("/sessions/{session_id}/paddle-kombi") +async def paddle_kombi(session_id: str): + """Run PaddleOCR + Tesseract on the preprocessed image and merge results. + + Both engines run on the same preprocessed (cropped/dewarped) image. + Word boxes are matched by IoU and coordinates are averaged weighted by + confidence. Unmatched Tesseract words (bullets, symbols) are added. + """ + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode image") + + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_paddle + + t0 = time.time() + + # --- PaddleOCR --- + paddle_words = await ocr_region_paddle(img_bgr, region=None) + if not paddle_words: + paddle_words = [] + + # --- Tesseract --- + from PIL import Image + import pytesseract + + pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) + data = pytesseract.image_to_data( + pil_img, lang="eng+deu", + config="--psm 6 --oem 3", + output_type=pytesseract.Output.DICT, + ) + tess_words = [] + for i in range(len(data["text"])): + text = str(data["text"][i]).strip() + conf_raw = str(data["conf"][i]) + conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 + if not text or conf < 20: + continue + tess_words.append({ + "text": text, + "left": data["left"][i], + "top": data["top"][i], + "width": data["width"][i], + "height": data["height"][i], + "conf": conf, + }) + + # --- Split multi-word Paddle boxes into individual words --- + paddle_words_split = _split_paddle_multi_words(paddle_words) + logger.info( + "paddle_kombi: split %d paddle boxes → %d individual words", + len(paddle_words), len(paddle_words_split), + ) + + # --- Merge --- + if not paddle_words_split and not tess_words: + raise HTTPException(status_code=400, detail="Both OCR engines returned no words") + + merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words) + merged_words = _deduplicate_words(merged_words) + + cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) + duration = time.time() - t0 + + for cell in cells: + cell["ocr_engine"] = "kombi" + + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": "kombi", + "grid_method": "kombi", + "raw_paddle_words": paddle_words, + "raw_paddle_words_split": paddle_words_split, + "raw_tesseract_words": tess_words, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + "paddle_words": len(paddle_words), + "paddle_words_split": len(paddle_words_split), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + }, + } + + await update_session_db( + session_id, + word_result=word_result, + cropped_png=img_png, + current_step=8, + ) + # Update in-memory cache so detect-structure can access word_result + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info( + "paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " + "[paddle=%d, tess=%d, merged=%d]", + session_id, len(cells), n_rows, n_cols, duration, + len(paddle_words), len(tess_words), len(merged_words), + ) + + await _append_pipeline_log(session_id, "paddle_kombi", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "paddle_words": len(paddle_words), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + "ocr_engine": "kombi", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +@router.post("/sessions/{session_id}/rapid-kombi") +async def rapid_kombi(session_id: str): + """Run RapidOCR + Tesseract on the preprocessed image and merge results. + + Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime) + instead of remote PaddleOCR service. + """ + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode image") + + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_rapid + from cv_vocab_types import PageRegion + + t0 = time.time() + + # --- RapidOCR (local, synchronous) --- + full_region = PageRegion( + type="full_page", x=0, y=0, width=img_w, height=img_h, + ) + rapid_words = ocr_region_rapid(img_bgr, full_region) + if not rapid_words: + rapid_words = [] + + # --- Tesseract --- + from PIL import Image + import pytesseract + + pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) + data = pytesseract.image_to_data( + pil_img, lang="eng+deu", + config="--psm 6 --oem 3", + output_type=pytesseract.Output.DICT, + ) + tess_words = [] + for i in range(len(data["text"])): + text = str(data["text"][i]).strip() + conf_raw = str(data["conf"][i]) + conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 + if not text or conf < 20: + continue + tess_words.append({ + "text": text, + "left": data["left"][i], + "top": data["top"][i], + "width": data["width"][i], + "height": data["height"][i], + "conf": conf, + }) + + # --- Split multi-word RapidOCR boxes into individual words --- + rapid_words_split = _split_paddle_multi_words(rapid_words) + logger.info( + "rapid_kombi: split %d rapid boxes → %d individual words", + len(rapid_words), len(rapid_words_split), + ) + + # --- Merge --- + if not rapid_words_split and not tess_words: + raise HTTPException(status_code=400, detail="Both OCR engines returned no words") + + merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words) + merged_words = _deduplicate_words(merged_words) + + cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h) + duration = time.time() - t0 + + for cell in cells: + cell["ocr_engine"] = "rapid_kombi" + + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": "rapid_kombi", + "grid_method": "rapid_kombi", + "raw_rapid_words": rapid_words, + "raw_rapid_words_split": rapid_words_split, + "raw_tesseract_words": tess_words, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + "rapid_words": len(rapid_words), + "rapid_words_split": len(rapid_words_split), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + }, + } + + await update_session_db( + session_id, + word_result=word_result, + cropped_png=img_png, + current_step=8, + ) + # Update in-memory cache so detect-structure can access word_result + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info( + "rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " + "[rapid=%d, tess=%d, merged=%d]", + session_id, len(cells), n_rows, n_cols, duration, + len(rapid_words), len(tess_words), len(merged_words), + ) + + await _append_pipeline_log(session_id, "rapid_kombi", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "rapid_words": len(rapid_words), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + "ocr_engine": "rapid_kombi", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} diff --git a/klausur-service/backend/ocr_pipeline_postprocess.py b/klausur-service/backend/ocr_pipeline_postprocess.py new file mode 100644 index 0000000..3445800 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_postprocess.py @@ -0,0 +1,929 @@ +""" +OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation, +image detection/generation, and handwriting removal endpoints. + +Extracted from ocr_pipeline_api.py to keep the main module manageable. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import re +from datetime import datetime +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + llm_review_entries, + llm_review_entries_streaming, +) +from ocr_pipeline_session_store import ( + get_session_db, + get_session_image, + get_sub_sessions, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _get_base_image_png, + _append_pipeline_log, + RemoveHandwritingRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +STYLE_SUFFIXES = { + "educational": "educational illustration, textbook style, clear, colorful", + "cartoon": "cartoon, child-friendly, simple shapes", + "sketch": "pencil sketch, hand-drawn, black and white", + "clipart": "clipart, flat vector style, simple", + "realistic": "photorealistic, high detail", +} + + +class ValidationRequest(BaseModel): + notes: Optional[str] = None + score: Optional[int] = None + + +class GenerateImageRequest(BaseModel): + region_index: int + prompt: str + style: str = "educational" + + +# --------------------------------------------------------------------------- +# Step 8: LLM Review +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/llm-review") +async def run_llm_review(session_id: str, request: Request, stream: bool = False): + """Run LLM-based correction on vocab entries from Step 5. + + Query params: + stream: false (default) for JSON response, true for SSE streaming + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") + + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if not entries: + raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") + + # Optional model override from request body + body = {} + try: + body = await request.json() + except Exception: + pass + model = body.get("model") or OLLAMA_REVIEW_MODEL + + if stream: + return StreamingResponse( + _llm_review_stream_generator(session_id, entries, word_result, model, request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + + # Non-streaming path + try: + result = await llm_review_entries(entries, model=model) + except Exception as e: + import traceback + logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") + + # Store result inside word_result as a sub-key + word_result["llm_review"] = { + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "entries_corrected": result["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " + f"{result['duration_ms']}ms, model={result['model_used']}") + + await _append_pipeline_log(session_id, "correction", { + "engine": "llm", + "model": result["model_used"], + "total_entries": len(entries), + "corrections_proposed": len(result["changes"]), + }, duration_ms=result["duration_ms"]) + + return { + "session_id": session_id, + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "total_entries": len(entries), + "corrections_found": len(result["changes"]), + } + + +async def _llm_review_stream_generator( + session_id: str, + entries: List[Dict], + word_result: Dict, + model: str, + request: Request, +): + """SSE generator that yields batch-by-batch LLM review progress.""" + try: + async for event in llm_review_entries_streaming(entries, model=model): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during LLM review for {session_id}") + return + + yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" + + # On complete: persist to DB + if event.get("type") == "complete": + word_result["llm_review"] = { + "changes": event["changes"], + "model_used": event["model_used"], + "duration_ms": event["duration_ms"], + "entries_corrected": event["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " + f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") + + except Exception as e: + import traceback + logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} + yield f"data: {json.dumps(error_event)}\n\n" + + +@router.post("/sessions/{session_id}/llm-review/apply") +async def apply_llm_corrections(session_id: str, request: Request): + """Apply selected LLM corrections to vocab entries.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + llm_review = word_result.get("llm_review") + if not llm_review: + raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") + + body = await request.json() + accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] + + changes = llm_review.get("changes", []) + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + + # Build a lookup: (row_index, field) -> new_value for accepted changes + corrections = {} + applied_count = 0 + for idx, change in enumerate(changes): + if idx in accepted_indices: + key = (change["row_index"], change["field"]) + corrections[key] = change["new"] + applied_count += 1 + + # Apply corrections to entries + for entry in entries: + row_idx = entry.get("row_index", -1) + for field_name in ("english", "german", "example"): + key = (row_idx, field_name) + if key in corrections: + entry[field_name] = corrections[key] + entry["llm_corrected"] = True + + # Update word_result + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["llm_review"]["applied_count"] = applied_count + word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() + + await update_session_db(session_id, word_result=word_result) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") + + return { + "session_id": session_id, + "applied_count": applied_count, + "total_changes": len(changes), + } + + +# --------------------------------------------------------------------------- +# Step 9: Reconstruction + Fabric JSON export +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction") +async def save_reconstruction(session_id: str, request: Request): + """Save edited cell texts from reconstruction step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + body = await request.json() + cell_updates = body.get("cells", []) + + if not cell_updates: + await update_session_db(session_id, current_step=10) + return {"session_id": session_id, "updated": 0} + + # Build update map: cell_id -> new text + update_map = {c["cell_id"]: c["text"] for c in cell_updates} + + # Separate sub-session updates (cell_ids prefixed with "box{N}_") + sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} + main_updates: Dict[str, str] = {} + for cell_id, text in update_map.items(): + m = re.match(r'^box(\d+)_(.+)$', cell_id) + if m: + bi = int(m.group(1)) + original_id = m.group(2) + sub_updates.setdefault(bi, {})[original_id] = text + else: + main_updates[cell_id] = text + + # Update main session cells + cells = word_result.get("cells", []) + updated_count = 0 + for cell in cells: + if cell["cell_id"] in main_updates: + cell["text"] = main_updates[cell["cell_id"]] + cell["status"] = "edited" + updated_count += 1 + + word_result["cells"] = cells + + # Also update vocab_entries if present + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if entries: + # Map cell_id pattern "R{row}_C{col}" to entry fields + for entry in entries: + row_idx = entry.get("row_index", -1) + # Check each field's cell + for col_idx, field_name in enumerate(["english", "german", "example"]): + cell_id = f"R{row_idx:02d}_C{col_idx}" + # Also try without zero-padding + cell_id_alt = f"R{row_idx}_C{col_idx}" + new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) + if new_text is not None: + entry[field_name] = new_text + + word_result["vocab_entries"] = entries + if "entries" in word_result: + word_result["entries"] = entries + + await update_session_db(session_id, word_result=word_result, current_step=10) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + # Route sub-session updates + sub_updated = 0 + if sub_updates: + subs = await get_sub_sessions(session_id) + sub_by_index = {s.get("box_index"): s["id"] for s in subs} + for bi, updates in sub_updates.items(): + sub_id = sub_by_index.get(bi) + if not sub_id: + continue + sub_session = await get_session_db(sub_id) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word: + continue + sub_cells = sub_word.get("cells", []) + for cell in sub_cells: + if cell["cell_id"] in updates: + cell["text"] = updates[cell["cell_id"]] + cell["status"] = "edited" + sub_updated += 1 + sub_word["cells"] = sub_cells + await update_session_db(sub_id, word_result=sub_word) + if sub_id in _cache: + _cache[sub_id]["word_result"] = sub_word + + total_updated = updated_count + sub_updated + logger.info(f"Reconstruction saved for session {session_id}: " + f"{updated_count} main + {sub_updated} sub-session cells updated") + + return { + "session_id": session_id, + "updated": total_updated, + "main_updated": updated_count, + "sub_updated": sub_updated, + } + + +@router.get("/sessions/{session_id}/reconstruction/fabric-json") +async def get_fabric_json(session_id: str): + """Return cell grid as Fabric.js-compatible JSON for the canvas editor. + + If the session has sub-sessions (box regions), their cells are merged + into the result at the correct Y positions. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = list(word_result.get("cells", [])) + img_w = word_result.get("image_width", 800) + img_h = word_result.get("image_height", 600) + + # Merge sub-session cells at box positions + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word or not sub_word.get("cells"): + continue + + bi = sub.get("box_index", 0) + if bi < len(box_zones): + box = box_zones[bi]["box"] + box_y, box_x = box["y"], box["x"] + else: + box_y, box_x = 0, 0 + + # Offset sub-session cells to absolute page coordinates + for cell in sub_word["cells"]: + cell_copy = dict(cell) + # Prefix cell_id with box index + cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" + cell_copy["source"] = f"box_{bi}" + # Offset bbox_px + bbox = cell_copy.get("bbox_px", {}) + if bbox: + bbox = dict(bbox) + bbox["x"] = bbox.get("x", 0) + box_x + bbox["y"] = bbox.get("y", 0) + box_y + cell_copy["bbox_px"] = bbox + cells.append(cell_copy) + + from services.layout_reconstruction_service import cells_to_fabric_json + fabric_json = cells_to_fabric_json(cells, img_w, img_h) + + return fabric_json + + +# --------------------------------------------------------------------------- +# Vocab entries merged + PDF/DOCX export +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/vocab-entries/merged") +async def get_merged_vocab_entries(session_id: str): + """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") or {} + entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) + + # Tag main entries + for e in entries: + e.setdefault("source", "main") + + # Merge sub-session entries + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") or {} + sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] + + bi = sub.get("box_index", 0) + box_y = 0 + if bi < len(box_zones): + box_y = box_zones[bi]["box"]["y"] + + for e in sub_entries: + e_copy = dict(e) + e_copy["source"] = f"box_{bi}" + e_copy["source_y"] = box_y # for sorting + entries.append(e_copy) + + # Sort by approximate Y position + def _sort_key(e): + if e.get("source", "main") == "main": + return e.get("row_index", 0) * 100 # main entries by row index + return e.get("source_y", 0) * 100 + e.get("row_index", 0) + + entries.sort(key=_sort_key) + + return { + "session_id": session_id, + "entries": entries, + "total": len(entries), + "sources": list(set(e.get("source", "main") for e in entries)), + } + + +@router.get("/sessions/{session_id}/reconstruction/export/pdf") +async def export_reconstruction_pdf(session_id: str): + """Export the reconstructed cell grid as a PDF table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + # Build table data: rows x columns + table_data: list[list[str]] = [] + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + table_data.append(header) + + for r in range(n_rows): + row_texts = [] + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + row_texts.append(cell.get("text", "") if cell else "") + table_data.append(row_texts) + + # Generate PDF with reportlab + try: + from reportlab.lib.pagesizes import A4 + from reportlab.lib import colors + from reportlab.platypus import SimpleDocTemplate, Table, TableStyle + import io as _io + + buf = _io.BytesIO() + doc = SimpleDocTemplate(buf, pagesize=A4) + if not table_data or not table_data[0]: + raise HTTPException(status_code=400, detail="No data to export") + + t = Table(table_data) + t.setStyle(TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), + ('FONTSIZE', (0, 0), (-1, -1), 9), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('WORDWRAP', (0, 0), (-1, -1), True), + ])) + doc.build([t]) + buf.seek(0) + + from fastapi.responses import StreamingResponse + return StreamingResponse( + buf, + media_type="application/pdf", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="reportlab not installed") + + +@router.get("/sessions/{session_id}/reconstruction/export/docx") +async def export_reconstruction_docx(session_id: str): + """Export the reconstructed cell grid as a DOCX table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + try: + from docx import Document + from docx.shared import Pt + import io as _io + + doc = Document() + doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1) + + # Build header + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + + table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) + table.style = 'Table Grid' + + # Header row + for ci, h in enumerate(header): + table.rows[0].cells[ci].text = h + + # Data rows + for r in range(n_rows): + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" + + buf = _io.BytesIO() + doc.save(buf) + buf.seek(0) + + from fastapi.responses import StreamingResponse + return StreamingResponse( + buf, + media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="python-docx not installed") + + +# --------------------------------------------------------------------------- +# Step 8: Validation — Original vs. Reconstruction +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction/detect-images") +async def detect_image_regions(session_id: str): + """Detect illustration/image regions in the original scan using VLM. + + Sends the original image to qwen2.5vl to find non-text, non-table + image areas, returning bounding boxes (in %) and descriptions. + """ + import base64 + import httpx + import re + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # Get original image bytes + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=400, detail="No original image found") + + # Build context from vocab entries for richer descriptions + word_result = session.get("word_result") or {} + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + vocab_context = "" + if entries: + sample = entries[:10] + words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] + if words: + vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "Analyze this scanned page. Find ALL illustration/image/picture regions " + "(NOT text, NOT table cells, NOT blank areas). " + "For each image region found, return its bounding box as percentage of page dimensions " + "and a short English description of what the image shows. " + "Reply with ONLY a JSON array like: " + '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' + "where x, y, w, h are percentages (0-100) of the page width/height. " + "If there are NO images on the page, return an empty array: []" + f"{vocab_context}" + ) + + img_b64 = base64.b64encode(original_png).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON array from response + match = re.search(r'\[.*?\]', text, re.DOTALL) + if match: + raw_regions = json.loads(match.group(0)) + else: + raw_regions = [] + + # Normalize to ImageRegion format + regions = [] + for r in raw_regions: + regions.append({ + "bbox_pct": { + "x": max(0, min(100, float(r.get("x", 0)))), + "y": max(0, min(100, float(r.get("y", 0)))), + "w": max(1, min(100, float(r.get("w", 10)))), + "h": max(1, min(100, float(r.get("h", 10)))), + }, + "description": r.get("description", ""), + "prompt": r.get("description", ""), + "image_b64": None, + "style": "educational", + }) + + # Enrich prompts with nearby vocab context + if entries: + for region in regions: + ry = region["bbox_pct"]["y"] + rh = region["bbox_pct"]["h"] + nearby = [ + e for e in entries + if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 + ] + if nearby: + en_words = [e.get("english", "") for e in nearby if e.get("english")] + de_words = [e.get("german", "") for e in nearby if e.get("german")] + if en_words or de_words: + context = f" (vocabulary context: {', '.join(en_words[:5])}" + if de_words: + context += f" / {', '.join(de_words[:5])}" + context += ")" + region["prompt"] = region["description"] + context + + # Save to ground_truth JSONB + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["image_regions"] = regions + validation["detected_at"] = datetime.utcnow().isoformat() + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Detected {len(regions)} image regions for session {session_id}") + + return {"regions": regions, "count": len(regions)} + + except httpx.ConnectError: + logger.warning(f"VLM not available at {ollama_base} for image detection") + return {"regions": [], "count": 0, "error": "VLM not available"} + except Exception as e: + logger.error(f"Image detection failed for {session_id}: {e}") + return {"regions": [], "count": 0, "error": str(e)} + + +@router.post("/sessions/{session_id}/reconstruction/generate-image") +async def generate_image_for_region(session_id: str, req: GenerateImageRequest): + """Generate a replacement image for a detected region using mflux. + + Sends the prompt (with style suffix) to the mflux-service running + natively on the Mac Mini (Metal GPU required). + """ + import httpx + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + regions = validation.get("image_regions") or [] + + if req.region_index < 0 or req.region_index >= len(regions): + raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") + + mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") + style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) + full_prompt = f"{req.prompt}, {style_suffix}" + + # Determine image size from region aspect ratio (snap to multiples of 64) + region = regions[req.region_index] + bbox = region["bbox_pct"] + aspect = bbox["w"] / max(bbox["h"], 1) + if aspect > 1.3: + width, height = 768, 512 + elif aspect < 0.7: + width, height = 512, 768 + else: + width, height = 512, 512 + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{mflux_url}/generate", json={ + "prompt": full_prompt, + "width": width, + "height": height, + "steps": 4, + }) + resp.raise_for_status() + data = resp.json() + image_b64 = data.get("image_b64") + + if not image_b64: + return {"image_b64": None, "success": False, "error": "No image returned"} + + # Save to ground_truth + regions[req.region_index]["image_b64"] = image_b64 + regions[req.region_index]["prompt"] = req.prompt + regions[req.region_index]["style"] = req.style + validation["image_regions"] = regions + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Generated image for session {session_id} region {req.region_index}") + return {"image_b64": image_b64, "success": True} + + except httpx.ConnectError: + logger.warning(f"mflux-service not available at {mflux_url}") + return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} + except Exception as e: + logger.error(f"Image generation failed for {session_id}: {e}") + return {"image_b64": None, "success": False, "error": str(e)} + + +@router.post("/sessions/{session_id}/reconstruction/validate") +async def save_validation(session_id: str, req: ValidationRequest): + """Save final validation results for step 8. + + Stores notes, score, and preserves any detected/generated image regions. + Sets current_step = 10 to mark pipeline as complete. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["validated_at"] = datetime.utcnow().isoformat() + validation["notes"] = req.notes + validation["score"] = req.score + ground_truth["validation"] = validation + + await update_session_db(session_id, ground_truth=ground_truth, current_step=11) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Validation saved for session {session_id}: score={req.score}") + + return {"session_id": session_id, "validation": validation} + + +@router.get("/sessions/{session_id}/reconstruction/validation") +async def get_validation(session_id: str): + """Retrieve saved validation data for step 8.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") + + return { + "session_id": session_id, + "validation": validation, + "word_result": session.get("word_result"), + } + + +# --------------------------------------------------------------------------- +# Remove handwriting +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/remove-handwriting") +async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): + """ + Remove handwriting from a session image using inpainting. + + Steps: + 1. Load source image (auto -> deskewed if available, else original) + 2. Detect handwriting mask (filtered by target_ink) + 3. Dilate mask to cover stroke edges + 4. Inpaint the image + 5. Store result as clean_png in the session + + Returns metadata including the URL to fetch the clean image. + """ + import time as _time + t0 = _time.monotonic() + + from services.handwriting_detection import detect_handwriting + from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # 1. Determine source image + source = req.use_source + if source == "auto": + deskewed = await get_session_image(session_id, "deskewed") + source = "deskewed" if deskewed else "original" + + image_bytes = await get_session_image(session_id, source) + if not image_bytes: + raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") + + # 2. Detect handwriting mask + detection = detect_handwriting(image_bytes, target_ink=req.target_ink) + + # 3. Convert mask to PNG bytes and dilate + import io + from PIL import Image as _PILImage + mask_img = _PILImage.fromarray(detection.mask) + mask_buf = io.BytesIO() + mask_img.save(mask_buf, format="PNG") + mask_bytes = mask_buf.getvalue() + + if req.dilation > 0: + mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) + + # 4. Inpaint + method_map = { + "telea": InpaintingMethod.OPENCV_TELEA, + "ns": InpaintingMethod.OPENCV_NS, + "auto": InpaintingMethod.AUTO, + } + inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) + + result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) + if not result.success: + raise HTTPException(status_code=500, detail="Inpainting failed") + + elapsed_ms = int((_time.monotonic() - t0) * 1000) + + meta = { + "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), + "handwriting_ratio": round(detection.handwriting_ratio, 4), + "detection_confidence": round(detection.confidence, 4), + "target_ink": req.target_ink, + "dilation": req.dilation, + "source_image": source, + "processing_time_ms": elapsed_ms, + } + + # 5. Persist clean image (convert BGR ndarray -> PNG bytes) + clean_png_bytes = image_to_png(result.image) + await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) + + return { + **meta, + "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", + "session_id": session_id, + } diff --git a/klausur-service/backend/ocr_pipeline_rows.py b/klausur-service/backend/ocr_pipeline_rows.py new file mode 100644 index 0000000..9fb9915 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_rows.py @@ -0,0 +1,348 @@ +""" +OCR Pipeline - Row Detection Endpoints. + +Extracted from ocr_pipeline_api.py. +Handles row detection (auto + manual) and row ground truth. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException + +from cv_vocab_pipeline import ( + create_ocr_image, + detect_column_geometry, + detect_row_geometry, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, + ManualRowsRequest, + RowGroundTruthRequest, +) +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Helper: Box-exclusion overlay (used by rows overlay and columns overlay) +# --------------------------------------------------------------------------- + +def _draw_box_exclusion_overlay( + img: np.ndarray, + zones: List[Dict], + *, + label: str = "BOX — separat verarbeitet", +) -> None: + """Draw red semi-transparent rectangles over box zones (in-place). + + Reusable for columns, rows, and words overlays. + """ + for zone in zones: + if zone.get("zone_type") != "box" or not zone.get("box"): + continue + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + + # Red semi-transparent fill (~25 %) + box_overlay = img.copy() + cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1) + cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img) + + # Border + cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2) + + # Label + cv2.putText(img, label, (bx + 10, by + bh - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + +# --------------------------------------------------------------------------- +# Row Detection Endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/rows") +async def detect_rows(session_id: str): + """Run row detection on the cropped (or dewarped) image using horizontal gap analysis.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if dewarped_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection") + + t0 = time.time() + + # Try to reuse cached word_dicts and inv from column detection + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + # Not cached — run column geometry to get intermediates + ocr_img = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img, dewarped_bgr) + if geo_result is None: + raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows") + _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + else: + left_x, right_x, top_y, bottom_y = content_bounds + + # Read zones from column_result to exclude box regions + session = await get_session_db(session_id) + column_result = (session or {}).get("column_result") or {} + is_sub_session = bool((session or {}).get("parent_session_id")) + + # Sub-sessions (box crops): use word-grouping instead of gap-based + # row detection. Box images are small with complex internal layouts + # (headings, sub-columns) where the horizontal projection approach + # merges rows. Word-grouping directly clusters words by Y proximity, + # which is more robust for these cases. + if is_sub_session and word_dicts: + from cv_layout import _build_rows_from_word_grouping + rows = _build_rows_from_word_grouping( + word_dicts, left_x, right_x, top_y, bottom_y, + right_x - left_x, bottom_y - top_y, + ) + logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows") + else: + zones = column_result.get("zones") or [] # zones can be None for sub-sessions + + # Collect box y-ranges for filtering. + # Use border_thickness to shrink the exclusion zone: the border pixels + # belong visually to the box frame, but text rows above/below the box + # may overlap with the border area and must not be clipped. + box_ranges = [] # [(y_start, y_end)] + box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin + box_ranges.append((box["y"], box["y"] + box["height"])) + # Inner range: shrink by border thickness so boundary rows aren't excluded + box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) + + if box_ranges and inv is not None: + # Combined-image approach: strip box regions from inv image, + # run row detection on the combined image, then remap y-coords back. + content_strips = [] # [(y_start, y_end)] in absolute coords + # Build content strips by subtracting box inner ranges from [top_y, bottom_y]. + # Using inner ranges means the border area is included in the content + # strips, so the last row above a box isn't clipped by the border. + sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0]) + strip_start = top_y + for by_start, by_end in sorted_boxes: + if by_start > strip_start: + content_strips.append((strip_start, by_start)) + strip_start = max(strip_start, by_end) + if strip_start < bottom_y: + content_strips.append((strip_start, bottom_y)) + + # Filter to strips with meaningful height + content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20] + + if content_strips: + # Stack content strips vertically + inv_strips = [inv[ys:ye, :] for ys, ye in content_strips] + combined_inv = np.vstack(inv_strips) + + # Filter word_dicts to only include words from content strips + combined_words = [] + cum_y = 0 + strip_offsets = [] # (combined_y_start, strip_height, abs_y_start) + for ys, ye in content_strips: + h = ye - ys + strip_offsets.append((cum_y, h, ys)) + for w in word_dicts: + w_abs_y = w['top'] + top_y # word y is relative to content top + w_center = w_abs_y + w['height'] / 2 + if ys <= w_center < ye: + # Remap to combined coordinates + w_copy = dict(w) + w_copy['top'] = cum_y + (w_abs_y - ys) + combined_words.append(w_copy) + cum_y += h + + # Run row detection on combined image + combined_h = combined_inv.shape[0] + rows = detect_row_geometry( + combined_inv, combined_words, left_x, right_x, 0, combined_h, + ) + + # Remap y-coordinates back to absolute page coords + def _combined_y_to_abs(cy: int) -> int: + for c_start, s_h, abs_start in strip_offsets: + if cy < c_start + s_h: + return abs_start + (cy - c_start) + last_c, last_h, last_abs = strip_offsets[-1] + return last_abs + last_h + + for r in rows: + abs_y = _combined_y_to_abs(r.y) + abs_y_end = _combined_y_to_abs(r.y + r.height) + r.y = abs_y + r.height = abs_y_end - abs_y + else: + rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + else: + # No boxes — standard row detection + rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + duration = time.time() - t0 + + # Assign zone_index based on which content zone each row falls in + # Build content zone list with indices + zones = column_result.get("zones") or [] + content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else [] + + # Build serializable result (exclude words to keep payload small) + rows_data = [] + for r in rows: + # Determine zone_index + zone_idx = 0 + row_center_y = r.y + r.height / 2 + for zi, zone in content_zones: + zy = zone["y"] + zh = zone["height"] + if zy <= row_center_y < zy + zh: + zone_idx = zi + break + + rd = { + "index": r.index, + "x": r.x, + "y": r.y, + "width": r.width, + "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + "zone_index": zone_idx, + } + rows_data.append(rd) + + type_counts = {} + for r in rows: + type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1 + + row_result = { + "rows": rows_data, + "summary": type_counts, + "total_rows": len(rows), + "duration_seconds": round(duration, 2), + } + + # Persist to DB — also invalidate word_result since rows changed + await update_session_db( + session_id, + row_result=row_result, + word_result=None, + current_step=7, + ) + + cached["row_result"] = row_result + cached.pop("word_result", None) + + logger.info(f"OCR Pipeline: rows session {session_id}: " + f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}") + + content_rows = sum(1 for r in rows if r.row_type == "content") + avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0 + await _append_pipeline_log(session_id, "rows", { + "total_rows": len(rows), + "content_rows": content_rows, + "artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0), + "avg_row_height_px": avg_height, + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **row_result, + } + + +@router.post("/sessions/{session_id}/rows/manual") +async def set_manual_rows(session_id: str, req: ManualRowsRequest): + """Override detected rows with manual definitions.""" + row_result = { + "rows": req.rows, + "total_rows": len(req.rows), + "duration_seconds": 0, + "method": "manual", + } + + await update_session_db(session_id, row_result=row_result, word_result=None) + + if session_id in _cache: + _cache[session_id]["row_result"] = row_result + _cache[session_id].pop("word_result", None) + + logger.info(f"OCR Pipeline: manual rows session {session_id}: " + f"{len(req.rows)} rows set") + + return {"session_id": session_id, **row_result} + + +@router.post("/sessions/{session_id}/ground-truth/rows") +async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest): + """Save ground truth feedback for the row detection step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_rows": req.corrected_rows, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "row_result": session.get("row_result"), + } + ground_truth["rows"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@router.get("/sessions/{session_id}/ground-truth/rows") +async def get_row_ground_truth(session_id: str): + """Retrieve saved ground truth for row detection.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + rows_gt = ground_truth.get("rows") + if not rows_gt: + raise HTTPException(status_code=404, detail="No row ground truth saved") + + return { + "session_id": session_id, + "rows_gt": rows_gt, + "rows_auto": session.get("row_result"), + } diff --git a/klausur-service/backend/ocr_pipeline_sessions.py b/klausur-service/backend/ocr_pipeline_sessions.py new file mode 100644 index 0000000..57e25ae --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_sessions.py @@ -0,0 +1,483 @@ +""" +OCR Pipeline Sessions API - Session management and image serving endpoints. + +Extracted from ocr_pipeline_api.py for modularity. +Handles: CRUD for sessions, thumbnails, pipeline logs, categories, +image serving (with overlay dispatch), and document type detection. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +import uuid +from typing import Any, Dict, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile +from fastapi.responses import Response + +from cv_vocab_pipeline import ( + create_ocr_image, + detect_document_type, + render_image_high_res, + render_pdf_high_res, +) +from ocr_pipeline_common import ( + VALID_DOCUMENT_CATEGORIES, + UpdateSessionRequest, + _append_pipeline_log, + _cache, + _get_base_image_png, + _get_cached, + _load_session_to_cache, +) +from ocr_pipeline_overlays import render_overlay +from ocr_pipeline_session_store import ( + create_session_db, + delete_all_sessions_db, + delete_session_db, + get_session_db, + get_session_image, + get_sub_sessions, + list_sessions_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Session Management Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions") +async def list_sessions(include_sub_sessions: bool = False): + """List OCR pipeline sessions. + + By default, sub-sessions (box regions) are hidden. + Pass ?include_sub_sessions=true to show them. + """ + sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) + return {"sessions": sessions} + + +@router.post("/sessions") +async def create_session( + file: UploadFile = File(...), + name: Optional[str] = Form(None), +): + """Upload a PDF or image file and create a pipeline session.""" + file_data = await file.read() + filename = file.filename or "upload" + content_type = file.content_type or "" + + session_id = str(uuid.uuid4()) + is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") + + try: + if is_pdf: + img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) + else: + img_bgr = render_image_high_res(file_data) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Could not process file: {e}") + + # Encode original as PNG bytes + success, png_buf = cv2.imencode(".png", img_bgr) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode image") + + original_png = png_buf.tobytes() + session_name = name or filename + + # Persist to DB + await create_session_db( + session_id=session_id, + name=session_name, + filename=filename, + original_png=original_png, + ) + + # Cache BGR array for immediate processing + _cache[session_id] = { + "id": session_id, + "filename": filename, + "name": session_name, + "original_bgr": img_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + logger.info(f"OCR Pipeline: created session {session_id} from {filename} " + f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") + + return { + "session_id": session_id, + "filename": filename, + "name": session_name, + "image_width": img_bgr.shape[1], + "image_height": img_bgr.shape[0], + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + } + + +@router.get("/sessions/{session_id}") +async def get_session_info(session_id: str): + """Get session info including deskew/dewarp/column results for step navigation.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # Get image dimensions from original PNG + original_png = await get_session_image(session_id, "original") + if original_png: + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) + else: + img_w, img_h = 0, 0 + + result = { + "session_id": session["id"], + "filename": session.get("filename", ""), + "name": session.get("name", ""), + "image_width": img_w, + "image_height": img_h, + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + "current_step": session.get("current_step", 1), + "document_category": session.get("document_category"), + "doc_type": session.get("doc_type"), + } + + if session.get("orientation_result"): + result["orientation_result"] = session["orientation_result"] + if session.get("crop_result"): + result["crop_result"] = session["crop_result"] + if session.get("deskew_result"): + result["deskew_result"] = session["deskew_result"] + if session.get("dewarp_result"): + result["dewarp_result"] = session["dewarp_result"] + if session.get("column_result"): + result["column_result"] = session["column_result"] + if session.get("row_result"): + result["row_result"] = session["row_result"] + if session.get("word_result"): + result["word_result"] = session["word_result"] + if session.get("doc_type_result"): + result["doc_type_result"] = session["doc_type_result"] + + # Sub-session info + if session.get("parent_session_id"): + result["parent_session_id"] = session["parent_session_id"] + result["box_index"] = session.get("box_index") + else: + # Check for sub-sessions + subs = await get_sub_sessions(session_id) + if subs: + result["sub_sessions"] = [ + {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} + for s in subs + ] + + return result + + +@router.put("/sessions/{session_id}") +async def update_session(session_id: str, req: UpdateSessionRequest): + """Update session name and/or document category.""" + kwargs: Dict[str, Any] = {} + if req.name is not None: + kwargs["name"] = req.name + if req.document_category is not None: + if req.document_category not in VALID_DOCUMENT_CATEGORIES: + raise HTTPException( + status_code=400, + detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", + ) + kwargs["document_category"] = req.document_category + if not kwargs: + raise HTTPException(status_code=400, detail="Nothing to update") + updated = await update_session_db(session_id, **kwargs) + if not updated: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, **kwargs} + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a session.""" + _cache.pop(session_id, None) + deleted = await delete_session_db(session_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "deleted": True} + + +@router.delete("/sessions") +async def delete_all_sessions(): + """Delete ALL sessions (cleanup).""" + _cache.clear() + count = await delete_all_sessions_db() + return {"deleted_count": count} + + +@router.post("/sessions/{session_id}/create-box-sessions") +async def create_box_sessions(session_id: str): + """Create sub-sessions for each detected box region. + + Crops box regions from the cropped/dewarped image and creates + independent sub-sessions that can be processed through the pipeline. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result: + raise HTTPException(status_code=400, detail="Column detection must be completed first") + + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + if not box_zones: + return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} + + # Check for existing sub-sessions + existing = await get_sub_sessions(session_id) + if existing: + return { + "session_id": session_id, + "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], + "message": f"{len(existing)} sub-session(s) already exist", + } + + # Load base image + base_png = await get_session_image(session_id, "cropped") + if not base_png: + base_png = await get_session_image(session_id, "dewarped") + if not base_png: + raise HTTPException(status_code=400, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + parent_name = session.get("name", "Session") + created = [] + + for i, zone in enumerate(box_zones): + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + + # Crop box region with small padding + pad = 5 + y1 = max(0, by - pad) + y2 = min(img.shape[0], by + bh + pad) + x1 = max(0, bx - pad) + x2 = min(img.shape[1], bx + bw + pad) + crop = img[y1:y2, x1:x2] + + # Encode as PNG + success, png_buf = cv2.imencode(".png", crop) + if not success: + logger.warning(f"Failed to encode box {i} crop for session {session_id}") + continue + + sub_id = str(uuid.uuid4()) + sub_name = f"{parent_name} — Box {i + 1}" + + await create_session_db( + session_id=sub_id, + name=sub_name, + filename=session.get("filename", "box-crop.png"), + original_png=png_buf.tobytes(), + parent_session_id=session_id, + box_index=i, + ) + + # Cache the BGR for immediate processing + # Promote original to cropped so column/row/word detection finds it + box_bgr = crop.copy() + _cache[sub_id] = { + "id": sub_id, + "filename": session.get("filename", "box-crop.png"), + "name": sub_name, + "parent_session_id": session_id, + "original_bgr": box_bgr, + "oriented_bgr": None, + "cropped_bgr": box_bgr, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + created.append({ + "id": sub_id, + "name": sub_name, + "box_index": i, + "box": box, + "image_width": crop.shape[1], + "image_height": crop.shape[0], + }) + + logger.info(f"Created box sub-session {sub_id} for session {session_id} " + f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") + + return { + "session_id": session_id, + "sub_sessions": created, + "total": len(created), + } + + +@router.get("/sessions/{session_id}/thumbnail") +async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): + """Return a small thumbnail of the original image.""" + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + h, w = img.shape[:2] + scale = size / max(h, w) + new_w, new_h = int(w * scale), int(h * scale) + thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + _, png_bytes = cv2.imencode(".png", thumb) + return Response(content=png_bytes.tobytes(), media_type="image/png", + headers={"Cache-Control": "public, max-age=3600"}) + + +@router.get("/sessions/{session_id}/pipeline-log") +async def get_pipeline_log(session_id: str): + """Get the pipeline execution log for a session.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} + + +@router.get("/categories") +async def list_categories(): + """List valid document categories.""" + return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} + + +# --------------------------------------------------------------------------- +# Image Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/image/{image_type}") +async def get_image(session_id: str, image_type: str): + """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" + valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} + if image_type not in valid_types: + raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") + + if image_type == "structure-overlay": + return await render_overlay("structure", session_id) + + if image_type == "columns-overlay": + return await render_overlay("columns", session_id) + + if image_type == "rows-overlay": + return await render_overlay("rows", session_id) + + if image_type == "words-overlay": + return await render_overlay("words", session_id) + + # Try cache first for fast serving + cached = _cache.get(session_id) + if cached: + png_key = f"{image_type}_png" if image_type != "original" else None + bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None + + # For binarized, check if we have it cached as PNG + if image_type == "binarized" and cached.get("binarized_png"): + return Response(content=cached["binarized_png"], media_type="image/png") + + # Load from DB — for cropped/dewarped, fall back through the chain + if image_type in ("cropped", "dewarped"): + data = await _get_base_image_png(session_id) + else: + data = await get_session_image(session_id, image_type) + if not data: + raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") + + return Response(content=data, media_type="image/png") + + +# --------------------------------------------------------------------------- +# Document Type Detection (between Dewarp and Columns) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-type") +async def detect_type(session_id: str): + """Detect document type (vocab_table, full_text, generic_table). + + Should be called after crop (clean image available). + Falls back to dewarped if crop was skipped. + Stores result in session for frontend to decide pipeline flow. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") + + t0 = time.time() + ocr_img = create_ocr_image(img_bgr) + result = detect_document_type(ocr_img, img_bgr) + duration = time.time() - t0 + + result_dict = { + "doc_type": result.doc_type, + "confidence": result.confidence, + "pipeline": result.pipeline, + "skip_steps": result.skip_steps, + "features": result.features, + "duration_seconds": round(duration, 2), + } + + # Persist to DB + await update_session_db( + session_id, + doc_type=result.doc_type, + doc_type_result=result_dict, + ) + + cached["doc_type_result"] = result_dict + + logger.info(f"OCR Pipeline: detect-type session {session_id}: " + f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") + + await _append_pipeline_log(session_id, "detect_type", { + "doc_type": result.doc_type, + "pipeline": result.pipeline, + "confidence": result.confidence, + **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **result_dict} diff --git a/klausur-service/backend/ocr_pipeline_words.py b/klausur-service/backend/ocr_pipeline_words.py new file mode 100644 index 0000000..c3eb7d6 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_words.py @@ -0,0 +1,876 @@ +""" +OCR Pipeline Words - Word detection and ground truth endpoints. + +Extracted from ocr_pipeline_api.py. +Handles: +- POST /sessions/{session_id}/words — main SSE streaming word detection +- POST /sessions/{session_id}/paddle-direct — PaddleOCR direct endpoint +- POST /sessions/{session_id}/ground-truth/words — save ground truth +- GET /sessions/{session_id}/ground-truth/words — get ground truth + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from datetime import datetime +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from cv_vocab_pipeline import ( + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + build_cell_grid_v2, + build_cell_grid_v2_streaming, + create_ocr_image, + detect_column_geometry, +) +from cv_words_first import build_grid_from_words +from ocr_pipeline_session_store import ( + get_session_db, + get_session_image, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _get_base_image_png, + _append_pipeline_log, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Pydantic models +# --------------------------------------------------------------------------- + +class WordGroundTruthRequest(BaseModel): + is_correct: bool + corrected_entries: Optional[List[Dict[str, Any]]] = None + notes: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Word Detection Endpoint (Step 7) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/words") +async def detect_words( + session_id: str, + request: Request, + engine: str = "auto", + pronunciation: str = "british", + stream: bool = False, + skip_heal_gaps: bool = False, + grid_method: str = "v2", +): + """Build word grid from columns × rows, OCR each cell. + + Query params: + engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' + pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup + stream: false (default) for JSON response, true for SSE streaming + skip_heal_gaps: false (default). When true, cells keep exact row geometry + positions without gap-healing expansion. Better for overlay rendering. + grid_method: 'v2' (default) or 'words_first' — grid construction strategy. + 'v2' uses pre-detected columns/rows (top-down). + 'words_first' clusters words bottom-up (no column/row detection needed). + """ + # PaddleOCR is full-page remote OCR → force words_first grid method + if engine == "paddle" and grid_method != "words_first": + logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) + grid_method = "words_first" + + if session_id not in _cache: + logger.info("detect_words: session %s not in cache, loading from DB", session_id) + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if dewarped_bgr is None: + logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", + session_id, [k for k in cached.keys() if k.endswith('_bgr')]) + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + row_result = session.get("row_result") + if not column_result or not column_result.get("columns"): + # No column detection — synthesize a single full-page pseudo-column. + # This enables the overlay pipeline which skips column detection. + img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] + column_result = { + "columns": [{ + "type": "column_text", + "x": 0, "y": 0, + "width": img_w_tmp, "height": img_h_tmp, + "classification_confidence": 1.0, + "classification_method": "full_page_fallback", + }], + "zones": [], + "duration_seconds": 0, + } + logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) + if grid_method != "words_first" and (not row_result or not row_result.get("rows")): + raise HTTPException(status_code=400, detail="Row detection must be completed first") + + # Convert column dicts back to PageRegion objects + col_regions = [ + PageRegion( + type=c["type"], + x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + # Convert row dicts back to RowGeometry objects + row_geoms = [ + RowGeometry( + index=r["index"], + x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), + words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + # Cell-First OCR (v2): no full-page word re-population needed. + # Each cell is cropped and OCR'd in isolation → no neighbour bleeding. + # We still need word_count > 0 for row filtering in build_cell_grid_v2, + # so populate from cached words if available (just for counting). + word_dicts = cached.get("_word_dicts") + if word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if word_dicts: + content_bounds = cached.get("_content_bounds") + if content_bounds: + _lx, _rx, top_y, _by = content_bounds + else: + top_y = min(r.y for r in row_geoms) if row_geoms else 0 + + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + # Exclude rows that fall within box zones. + # Use inner box range (shrunk by border_thickness) so that rows at + # the boundary (overlapping with the box border) are NOT excluded. + zones = column_result.get("zones") or [] + box_ranges_inner = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin + box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) + + if box_ranges_inner: + def _row_in_box(r): + center_y = r.y + r.height / 2 + return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) + + before_count = len(row_geoms) + row_geoms = [r for r in row_geoms if not _row_in_box(r)] + excluded = before_count - len(row_geoms) + if excluded: + logger.info(f"detect_words: excluded {excluded} rows inside box zones") + + # --- Words-First path: bottom-up grid from word boxes --- + if grid_method == "words_first": + t0 = time.time() + img_h, img_w = dewarped_bgr.shape[:2] + + # For paddle engine: run remote PaddleOCR full-page instead of Tesseract + if engine == "paddle": + from cv_ocr_engines import ocr_region_paddle + + wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) + # PaddleOCR returns absolute coordinates, no content_bounds offset needed + cached["_paddle_word_dicts"] = wf_word_dicts + else: + # Get word_dicts from cache or run Tesseract full-page + wf_word_dicts = cached.get("_word_dicts") + if wf_word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result + cached["_word_dicts"] = wf_word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if not wf_word_dicts: + raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid") + + # Convert word coordinates to absolute image coordinates if needed + # (detect_column_geometry returns words relative to content ROI) + # PaddleOCR already returns absolute coordinates — skip offset. + if engine != "paddle": + content_bounds = cached.get("_content_bounds") + if content_bounds: + lx, _rx, ty, _by = content_bounds + abs_words = [] + for w in wf_word_dicts: + abs_words.append({ + **w, + 'left': w['left'] + lx, + 'top': w['top'] + ty, + }) + wf_word_dicts = abs_words + + # Extract box rects for box-aware column clustering + box_rects = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box_rects.append(zone["box"]) + + cells, columns_meta = build_grid_from_words( + wf_word_dicts, img_w, img_h, box_rects=box_rects or None, + ) + duration = time.time() - t0 + + # Apply IPA phonetic fixes + fix_cell_phonetics(cells, pronunciation=pronunciation) + + # Add zone_index for backward compat + for cell in cells: + cell.setdefault("zone_index", 0) + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + used_engine = "paddle" if engine == "paddle" else "words_first" + + word_result = { + "cells": cells, + "grid_shape": { + "rows": n_rows, + "cols": n_cols, + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "grid_method": "words_first", + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + if is_vocab or 'column_text' in col_types: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words-first session {session_id}: " + f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") + + await _append_pipeline_log(session_id, "words", { + "grid_method": "words_first", + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + if stream: + # Cell-First OCR v2: use batch-then-stream approach instead of + # per-cell streaming. The parallel ThreadPoolExecutor in + # build_cell_grid_v2 is much faster than sequential streaming. + return StreamingResponse( + _word_batch_stream_generator( + session_id, cached, col_regions, row_geoms, + dewarped_bgr, engine, pronunciation, request, + skip_heal_gaps=skip_heal_gaps, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + # --- Non-streaming path (grid_method=v2) --- + t0 = time.time() + + # Create binarized OCR image (for Tesseract) + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + # Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation + cells, columns_meta = build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ) + duration = time.time() - t0 + + # Add zone_index to each cell (default 0 for backward compatibility) + for cell in cells: + cell.setdefault("zone_index", 0) + + # Layout detection + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + + # Count content rows and columns for grid_shape + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len(columns_meta) + + # Determine which engine was actually used + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + # Apply IPA phonetic fixes directly to cell texts (for overlay mode) + fix_cell_phonetics(cells, pronunciation=pronunciation) + + # Grid result (always generic) + word_result = { + "cells": cells, + "grid_shape": { + "rows": n_content_rows, + "cols": n_cols, + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + # For vocab layout or single-column (box sub-sessions): map cells 1:1 + # to vocab entries (row→entry). + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + # Persist to DB + await update_session_db( + session_id, + word_result=word_result, + current_step=8, + ) + + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") + + await _append_pipeline_log(session_id, "words", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "low_confidence_count": word_result["summary"]["low_confidence"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + "entry_count": word_result.get("entry_count", 0), + }, duration_ms=int(duration * 1000)) + + return { + "session_id": session_id, + **word_result, + } + + +async def _word_batch_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, + skip_heal_gaps: bool = False, +): + """SSE generator that runs batch OCR (parallel) then streams results. + + Unlike the old per-cell streaming, this uses build_cell_grid_v2 with + ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events. + The 'preparing' event keeps the connection alive during OCR processing. + """ + import asyncio + + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + total_cells = n_content_rows * n_cols + + # 1. Send meta event immediately + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + # 2. Send preparing event (keepalive for proxy) + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" + + # 3. Run batch OCR in thread pool with periodic keepalive events. + # The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE + # connections after 30-60s. Send keepalive every 5s to prevent this. + loop = asyncio.get_event_loop() + ocr_future = loop.run_in_executor( + None, + lambda: build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ), + ) + + # Send keepalive events every 5 seconds while OCR runs + keepalive_count = 0 + while not ocr_future.done(): + try: + cells, columns_meta = await asyncio.wait_for( + asyncio.shield(ocr_future), timeout=5.0, + ) + break # OCR finished + except asyncio.TimeoutError: + keepalive_count += 1 + elapsed = int(time.time() - t0) + yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected during OCR for {session_id}") + ocr_future.cancel() + return + else: + cells, columns_meta = ocr_future.result() + + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected after OCR for {session_id}") + return + + # 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode) + fix_cell_phonetics(cells, pronunciation=pronunciation) + + # 5. Send columns meta + if columns_meta: + yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" + + # 6. Stream all cells + for idx, cell in enumerate(cells): + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": idx + 1, "total": len(cells)}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # 6. Build final result and persist + duration = time.time() - t0 + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " + f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") + + # 7. Send complete event + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" + + +async def _word_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, +): + """SSE generator that yields cell-by-cell OCR progress.""" + t0 = time.time() + + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + # Compute grid shape upfront for the meta event + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + + # Determine layout + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + + # Start streaming — first event: meta + columns_meta = None # will be set from first yield + total_cells = n_content_rows * n_cols + + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + # Keepalive: send preparing event so proxy doesn't timeout during OCR init + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" + + # Stream cells one by one + all_cells: List[Dict[str, Any]] = [] + cell_idx = 0 + last_keepalive = time.time() + + for cell, cols_meta, total in build_cell_grid_v2_streaming( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + ): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during streaming for {session_id}") + return + + if columns_meta is None: + columns_meta = cols_meta + # Send columns_used as part of first cell or update meta + meta_update = { + "type": "columns", + "columns_used": cols_meta, + } + yield f"data: {json.dumps(meta_update)}\n\n" + + all_cells.append(cell) + cell_idx += 1 + + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": cell_idx, "total": total}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # All cells done — build final result + duration = time.time() - t0 + if columns_meta is None: + columns_meta = [] + + # Post-OCR: remove rows where ALL cells are empty (inter-row gaps + # that had stray Tesseract artifacts giving word_count > 0). + rows_with_text: set = set() + for c in all_cells: + if c.get("text", "").strip(): + rows_with_text.add(c["row_index"]) + before_filter = len(all_cells) + all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] + empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) + if empty_rows_removed > 0: + logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") + + used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine + + # Apply IPA phonetic fixes directly to cell texts (for overlay mode) + fix_cell_phonetics(all_cells, pronunciation=pronunciation) + + word_result = { + "cells": all_cells, + "grid_shape": { + "rows": n_content_rows, + "cols": n_cols, + "total_cells": len(all_cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(all_cells), + "non_empty_cells": sum(1 for c in all_cells if c.get("text")), + "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), + }, + } + + # For vocab layout or single-column (box sub-sessions): map cells 1:1 + # to vocab entries (row→entry). + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(all_cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + # Persist to DB + await update_session_db( + session_id, + word_result=word_result, + current_step=8, + ) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(all_cells)} cells ({duration:.2f}s)") + + # Final complete event + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" + + +# --------------------------------------------------------------------------- +# PaddleOCR Direct Endpoint +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/paddle-direct") +async def paddle_direct(session_id: str): + """Run PaddleOCR on the preprocessed image and build a word grid directly. + + Expects orientation/deskew/dewarp/crop to be done already. + Uses the cropped image (falls back to dewarped, then original). + The used image is stored as cropped_png so OverlayReconstruction + can display it as the background. + """ + # Try preprocessed images first (crop > dewarp > original) + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode original image") + + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_paddle + + t0 = time.time() + word_dicts = await ocr_region_paddle(img_bgr, region=None) + if not word_dicts: + raise HTTPException(status_code=400, detail="PaddleOCR returned no words") + + # Reuse build_grid_from_words — same function that works in the regular + # pipeline with PaddleOCR (engine=paddle, grid_method=words_first). + # Handles phrase splitting, column clustering, and reading order. + cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h) + duration = time.time() - t0 + + # Tag cells as paddle_direct + for cell in cells: + cell["ocr_engine"] = "paddle_direct" + + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + word_result = { + "cells": cells, + "grid_shape": { + "rows": n_rows, + "cols": n_cols, + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": "paddle_direct", + "grid_method": "paddle_direct", + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + # Store preprocessed image as cropped_png so OverlayReconstruction shows it + await update_session_db( + session_id, + word_result=word_result, + cropped_png=img_png, + current_step=8, + ) + + logger.info( + "paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs", + session_id, len(cells), n_rows, n_cols, duration, + ) + + await _append_pipeline_log(session_id, "paddle_direct", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "ocr_engine": "paddle_direct", + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +# --------------------------------------------------------------------------- +# Ground Truth Words Endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/ground-truth/words") +async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): + """Save ground truth feedback for the word recognition step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "corrected_entries": req.corrected_entries, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "word_result": session.get("word_result"), + } + ground_truth["words"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +@router.get("/sessions/{session_id}/ground-truth/words") +async def get_word_ground_truth(session_id: str): + """Retrieve saved ground truth for word recognition.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + words_gt = ground_truth.get("words") + if not words_gt: + raise HTTPException(status_code=404, detail="No word ground truth saved") + + return { + "session_id": session_id, + "words_gt": words_gt, + "words_auto": session.get("word_result"), + }