Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_rows.py
Benjamin Admin ec287fd12e refactor: split ocr_pipeline_api.py (5426 lines) into 8 modules
Each module is under 1050 lines:
- ocr_pipeline_common.py (354) - shared state, cache, models, helpers
- ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type
- ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns
- ocr_pipeline_rows.py (348) - row detection, box-overlay helper
- ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct
- ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints
- ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export
- ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess

ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports
router, _cache, and test-imported symbols for backward compatibility.
No changes needed in main.py or tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 08:42:00 +01:00

349 lines
13 KiB
Python

"""
OCR Pipeline - Row Detection Endpoints.
Extracted from ocr_pipeline_api.py.
Handles row detection (auto + manual) and row ground truth.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import (
create_ocr_image,
detect_column_geometry,
detect_row_geometry,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
ManualRowsRequest,
RowGroundTruthRequest,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
# ---------------------------------------------------------------------------
def _draw_box_exclusion_overlay(
img: np.ndarray,
zones: List[Dict],
*,
label: str = "BOX — separat verarbeitet",
) -> None:
"""Draw red semi-transparent rectangles over box zones (in-place).
Reusable for columns, rows, and words overlays.
"""
for zone in zones:
if zone.get("zone_type") != "box" or not zone.get("box"):
continue
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Red semi-transparent fill (~25 %)
box_overlay = img.copy()
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
# Border
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
# Label
cv2.putText(img, label, (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# ---------------------------------------------------------------------------
# Row Detection Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/rows")
async def detect_rows(session_id: str):
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
t0 = time.time()
# Try to reuse cached word_dicts and inv from column detection
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
# Not cached — run column geometry to get intermediates
ocr_img = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
else:
left_x, right_x, top_y, bottom_y = content_bounds
# Read zones from column_result to exclude box regions
session = await get_session_db(session_id)
column_result = (session or {}).get("column_result") or {}
is_sub_session = bool((session or {}).get("parent_session_id"))
# Sub-sessions (box crops): use word-grouping instead of gap-based
# row detection. Box images are small with complex internal layouts
# (headings, sub-columns) where the horizontal projection approach
# merges rows. Word-grouping directly clusters words by Y proximity,
# which is more robust for these cases.
if is_sub_session and word_dicts:
from cv_layout import _build_rows_from_word_grouping
rows = _build_rows_from_word_grouping(
word_dicts, left_x, right_x, top_y, bottom_y,
right_x - left_x, bottom_y - top_y,
)
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
else:
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
# Collect box y-ranges for filtering.
# Use border_thickness to shrink the exclusion zone: the border pixels
# belong visually to the box frame, but text rows above/below the box
# may overlap with the border area and must not be clipped.
box_ranges = [] # [(y_start, y_end)]
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
box_ranges.append((box["y"], box["y"] + box["height"]))
# Inner range: shrink by border thickness so boundary rows aren't excluded
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges and inv is not None:
# Combined-image approach: strip box regions from inv image,
# run row detection on the combined image, then remap y-coords back.
content_strips = [] # [(y_start, y_end)] in absolute coords
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
# Using inner ranges means the border area is included in the content
# strips, so the last row above a box isn't clipped by the border.
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
strip_start = top_y
for by_start, by_end in sorted_boxes:
if by_start > strip_start:
content_strips.append((strip_start, by_start))
strip_start = max(strip_start, by_end)
if strip_start < bottom_y:
content_strips.append((strip_start, bottom_y))
# Filter to strips with meaningful height
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
if content_strips:
# Stack content strips vertically
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
combined_inv = np.vstack(inv_strips)
# Filter word_dicts to only include words from content strips
combined_words = []
cum_y = 0
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
for ys, ye in content_strips:
h = ye - ys
strip_offsets.append((cum_y, h, ys))
for w in word_dicts:
w_abs_y = w['top'] + top_y # word y is relative to content top
w_center = w_abs_y + w['height'] / 2
if ys <= w_center < ye:
# Remap to combined coordinates
w_copy = dict(w)
w_copy['top'] = cum_y + (w_abs_y - ys)
combined_words.append(w_copy)
cum_y += h
# Run row detection on combined image
combined_h = combined_inv.shape[0]
rows = detect_row_geometry(
combined_inv, combined_words, left_x, right_x, 0, combined_h,
)
# Remap y-coordinates back to absolute page coords
def _combined_y_to_abs(cy: int) -> int:
for c_start, s_h, abs_start in strip_offsets:
if cy < c_start + s_h:
return abs_start + (cy - c_start)
last_c, last_h, last_abs = strip_offsets[-1]
return last_abs + last_h
for r in rows:
abs_y = _combined_y_to_abs(r.y)
abs_y_end = _combined_y_to_abs(r.y + r.height)
r.y = abs_y
r.height = abs_y_end - abs_y
else:
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
else:
# No boxes — standard row detection
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
duration = time.time() - t0
# Assign zone_index based on which content zone each row falls in
# Build content zone list with indices
zones = column_result.get("zones") or []
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
# Build serializable result (exclude words to keep payload small)
rows_data = []
for r in rows:
# Determine zone_index
zone_idx = 0
row_center_y = r.y + r.height / 2
for zi, zone in content_zones:
zy = zone["y"]
zh = zone["height"]
if zy <= row_center_y < zy + zh:
zone_idx = zi
break
rd = {
"index": r.index,
"x": r.x,
"y": r.y,
"width": r.width,
"height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
"zone_index": zone_idx,
}
rows_data.append(rd)
type_counts = {}
for r in rows:
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
row_result = {
"rows": rows_data,
"summary": type_counts,
"total_rows": len(rows),
"duration_seconds": round(duration, 2),
}
# Persist to DB — also invalidate word_result since rows changed
await update_session_db(
session_id,
row_result=row_result,
word_result=None,
current_step=7,
)
cached["row_result"] = row_result
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: rows session {session_id}: "
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
content_rows = sum(1 for r in rows if r.row_type == "content")
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
await _append_pipeline_log(session_id, "rows", {
"total_rows": len(rows),
"content_rows": content_rows,
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
"avg_row_height_px": avg_height,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**row_result,
}
@router.post("/sessions/{session_id}/rows/manual")
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
"""Override detected rows with manual definitions."""
row_result = {
"rows": req.rows,
"total_rows": len(req.rows),
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, row_result=row_result, word_result=None)
if session_id in _cache:
_cache[session_id]["row_result"] = row_result
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
f"{len(req.rows)} rows set")
return {"session_id": session_id, **row_result}
@router.post("/sessions/{session_id}/ground-truth/rows")
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
"""Save ground truth feedback for the row detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_rows": req.corrected_rows,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"row_result": session.get("row_result"),
}
ground_truth["rows"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/rows")
async def get_row_ground_truth(session_id: str):
"""Retrieve saved ground truth for row detection."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
rows_gt = ground_truth.get("rows")
if not rows_gt:
raise HTTPException(status_code=404, detail="No row ground truth saved")
return {
"session_id": session_id,
"rows_gt": rows_gt,
"rows_auto": session.get("row_result"),
}