0504d22b8e
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
|
Shared common module for the OCR pipeline.
|
|
|
|
Contains in-memory cache, helper functions, Pydantic request models,
|
|
pipeline logging, and border-ghost word filtering used by the pipeline
|
|
API endpoints and related modules.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from .session_store import get_session_db, get_session_image, update_session_db
|
|
|
|
__all__ = [
|
|
# Cache
|
|
"_cache",
|
|
# Helper functions
|
|
"_get_base_image_png",
|
|
"_load_session_to_cache",
|
|
"_get_cached",
|
|
# Pydantic models
|
|
"ManualDeskewRequest",
|
|
"DeskewGroundTruthRequest",
|
|
"ManualDewarpRequest",
|
|
"CombinedAdjustRequest",
|
|
"DewarpGroundTruthRequest",
|
|
"VALID_DOCUMENT_CATEGORIES",
|
|
"UpdateSessionRequest",
|
|
"ManualColumnsRequest",
|
|
"ColumnGroundTruthRequest",
|
|
"ManualRowsRequest",
|
|
"RowGroundTruthRequest",
|
|
"RemoveHandwritingRequest",
|
|
# Pipeline log
|
|
"_append_pipeline_log",
|
|
# Border-ghost filter
|
|
"_BORDER_GHOST_CHARS",
|
|
"_filter_border_ghost_words",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
|
# DB is source of truth, cache holds BGR arrays during active processing.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_cache: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
|
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
|
for img_type in ("cropped", "dewarped", "original"):
|
|
png_data = await get_session_image(session_id, img_type)
|
|
if png_data:
|
|
return png_data
|
|
return None
|
|
|
|
|
|
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
|
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
if session_id in _cache:
|
|
return _cache[session_id]
|
|
|
|
cache_entry: Dict[str, Any] = {
|
|
"id": session_id,
|
|
**session,
|
|
"original_bgr": None,
|
|
"oriented_bgr": None,
|
|
"cropped_bgr": None,
|
|
"deskewed_bgr": None,
|
|
"dewarped_bgr": None,
|
|
}
|
|
|
|
# Decode images from DB into BGR numpy arrays
|
|
for img_type, bgr_key in [
|
|
("original", "original_bgr"),
|
|
("oriented", "oriented_bgr"),
|
|
("cropped", "cropped_bgr"),
|
|
("deskewed", "deskewed_bgr"),
|
|
("dewarped", "dewarped_bgr"),
|
|
]:
|
|
png_data = await get_session_image(session_id, img_type)
|
|
if png_data:
|
|
arr = np.frombuffer(png_data, dtype=np.uint8)
|
|
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
cache_entry[bgr_key] = bgr
|
|
|
|
# Sub-sessions: original image IS the cropped box region.
|
|
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
|
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
|
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
|
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
|
|
|
_cache[session_id] = cache_entry
|
|
return cache_entry
|
|
|
|
|
|
def _get_cached(session_id: str) -> Dict[str, Any]:
|
|
"""Get from cache or raise 404."""
|
|
entry = _cache.get(session_id)
|
|
if not entry:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
|
return entry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pydantic Models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ManualDeskewRequest(BaseModel):
|
|
angle: float
|
|
|
|
|
|
class DeskewGroundTruthRequest(BaseModel):
|
|
is_correct: bool
|
|
corrected_angle: Optional[float] = None
|
|
notes: Optional[str] = None
|
|
|
|
|
|
class ManualDewarpRequest(BaseModel):
|
|
shear_degrees: float
|
|
|
|
|
|
class CombinedAdjustRequest(BaseModel):
|
|
rotation_degrees: float = 0.0
|
|
shear_degrees: float = 0.0
|
|
|
|
|
|
class DewarpGroundTruthRequest(BaseModel):
|
|
is_correct: bool
|
|
corrected_shear: Optional[float] = None
|
|
notes: Optional[str] = None
|
|
|
|
|
|
VALID_DOCUMENT_CATEGORIES = {
|
|
'vokabelseite', 'woerterbuch', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
|
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
|
}
|
|
|
|
|
|
class UpdateSessionRequest(BaseModel):
|
|
name: Optional[str] = None
|
|
document_category: Optional[str] = None
|
|
|
|
|
|
class ManualColumnsRequest(BaseModel):
|
|
columns: List[Dict[str, Any]]
|
|
|
|
|
|
class ColumnGroundTruthRequest(BaseModel):
|
|
is_correct: bool
|
|
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
|
notes: Optional[str] = None
|
|
|
|
|
|
class ManualRowsRequest(BaseModel):
|
|
rows: List[Dict[str, Any]]
|
|
|
|
|
|
class RowGroundTruthRequest(BaseModel):
|
|
is_correct: bool
|
|
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
|
notes: Optional[str] = None
|
|
|
|
|
|
class RemoveHandwritingRequest(BaseModel):
|
|
method: str = "auto" # "auto" | "telea" | "ns"
|
|
target_ink: str = "all" # "all" | "colored" | "pencil"
|
|
dilation: int = 2 # mask dilation iterations (0-5)
|
|
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline Log Helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _append_pipeline_log(
|
|
session_id: str,
|
|
step_name: str,
|
|
metrics: Dict[str, Any],
|
|
success: bool = True,
|
|
duration_ms: Optional[int] = None,
|
|
):
|
|
"""Append a step entry to the session's pipeline_log JSONB."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
return
|
|
log = session.get("pipeline_log") or {"steps": []}
|
|
if not isinstance(log, dict):
|
|
log = {"steps": []}
|
|
entry = {
|
|
"step": step_name,
|
|
"completed_at": datetime.utcnow().isoformat(),
|
|
"success": success,
|
|
"metrics": metrics,
|
|
}
|
|
if duration_ms is not None:
|
|
entry["duration_ms"] = duration_ms
|
|
log.setdefault("steps", []).append(entry)
|
|
await update_session_db(session_id, pipeline_log=log)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Border-ghost word filter
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Characters that OCR produces when reading box-border lines.
|
|
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"")
|
|
|
|
|
|
def _filter_border_ghost_words(
|
|
word_result: Dict,
|
|
boxes: List,
|
|
) -> int:
|
|
"""Remove OCR words that are actually box border lines.
|
|
|
|
A word is considered a border ghost when it sits on a known box edge
|
|
(left, right, top, or bottom) and looks like a line artefact (narrow
|
|
aspect ratio or text consists only of line-like characters).
|
|
|
|
After removing ghost cells, columns that have become empty are also
|
|
removed from ``columns_used`` so the grid no longer shows phantom
|
|
columns.
|
|
|
|
Modifies *word_result* in-place and returns the number of removed cells.
|
|
"""
|
|
if not boxes or not word_result:
|
|
return 0
|
|
|
|
cells = word_result.get("cells")
|
|
if not cells:
|
|
return 0
|
|
|
|
# Build border bands — vertical (X) and horizontal (Y)
|
|
x_bands = [] # list of (x_lo, x_hi)
|
|
y_bands = [] # list of (y_lo, y_hi)
|
|
for b in boxes:
|
|
bx = b.x if hasattr(b, "x") else b.get("x", 0)
|
|
by = b.y if hasattr(b, "y") else b.get("y", 0)
|
|
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
|
|
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
|
|
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
|
|
margin = max(bt * 2, 10) + 6 # generous margin
|
|
|
|
# Vertical edges (left / right)
|
|
x_bands.append((bx - margin, bx + margin))
|
|
x_bands.append((bx + bw - margin, bx + bw + margin))
|
|
# Horizontal edges (top / bottom)
|
|
y_bands.append((by - margin, by + margin))
|
|
y_bands.append((by + bh - margin, by + bh + margin))
|
|
|
|
img_w = word_result.get("image_width", 1)
|
|
img_h = word_result.get("image_height", 1)
|
|
|
|
def _is_ghost(cell: Dict) -> bool:
|
|
text = (cell.get("text") or "").strip()
|
|
if not text:
|
|
return False
|
|
|
|
# Compute absolute pixel position
|
|
if cell.get("bbox_px"):
|
|
px = cell["bbox_px"]
|
|
cx = px["x"] + px["w"] / 2
|
|
cy = px["y"] + px["h"] / 2
|
|
cw = px["w"]
|
|
ch = px["h"]
|
|
elif cell.get("bbox_pct"):
|
|
pct = cell["bbox_pct"]
|
|
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
|
|
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
|
|
cw = (pct["w"] / 100) * img_w
|
|
ch = (pct["h"] / 100) * img_h
|
|
else:
|
|
return False
|
|
|
|
# Check if center sits on a vertical or horizontal border
|
|
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
|
|
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
|
|
if not on_vertical and not on_horizontal:
|
|
return False
|
|
|
|
# Very short text (1-2 chars) on a border → very likely ghost
|
|
if len(text) <= 2:
|
|
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
|
|
if ch > 0 and cw / ch < 0.5:
|
|
return True
|
|
if cw > 0 and ch / cw < 0.5:
|
|
return True
|
|
# Text is only border-ghost characters?
|
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
|
return True
|
|
|
|
# Longer text but still only ghost chars and very narrow
|
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
|
if ch > 0 and cw / ch < 0.35:
|
|
return True
|
|
if cw > 0 and ch / cw < 0.35:
|
|
return True
|
|
return True # all ghost chars on a border → remove
|
|
|
|
return False
|
|
|
|
before = len(cells)
|
|
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
|
|
removed = before - len(word_result["cells"])
|
|
|
|
# --- Remove empty columns from columns_used ---
|
|
columns_used = word_result.get("columns_used")
|
|
if removed and columns_used and len(columns_used) > 1:
|
|
remaining_cells = word_result["cells"]
|
|
occupied_cols = {c.get("col_index") for c in remaining_cells}
|
|
before_cols = len(columns_used)
|
|
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
|
|
|
|
# Re-index columns and remap cell col_index values
|
|
if len(columns_used) < before_cols:
|
|
old_to_new = {}
|
|
for new_i, col in enumerate(columns_used):
|
|
old_to_new[col["index"]] = new_i
|
|
col["index"] = new_i
|
|
for cell in remaining_cells:
|
|
old_ci = cell.get("col_index")
|
|
if old_ci in old_to_new:
|
|
cell["col_index"] = old_to_new[old_ci]
|
|
word_result["columns_used"] = columns_used
|
|
logger.info("border-ghost: removed %d empty column(s), %d remaining",
|
|
before_cols - len(columns_used), len(columns_used))
|
|
|
|
if removed:
|
|
# Update summary counts
|
|
summary = word_result.get("summary", {})
|
|
summary["total_cells"] = len(word_result["cells"])
|
|
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
|
|
word_result["summary"] = summary
|
|
gs = word_result.get("grid_shape", {})
|
|
gs["total_cells"] = len(word_result["cells"])
|
|
if columns_used is not None:
|
|
gs["cols"] = len(columns_used)
|
|
word_result["grid_shape"] = gs
|
|
|
|
return removed
|