refactor: split ocr_pipeline_api.py (5426 lines) into 8 modules
Each module is under 1050 lines: - ocr_pipeline_common.py (354) - shared state, cache, models, helpers - ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type - ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns - ocr_pipeline_rows.py (348) - row detection, box-overlay helper - ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct - ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints - ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export - ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports router, _cache, and test-imported symbols for backward compatibility. No changes needed in main.py or tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
354
klausur-service/backend/ocr_pipeline_common.py
Normal file
354
klausur-service/backend/ocr_pipeline_common.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Shared common module for the OCR pipeline.
|
||||
|
||||
Contains in-memory cache, helper functions, Pydantic request models,
|
||||
pipeline logging, and border-ghost word filtering used by the pipeline
|
||||
API endpoints and related modules.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db
|
||||
|
||||
__all__ = [
|
||||
# Cache
|
||||
"_cache",
|
||||
# Helper functions
|
||||
"_get_base_image_png",
|
||||
"_load_session_to_cache",
|
||||
"_get_cached",
|
||||
# Pydantic models
|
||||
"ManualDeskewRequest",
|
||||
"DeskewGroundTruthRequest",
|
||||
"ManualDewarpRequest",
|
||||
"CombinedAdjustRequest",
|
||||
"DewarpGroundTruthRequest",
|
||||
"VALID_DOCUMENT_CATEGORIES",
|
||||
"UpdateSessionRequest",
|
||||
"ManualColumnsRequest",
|
||||
"ColumnGroundTruthRequest",
|
||||
"ManualRowsRequest",
|
||||
"RowGroundTruthRequest",
|
||||
"RemoveHandwritingRequest",
|
||||
# Pipeline log
|
||||
"_append_pipeline_log",
|
||||
# Border-ghost filter
|
||||
"_BORDER_GHOST_CHARS",
|
||||
"_filter_border_ghost_words",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||||
# DB is source of truth, cache holds BGR arrays during active processing.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
||||
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
||||
for img_type in ("cropped", "dewarped", "original"):
|
||||
png_data = await get_session_image(session_id, img_type)
|
||||
if png_data:
|
||||
return png_data
|
||||
return None
|
||||
|
||||
|
||||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
if session_id in _cache:
|
||||
return _cache[session_id]
|
||||
|
||||
cache_entry: Dict[str, Any] = {
|
||||
"id": session_id,
|
||||
**session,
|
||||
"original_bgr": None,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
}
|
||||
|
||||
# Decode images from DB into BGR numpy arrays
|
||||
for img_type, bgr_key in [
|
||||
("original", "original_bgr"),
|
||||
("oriented", "oriented_bgr"),
|
||||
("cropped", "cropped_bgr"),
|
||||
("deskewed", "deskewed_bgr"),
|
||||
("dewarped", "dewarped_bgr"),
|
||||
]:
|
||||
png_data = await get_session_image(session_id, img_type)
|
||||
if png_data:
|
||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
cache_entry[bgr_key] = bgr
|
||||
|
||||
# Sub-sessions: original image IS the cropped box region.
|
||||
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
||||
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
||||
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
||||
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
||||
|
||||
_cache[session_id] = cache_entry
|
||||
return cache_entry
|
||||
|
||||
|
||||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||||
"""Get from cache or raise 404."""
|
||||
entry = _cache.get(session_id)
|
||||
if not entry:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||||
return entry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ManualDeskewRequest(BaseModel):
|
||||
angle: float
|
||||
|
||||
|
||||
class DeskewGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_angle: Optional[float] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class ManualDewarpRequest(BaseModel):
|
||||
shear_degrees: float
|
||||
|
||||
|
||||
class CombinedAdjustRequest(BaseModel):
|
||||
rotation_degrees: float = 0.0
|
||||
shear_degrees: float = 0.0
|
||||
|
||||
|
||||
class DewarpGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_shear: Optional[float] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
VALID_DOCUMENT_CATEGORIES = {
|
||||
'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
||||
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
||||
}
|
||||
|
||||
|
||||
class UpdateSessionRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
document_category: Optional[str] = None
|
||||
|
||||
|
||||
class ManualColumnsRequest(BaseModel):
|
||||
columns: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ColumnGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class ManualRowsRequest(BaseModel):
|
||||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class RowGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class RemoveHandwritingRequest(BaseModel):
|
||||
method: str = "auto" # "auto" | "telea" | "ns"
|
||||
target_ink: str = "all" # "all" | "colored" | "pencil"
|
||||
dilation: int = 2 # mask dilation iterations (0-5)
|
||||
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline Log Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _append_pipeline_log(
|
||||
session_id: str,
|
||||
step_name: str,
|
||||
metrics: Dict[str, Any],
|
||||
success: bool = True,
|
||||
duration_ms: Optional[int] = None,
|
||||
):
|
||||
"""Append a step entry to the session's pipeline_log JSONB."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
return
|
||||
log = session.get("pipeline_log") or {"steps": []}
|
||||
if not isinstance(log, dict):
|
||||
log = {"steps": []}
|
||||
entry = {
|
||||
"step": step_name,
|
||||
"completed_at": datetime.utcnow().isoformat(),
|
||||
"success": success,
|
||||
"metrics": metrics,
|
||||
}
|
||||
if duration_ms is not None:
|
||||
entry["duration_ms"] = duration_ms
|
||||
log.setdefault("steps", []).append(entry)
|
||||
await update_session_db(session_id, pipeline_log=log)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Border-ghost word filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Characters that OCR produces when reading box-border lines.
|
||||
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"")
|
||||
|
||||
|
||||
def _filter_border_ghost_words(
|
||||
word_result: Dict,
|
||||
boxes: List,
|
||||
) -> int:
|
||||
"""Remove OCR words that are actually box border lines.
|
||||
|
||||
A word is considered a border ghost when it sits on a known box edge
|
||||
(left, right, top, or bottom) and looks like a line artefact (narrow
|
||||
aspect ratio or text consists only of line-like characters).
|
||||
|
||||
After removing ghost cells, columns that have become empty are also
|
||||
removed from ``columns_used`` so the grid no longer shows phantom
|
||||
columns.
|
||||
|
||||
Modifies *word_result* in-place and returns the number of removed cells.
|
||||
"""
|
||||
if not boxes or not word_result:
|
||||
return 0
|
||||
|
||||
cells = word_result.get("cells")
|
||||
if not cells:
|
||||
return 0
|
||||
|
||||
# Build border bands — vertical (X) and horizontal (Y)
|
||||
x_bands = [] # list of (x_lo, x_hi)
|
||||
y_bands = [] # list of (y_lo, y_hi)
|
||||
for b in boxes:
|
||||
bx = b.x if hasattr(b, "x") else b.get("x", 0)
|
||||
by = b.y if hasattr(b, "y") else b.get("y", 0)
|
||||
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
|
||||
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
|
||||
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
|
||||
margin = max(bt * 2, 10) + 6 # generous margin
|
||||
|
||||
# Vertical edges (left / right)
|
||||
x_bands.append((bx - margin, bx + margin))
|
||||
x_bands.append((bx + bw - margin, bx + bw + margin))
|
||||
# Horizontal edges (top / bottom)
|
||||
y_bands.append((by - margin, by + margin))
|
||||
y_bands.append((by + bh - margin, by + bh + margin))
|
||||
|
||||
img_w = word_result.get("image_width", 1)
|
||||
img_h = word_result.get("image_height", 1)
|
||||
|
||||
def _is_ghost(cell: Dict) -> bool:
|
||||
text = (cell.get("text") or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# Compute absolute pixel position
|
||||
if cell.get("bbox_px"):
|
||||
px = cell["bbox_px"]
|
||||
cx = px["x"] + px["w"] / 2
|
||||
cy = px["y"] + px["h"] / 2
|
||||
cw = px["w"]
|
||||
ch = px["h"]
|
||||
elif cell.get("bbox_pct"):
|
||||
pct = cell["bbox_pct"]
|
||||
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
|
||||
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
|
||||
cw = (pct["w"] / 100) * img_w
|
||||
ch = (pct["h"] / 100) * img_h
|
||||
else:
|
||||
return False
|
||||
|
||||
# Check if center sits on a vertical or horizontal border
|
||||
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
|
||||
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
|
||||
if not on_vertical and not on_horizontal:
|
||||
return False
|
||||
|
||||
# Very short text (1-2 chars) on a border → very likely ghost
|
||||
if len(text) <= 2:
|
||||
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
|
||||
if ch > 0 and cw / ch < 0.5:
|
||||
return True
|
||||
if cw > 0 and ch / cw < 0.5:
|
||||
return True
|
||||
# Text is only border-ghost characters?
|
||||
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||
return True
|
||||
|
||||
# Longer text but still only ghost chars and very narrow
|
||||
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||
if ch > 0 and cw / ch < 0.35:
|
||||
return True
|
||||
if cw > 0 and ch / cw < 0.35:
|
||||
return True
|
||||
return True # all ghost chars on a border → remove
|
||||
|
||||
return False
|
||||
|
||||
before = len(cells)
|
||||
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
|
||||
removed = before - len(word_result["cells"])
|
||||
|
||||
# --- Remove empty columns from columns_used ---
|
||||
columns_used = word_result.get("columns_used")
|
||||
if removed and columns_used and len(columns_used) > 1:
|
||||
remaining_cells = word_result["cells"]
|
||||
occupied_cols = {c.get("col_index") for c in remaining_cells}
|
||||
before_cols = len(columns_used)
|
||||
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
|
||||
|
||||
# Re-index columns and remap cell col_index values
|
||||
if len(columns_used) < before_cols:
|
||||
old_to_new = {}
|
||||
for new_i, col in enumerate(columns_used):
|
||||
old_to_new[col["index"]] = new_i
|
||||
col["index"] = new_i
|
||||
for cell in remaining_cells:
|
||||
old_ci = cell.get("col_index")
|
||||
if old_ci in old_to_new:
|
||||
cell["col_index"] = old_to_new[old_ci]
|
||||
word_result["columns_used"] = columns_used
|
||||
logger.info("border-ghost: removed %d empty column(s), %d remaining",
|
||||
before_cols - len(columns_used), len(columns_used))
|
||||
|
||||
if removed:
|
||||
# Update summary counts
|
||||
summary = word_result.get("summary", {})
|
||||
summary["total_cells"] = len(word_result["cells"])
|
||||
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
|
||||
word_result["summary"] = summary
|
||||
gs = word_result.get("grid_shape", {})
|
||||
gs["total_cells"] = len(word_result["cells"])
|
||||
if columns_used is not None:
|
||||
gs["cols"] = len(columns_used)
|
||||
word_result["grid_shape"] = gs
|
||||
|
||||
return removed
|
||||
Reference in New Issue
Block a user