""" Shared common module for the OCR pipeline. Contains in-memory cache, helper functions, Pydantic request models, pipeline logging, and border-ghost word filtering used by the pipeline API endpoints and related modules. """ import logging import re import time from datetime import datetime from typing import Any, Dict, List, Optional import cv2 import numpy as np from fastapi import HTTPException from pydantic import BaseModel from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db __all__ = [ # Cache "_cache", # Helper functions "_get_base_image_png", "_load_session_to_cache", "_get_cached", # Pydantic models "ManualDeskewRequest", "DeskewGroundTruthRequest", "ManualDewarpRequest", "CombinedAdjustRequest", "DewarpGroundTruthRequest", "VALID_DOCUMENT_CATEGORIES", "UpdateSessionRequest", "ManualColumnsRequest", "ColumnGroundTruthRequest", "ManualRowsRequest", "RowGroundTruthRequest", "RemoveHandwritingRequest", # Pipeline log "_append_pipeline_log", # Border-ghost filter "_BORDER_GHOST_CHARS", "_filter_border_ghost_words", ] logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # In-memory cache for active sessions (BGR numpy arrays for processing) # DB is source of truth, cache holds BGR arrays during active processing. # --------------------------------------------------------------------------- _cache: Dict[str, Dict[str, Any]] = {} async def _get_base_image_png(session_id: str) -> Optional[bytes]: """Get the best available base image for a session (cropped > dewarped > original).""" for img_type in ("cropped", "dewarped", "original"): png_data = await get_session_image(session_id, img_type) if png_data: return png_data return None async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: """Load session from DB into cache, decoding PNGs to BGR arrays.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") if session_id in _cache: return _cache[session_id] cache_entry: Dict[str, Any] = { "id": session_id, **session, "original_bgr": None, "oriented_bgr": None, "cropped_bgr": None, "deskewed_bgr": None, "dewarped_bgr": None, } # Decode images from DB into BGR numpy arrays for img_type, bgr_key in [ ("original", "original_bgr"), ("oriented", "oriented_bgr"), ("cropped", "cropped_bgr"), ("deskewed", "deskewed_bgr"), ("dewarped", "dewarped_bgr"), ]: png_data = await get_session_image(session_id, img_type) if png_data: arr = np.frombuffer(png_data, dtype=np.uint8) bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) cache_entry[bgr_key] = bgr # Sub-sessions: original image IS the cropped box region. # Promote original_bgr to cropped_bgr so downstream steps find it. if session.get("parent_session_id") and cache_entry["original_bgr"] is not None: if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None: cache_entry["cropped_bgr"] = cache_entry["original_bgr"] _cache[session_id] = cache_entry return cache_entry def _get_cached(session_id: str) -> Dict[str, Any]: """Get from cache or raise 404.""" entry = _cache.get(session_id) if not entry: raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") return entry # --------------------------------------------------------------------------- # Pydantic Models # --------------------------------------------------------------------------- class ManualDeskewRequest(BaseModel): angle: float class DeskewGroundTruthRequest(BaseModel): is_correct: bool corrected_angle: Optional[float] = None notes: Optional[str] = None class ManualDewarpRequest(BaseModel): shear_degrees: float class CombinedAdjustRequest(BaseModel): rotation_degrees: float = 0.0 shear_degrees: float = 0.0 class DewarpGroundTruthRequest(BaseModel): is_correct: bool corrected_shear: Optional[float] = None notes: Optional[str] = None VALID_DOCUMENT_CATEGORIES = { 'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite', 'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges', } class UpdateSessionRequest(BaseModel): name: Optional[str] = None document_category: Optional[str] = None class ManualColumnsRequest(BaseModel): columns: List[Dict[str, Any]] class ColumnGroundTruthRequest(BaseModel): is_correct: bool corrected_columns: Optional[List[Dict[str, Any]]] = None notes: Optional[str] = None class ManualRowsRequest(BaseModel): rows: List[Dict[str, Any]] class RowGroundTruthRequest(BaseModel): is_correct: bool corrected_rows: Optional[List[Dict[str, Any]]] = None notes: Optional[str] = None class RemoveHandwritingRequest(BaseModel): method: str = "auto" # "auto" | "telea" | "ns" target_ink: str = "all" # "all" | "colored" | "pencil" dilation: int = 2 # mask dilation iterations (0-5) use_source: str = "auto" # "original" | "deskewed" | "auto" # --------------------------------------------------------------------------- # Pipeline Log Helper # --------------------------------------------------------------------------- async def _append_pipeline_log( session_id: str, step_name: str, metrics: Dict[str, Any], success: bool = True, duration_ms: Optional[int] = None, ): """Append a step entry to the session's pipeline_log JSONB.""" session = await get_session_db(session_id) if not session: return log = session.get("pipeline_log") or {"steps": []} if not isinstance(log, dict): log = {"steps": []} entry = { "step": step_name, "completed_at": datetime.utcnow().isoformat(), "success": success, "metrics": metrics, } if duration_ms is not None: entry["duration_ms"] = duration_ms log.setdefault("steps", []).append(entry) await update_session_db(session_id, pipeline_log=log) # --------------------------------------------------------------------------- # Border-ghost word filter # --------------------------------------------------------------------------- # Characters that OCR produces when reading box-border lines. _BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"") def _filter_border_ghost_words( word_result: Dict, boxes: List, ) -> int: """Remove OCR words that are actually box border lines. A word is considered a border ghost when it sits on a known box edge (left, right, top, or bottom) and looks like a line artefact (narrow aspect ratio or text consists only of line-like characters). After removing ghost cells, columns that have become empty are also removed from ``columns_used`` so the grid no longer shows phantom columns. Modifies *word_result* in-place and returns the number of removed cells. """ if not boxes or not word_result: return 0 cells = word_result.get("cells") if not cells: return 0 # Build border bands — vertical (X) and horizontal (Y) x_bands = [] # list of (x_lo, x_hi) y_bands = [] # list of (y_lo, y_hi) for b in boxes: bx = b.x if hasattr(b, "x") else b.get("x", 0) by = b.y if hasattr(b, "y") else b.get("y", 0) bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0)) bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0)) bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3) margin = max(bt * 2, 10) + 6 # generous margin # Vertical edges (left / right) x_bands.append((bx - margin, bx + margin)) x_bands.append((bx + bw - margin, bx + bw + margin)) # Horizontal edges (top / bottom) y_bands.append((by - margin, by + margin)) y_bands.append((by + bh - margin, by + bh + margin)) img_w = word_result.get("image_width", 1) img_h = word_result.get("image_height", 1) def _is_ghost(cell: Dict) -> bool: text = (cell.get("text") or "").strip() if not text: return False # Compute absolute pixel position if cell.get("bbox_px"): px = cell["bbox_px"] cx = px["x"] + px["w"] / 2 cy = px["y"] + px["h"] / 2 cw = px["w"] ch = px["h"] elif cell.get("bbox_pct"): pct = cell["bbox_pct"] cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2 cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2 cw = (pct["w"] / 100) * img_w ch = (pct["h"] / 100) * img_h else: return False # Check if center sits on a vertical or horizontal border on_vertical = any(lo <= cx <= hi for lo, hi in x_bands) on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands) if not on_vertical and not on_horizontal: return False # Very short text (1-2 chars) on a border → very likely ghost if len(text) <= 2: # Narrow vertically (line-like) or narrow horizontally (dash-like)? if ch > 0 and cw / ch < 0.5: return True if cw > 0 and ch / cw < 0.5: return True # Text is only border-ghost characters? if all(c in _BORDER_GHOST_CHARS for c in text): return True # Longer text but still only ghost chars and very narrow if all(c in _BORDER_GHOST_CHARS for c in text): if ch > 0 and cw / ch < 0.35: return True if cw > 0 and ch / cw < 0.35: return True return True # all ghost chars on a border → remove return False before = len(cells) word_result["cells"] = [c for c in cells if not _is_ghost(c)] removed = before - len(word_result["cells"]) # --- Remove empty columns from columns_used --- columns_used = word_result.get("columns_used") if removed and columns_used and len(columns_used) > 1: remaining_cells = word_result["cells"] occupied_cols = {c.get("col_index") for c in remaining_cells} before_cols = len(columns_used) columns_used = [col for col in columns_used if col.get("index") in occupied_cols] # Re-index columns and remap cell col_index values if len(columns_used) < before_cols: old_to_new = {} for new_i, col in enumerate(columns_used): old_to_new[col["index"]] = new_i col["index"] = new_i for cell in remaining_cells: old_ci = cell.get("col_index") if old_ci in old_to_new: cell["col_index"] = old_to_new[old_ci] word_result["columns_used"] = columns_used logger.info("border-ghost: removed %d empty column(s), %d remaining", before_cols - len(columns_used), len(columns_used)) if removed: # Update summary counts summary = word_result.get("summary", {}) summary["total_cells"] = len(word_result["cells"]) summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text")) word_result["summary"] = summary gs = word_result.get("grid_shape", {}) gs["total_cells"] = len(word_result["cells"]) if columns_used is not None: gs["cols"] = len(columns_used) word_result["grid_shape"] = gs return removed