Each module is under 1050 lines: - ocr_pipeline_common.py (354) - shared state, cache, models, helpers - ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type - ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns - ocr_pipeline_rows.py (348) - row detection, box-overlay helper - ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct - ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints - ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export - ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports router, _cache, and test-imported symbols for backward compatibility. No changes needed in main.py or tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
349 lines
13 KiB
Python
349 lines
13 KiB
Python
"""
|
|
OCR Pipeline - Row Detection Endpoints.
|
|
|
|
Extracted from ocr_pipeline_api.py.
|
|
Handles row detection (auto + manual) and row ground truth.
|
|
|
|
Lizenz: Apache 2.0
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
from cv_vocab_pipeline import (
|
|
create_ocr_image,
|
|
detect_column_geometry,
|
|
detect_row_geometry,
|
|
)
|
|
from ocr_pipeline_common import (
|
|
_cache,
|
|
_load_session_to_cache,
|
|
_get_cached,
|
|
_append_pipeline_log,
|
|
ManualRowsRequest,
|
|
RowGroundTruthRequest,
|
|
)
|
|
from ocr_pipeline_session_store import (
|
|
get_session_db,
|
|
update_session_db,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _draw_box_exclusion_overlay(
|
|
img: np.ndarray,
|
|
zones: List[Dict],
|
|
*,
|
|
label: str = "BOX — separat verarbeitet",
|
|
) -> None:
|
|
"""Draw red semi-transparent rectangles over box zones (in-place).
|
|
|
|
Reusable for columns, rows, and words overlays.
|
|
"""
|
|
for zone in zones:
|
|
if zone.get("zone_type") != "box" or not zone.get("box"):
|
|
continue
|
|
box = zone["box"]
|
|
bx, by = box["x"], box["y"]
|
|
bw, bh = box["width"], box["height"]
|
|
|
|
# Red semi-transparent fill (~25 %)
|
|
box_overlay = img.copy()
|
|
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
|
|
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
|
|
|
|
# Border
|
|
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
|
|
|
|
# Label
|
|
cv2.putText(img, label, (bx + 10, by + bh - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Row Detection Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/rows")
|
|
async def detect_rows(session_id: str):
|
|
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
if dewarped_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
|
|
|
|
t0 = time.time()
|
|
|
|
# Try to reuse cached word_dicts and inv from column detection
|
|
word_dicts = cached.get("_word_dicts")
|
|
inv = cached.get("_inv")
|
|
content_bounds = cached.get("_content_bounds")
|
|
|
|
if word_dicts is None or inv is None or content_bounds is None:
|
|
# Not cached — run column geometry to get intermediates
|
|
ocr_img = create_ocr_image(dewarped_bgr)
|
|
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
|
if geo_result is None:
|
|
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
|
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
|
cached["_word_dicts"] = word_dicts
|
|
cached["_inv"] = inv
|
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
else:
|
|
left_x, right_x, top_y, bottom_y = content_bounds
|
|
|
|
# Read zones from column_result to exclude box regions
|
|
session = await get_session_db(session_id)
|
|
column_result = (session or {}).get("column_result") or {}
|
|
is_sub_session = bool((session or {}).get("parent_session_id"))
|
|
|
|
# Sub-sessions (box crops): use word-grouping instead of gap-based
|
|
# row detection. Box images are small with complex internal layouts
|
|
# (headings, sub-columns) where the horizontal projection approach
|
|
# merges rows. Word-grouping directly clusters words by Y proximity,
|
|
# which is more robust for these cases.
|
|
if is_sub_session and word_dicts:
|
|
from cv_layout import _build_rows_from_word_grouping
|
|
rows = _build_rows_from_word_grouping(
|
|
word_dicts, left_x, right_x, top_y, bottom_y,
|
|
right_x - left_x, bottom_y - top_y,
|
|
)
|
|
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
|
|
else:
|
|
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
|
|
|
|
# Collect box y-ranges for filtering.
|
|
# Use border_thickness to shrink the exclusion zone: the border pixels
|
|
# belong visually to the box frame, but text rows above/below the box
|
|
# may overlap with the border area and must not be clipped.
|
|
box_ranges = [] # [(y_start, y_end)]
|
|
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
|
|
for zone in zones:
|
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
|
box = zone["box"]
|
|
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
|
box_ranges.append((box["y"], box["y"] + box["height"]))
|
|
# Inner range: shrink by border thickness so boundary rows aren't excluded
|
|
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
|
|
|
if box_ranges and inv is not None:
|
|
# Combined-image approach: strip box regions from inv image,
|
|
# run row detection on the combined image, then remap y-coords back.
|
|
content_strips = [] # [(y_start, y_end)] in absolute coords
|
|
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
|
|
# Using inner ranges means the border area is included in the content
|
|
# strips, so the last row above a box isn't clipped by the border.
|
|
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
|
|
strip_start = top_y
|
|
for by_start, by_end in sorted_boxes:
|
|
if by_start > strip_start:
|
|
content_strips.append((strip_start, by_start))
|
|
strip_start = max(strip_start, by_end)
|
|
if strip_start < bottom_y:
|
|
content_strips.append((strip_start, bottom_y))
|
|
|
|
# Filter to strips with meaningful height
|
|
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
|
|
|
|
if content_strips:
|
|
# Stack content strips vertically
|
|
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
|
|
combined_inv = np.vstack(inv_strips)
|
|
|
|
# Filter word_dicts to only include words from content strips
|
|
combined_words = []
|
|
cum_y = 0
|
|
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
|
|
for ys, ye in content_strips:
|
|
h = ye - ys
|
|
strip_offsets.append((cum_y, h, ys))
|
|
for w in word_dicts:
|
|
w_abs_y = w['top'] + top_y # word y is relative to content top
|
|
w_center = w_abs_y + w['height'] / 2
|
|
if ys <= w_center < ye:
|
|
# Remap to combined coordinates
|
|
w_copy = dict(w)
|
|
w_copy['top'] = cum_y + (w_abs_y - ys)
|
|
combined_words.append(w_copy)
|
|
cum_y += h
|
|
|
|
# Run row detection on combined image
|
|
combined_h = combined_inv.shape[0]
|
|
rows = detect_row_geometry(
|
|
combined_inv, combined_words, left_x, right_x, 0, combined_h,
|
|
)
|
|
|
|
# Remap y-coordinates back to absolute page coords
|
|
def _combined_y_to_abs(cy: int) -> int:
|
|
for c_start, s_h, abs_start in strip_offsets:
|
|
if cy < c_start + s_h:
|
|
return abs_start + (cy - c_start)
|
|
last_c, last_h, last_abs = strip_offsets[-1]
|
|
return last_abs + last_h
|
|
|
|
for r in rows:
|
|
abs_y = _combined_y_to_abs(r.y)
|
|
abs_y_end = _combined_y_to_abs(r.y + r.height)
|
|
r.y = abs_y
|
|
r.height = abs_y_end - abs_y
|
|
else:
|
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
|
else:
|
|
# No boxes — standard row detection
|
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
|
|
|
duration = time.time() - t0
|
|
|
|
# Assign zone_index based on which content zone each row falls in
|
|
# Build content zone list with indices
|
|
zones = column_result.get("zones") or []
|
|
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
|
|
|
|
# Build serializable result (exclude words to keep payload small)
|
|
rows_data = []
|
|
for r in rows:
|
|
# Determine zone_index
|
|
zone_idx = 0
|
|
row_center_y = r.y + r.height / 2
|
|
for zi, zone in content_zones:
|
|
zy = zone["y"]
|
|
zh = zone["height"]
|
|
if zy <= row_center_y < zy + zh:
|
|
zone_idx = zi
|
|
break
|
|
|
|
rd = {
|
|
"index": r.index,
|
|
"x": r.x,
|
|
"y": r.y,
|
|
"width": r.width,
|
|
"height": r.height,
|
|
"word_count": r.word_count,
|
|
"row_type": r.row_type,
|
|
"gap_before": r.gap_before,
|
|
"zone_index": zone_idx,
|
|
}
|
|
rows_data.append(rd)
|
|
|
|
type_counts = {}
|
|
for r in rows:
|
|
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
|
|
|
row_result = {
|
|
"rows": rows_data,
|
|
"summary": type_counts,
|
|
"total_rows": len(rows),
|
|
"duration_seconds": round(duration, 2),
|
|
}
|
|
|
|
# Persist to DB — also invalidate word_result since rows changed
|
|
await update_session_db(
|
|
session_id,
|
|
row_result=row_result,
|
|
word_result=None,
|
|
current_step=7,
|
|
)
|
|
|
|
cached["row_result"] = row_result
|
|
cached.pop("word_result", None)
|
|
|
|
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
|
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
|
|
|
content_rows = sum(1 for r in rows if r.row_type == "content")
|
|
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
|
|
await _append_pipeline_log(session_id, "rows", {
|
|
"total_rows": len(rows),
|
|
"content_rows": content_rows,
|
|
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
|
|
"avg_row_height_px": avg_height,
|
|
}, duration_ms=int(duration * 1000))
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
**row_result,
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/rows/manual")
|
|
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
|
"""Override detected rows with manual definitions."""
|
|
row_result = {
|
|
"rows": req.rows,
|
|
"total_rows": len(req.rows),
|
|
"duration_seconds": 0,
|
|
"method": "manual",
|
|
}
|
|
|
|
await update_session_db(session_id, row_result=row_result, word_result=None)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["row_result"] = row_result
|
|
_cache[session_id].pop("word_result", None)
|
|
|
|
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
|
f"{len(req.rows)} rows set")
|
|
|
|
return {"session_id": session_id, **row_result}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/ground-truth/rows")
|
|
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
|
"""Save ground truth feedback for the row detection step."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
gt = {
|
|
"is_correct": req.is_correct,
|
|
"corrected_rows": req.corrected_rows,
|
|
"notes": req.notes,
|
|
"saved_at": datetime.utcnow().isoformat(),
|
|
"row_result": session.get("row_result"),
|
|
}
|
|
ground_truth["rows"] = gt
|
|
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
return {"session_id": session_id, "ground_truth": gt}
|
|
|
|
|
|
@router.get("/sessions/{session_id}/ground-truth/rows")
|
|
async def get_row_ground_truth(session_id: str):
|
|
"""Retrieve saved ground truth for row detection."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
rows_gt = ground_truth.get("rows")
|
|
if not rows_gt:
|
|
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"rows_gt": rows_gt,
|
|
"rows_auto": session.get("row_result"),
|
|
}
|