Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_common.py
Benjamin Admin 4e668660a7 feat: add Woerterbuch category + column add/delete in grid editor
- New document category "Woerterbuch" (frontend type + backend validation)
- Column delete: hover column header → red "x" button (with confirmation)
- Column add: hover column header → "+" button inserts after that column
- Both operations support undo/redo, update cell IDs and summary
- Available in both GridEditor and StepGridReview (Kombi last step)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 16:27:12 +01:00

355 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Shared common module for the OCR pipeline.
Contains in-memory cache, helper functions, Pydantic request models,
pipeline logging, and border-ghost word filtering used by the pipeline
API endpoints and related modules.
"""
import logging
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import HTTPException
from pydantic import BaseModel
from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db
__all__ = [
# Cache
"_cache",
# Helper functions
"_get_base_image_png",
"_load_session_to_cache",
"_get_cached",
# Pydantic models
"ManualDeskewRequest",
"DeskewGroundTruthRequest",
"ManualDewarpRequest",
"CombinedAdjustRequest",
"DewarpGroundTruthRequest",
"VALID_DOCUMENT_CATEGORIES",
"UpdateSessionRequest",
"ManualColumnsRequest",
"ColumnGroundTruthRequest",
"ManualRowsRequest",
"RowGroundTruthRequest",
"RemoveHandwritingRequest",
# Pipeline log
"_append_pipeline_log",
# Border-ghost filter
"_BORDER_GHOST_CHARS",
"_filter_border_ghost_words",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# In-memory cache for active sessions (BGR numpy arrays for processing)
# DB is source of truth, cache holds BGR arrays during active processing.
# ---------------------------------------------------------------------------
_cache: Dict[str, Dict[str, Any]] = {}
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
"""Get the best available base image for a session (cropped > dewarped > original)."""
for img_type in ("cropped", "dewarped", "original"):
png_data = await get_session_image(session_id, img_type)
if png_data:
return png_data
return None
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
if session_id in _cache:
return _cache[session_id]
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
# Decode images from DB into BGR numpy arrays
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
# Sub-sessions: original image IS the cropped box region.
# Promote original_bgr to cropped_bgr so downstream steps find it.
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
_cache[session_id] = cache_entry
return cache_entry
def _get_cached(session_id: str) -> Dict[str, Any]:
"""Get from cache or raise 404."""
entry = _cache.get(session_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
return entry
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class ManualDeskewRequest(BaseModel):
angle: float
class DeskewGroundTruthRequest(BaseModel):
is_correct: bool
corrected_angle: Optional[float] = None
notes: Optional[str] = None
class ManualDewarpRequest(BaseModel):
shear_degrees: float
class CombinedAdjustRequest(BaseModel):
rotation_degrees: float = 0.0
shear_degrees: float = 0.0
class DewarpGroundTruthRequest(BaseModel):
is_correct: bool
corrected_shear: Optional[float] = None
notes: Optional[str] = None
VALID_DOCUMENT_CATEGORIES = {
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
}
class UpdateSessionRequest(BaseModel):
name: Optional[str] = None
document_category: Optional[str] = None
class ManualColumnsRequest(BaseModel):
columns: List[Dict[str, Any]]
class ColumnGroundTruthRequest(BaseModel):
is_correct: bool
corrected_columns: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class ManualRowsRequest(BaseModel):
rows: List[Dict[str, Any]]
class RowGroundTruthRequest(BaseModel):
is_correct: bool
corrected_rows: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class RemoveHandwritingRequest(BaseModel):
method: str = "auto" # "auto" | "telea" | "ns"
target_ink: str = "all" # "all" | "colored" | "pencil"
dilation: int = 2 # mask dilation iterations (0-5)
use_source: str = "auto" # "original" | "deskewed" | "auto"
# ---------------------------------------------------------------------------
# Pipeline Log Helper
# ---------------------------------------------------------------------------
async def _append_pipeline_log(
session_id: str,
step_name: str,
metrics: Dict[str, Any],
success: bool = True,
duration_ms: Optional[int] = None,
):
"""Append a step entry to the session's pipeline_log JSONB."""
session = await get_session_db(session_id)
if not session:
return
log = session.get("pipeline_log") or {"steps": []}
if not isinstance(log, dict):
log = {"steps": []}
entry = {
"step": step_name,
"completed_at": datetime.utcnow().isoformat(),
"success": success,
"metrics": metrics,
}
if duration_ms is not None:
entry["duration_ms"] = duration_ms
log.setdefault("steps", []).append(entry)
await update_session_db(session_id, pipeline_log=log)
# ---------------------------------------------------------------------------
# Border-ghost word filter
# ---------------------------------------------------------------------------
# Characters that OCR produces when reading box-border lines.
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—_~.,;:'\"")
def _filter_border_ghost_words(
word_result: Dict,
boxes: List,
) -> int:
"""Remove OCR words that are actually box border lines.
A word is considered a border ghost when it sits on a known box edge
(left, right, top, or bottom) and looks like a line artefact (narrow
aspect ratio or text consists only of line-like characters).
After removing ghost cells, columns that have become empty are also
removed from ``columns_used`` so the grid no longer shows phantom
columns.
Modifies *word_result* in-place and returns the number of removed cells.
"""
if not boxes or not word_result:
return 0
cells = word_result.get("cells")
if not cells:
return 0
# Build border bands — vertical (X) and horizontal (Y)
x_bands = [] # list of (x_lo, x_hi)
y_bands = [] # list of (y_lo, y_hi)
for b in boxes:
bx = b.x if hasattr(b, "x") else b.get("x", 0)
by = b.y if hasattr(b, "y") else b.get("y", 0)
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
margin = max(bt * 2, 10) + 6 # generous margin
# Vertical edges (left / right)
x_bands.append((bx - margin, bx + margin))
x_bands.append((bx + bw - margin, bx + bw + margin))
# Horizontal edges (top / bottom)
y_bands.append((by - margin, by + margin))
y_bands.append((by + bh - margin, by + bh + margin))
img_w = word_result.get("image_width", 1)
img_h = word_result.get("image_height", 1)
def _is_ghost(cell: Dict) -> bool:
text = (cell.get("text") or "").strip()
if not text:
return False
# Compute absolute pixel position
if cell.get("bbox_px"):
px = cell["bbox_px"]
cx = px["x"] + px["w"] / 2
cy = px["y"] + px["h"] / 2
cw = px["w"]
ch = px["h"]
elif cell.get("bbox_pct"):
pct = cell["bbox_pct"]
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
cw = (pct["w"] / 100) * img_w
ch = (pct["h"] / 100) * img_h
else:
return False
# Check if center sits on a vertical or horizontal border
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
if not on_vertical and not on_horizontal:
return False
# Very short text (1-2 chars) on a border → very likely ghost
if len(text) <= 2:
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
if ch > 0 and cw / ch < 0.5:
return True
if cw > 0 and ch / cw < 0.5:
return True
# Text is only border-ghost characters?
if all(c in _BORDER_GHOST_CHARS for c in text):
return True
# Longer text but still only ghost chars and very narrow
if all(c in _BORDER_GHOST_CHARS for c in text):
if ch > 0 and cw / ch < 0.35:
return True
if cw > 0 and ch / cw < 0.35:
return True
return True # all ghost chars on a border → remove
return False
before = len(cells)
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
removed = before - len(word_result["cells"])
# --- Remove empty columns from columns_used ---
columns_used = word_result.get("columns_used")
if removed and columns_used and len(columns_used) > 1:
remaining_cells = word_result["cells"]
occupied_cols = {c.get("col_index") for c in remaining_cells}
before_cols = len(columns_used)
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
# Re-index columns and remap cell col_index values
if len(columns_used) < before_cols:
old_to_new = {}
for new_i, col in enumerate(columns_used):
old_to_new[col["index"]] = new_i
col["index"] = new_i
for cell in remaining_cells:
old_ci = cell.get("col_index")
if old_ci in old_to_new:
cell["col_index"] = old_to_new[old_ci]
word_result["columns_used"] = columns_used
logger.info("border-ghost: removed %d empty column(s), %d remaining",
before_cols - len(columns_used), len(columns_used))
if removed:
# Update summary counts
summary = word_result.get("summary", {})
summary["total_cells"] = len(word_result["cells"])
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
word_result["summary"] = summary
gs = word_result.get("grid_shape", {})
gs["total_cells"] = len(word_result["cells"])
if columns_used is not None:
gs["cols"] = len(columns_used)
word_result["grid_shape"] = gs
return removed