""" OCR Pipeline - Row Detection Endpoints. Extracted from ocr_pipeline_api.py. Handles row detection (auto + manual) and row ground truth. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import logging import time from datetime import datetime from typing import Any, Dict, List, Optional import cv2 import numpy as np from fastapi import APIRouter, HTTPException from cv_vocab_pipeline import ( create_ocr_image, detect_column_geometry, detect_row_geometry, ) from ocr_pipeline_common import ( _cache, _load_session_to_cache, _get_cached, _append_pipeline_log, ManualRowsRequest, RowGroundTruthRequest, ) from ocr_pipeline_session_store import ( get_session_db, update_session_db, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) # --------------------------------------------------------------------------- # Helper: Box-exclusion overlay (used by rows overlay and columns overlay) # --------------------------------------------------------------------------- def _draw_box_exclusion_overlay( img: np.ndarray, zones: List[Dict], *, label: str = "BOX — separat verarbeitet", ) -> None: """Draw red semi-transparent rectangles over box zones (in-place). Reusable for columns, rows, and words overlays. """ for zone in zones: if zone.get("zone_type") != "box" or not zone.get("box"): continue box = zone["box"] bx, by = box["x"], box["y"] bw, bh = box["width"], box["height"] # Red semi-transparent fill (~25 %) box_overlay = img.copy() cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1) cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img) # Border cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2) # Label cv2.putText(img, label, (bx + 10, by + bh - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # --------------------------------------------------------------------------- # Row Detection Endpoints # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/rows") async def detect_rows(session_id: str): """Run row detection on the cropped (or dewarped) image using horizontal gap analysis.""" if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") if dewarped_bgr is None: raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection") t0 = time.time() # Try to reuse cached word_dicts and inv from column detection word_dicts = cached.get("_word_dicts") inv = cached.get("_inv") content_bounds = cached.get("_content_bounds") if word_dicts is None or inv is None or content_bounds is None: # Not cached — run column geometry to get intermediates ocr_img = create_ocr_image(dewarped_bgr) geo_result = detect_column_geometry(ocr_img, dewarped_bgr) if geo_result is None: raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows") _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result cached["_word_dicts"] = word_dicts cached["_inv"] = inv cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) else: left_x, right_x, top_y, bottom_y = content_bounds # Read zones from column_result to exclude box regions session = await get_session_db(session_id) column_result = (session or {}).get("column_result") or {} is_sub_session = bool((session or {}).get("parent_session_id")) # Sub-sessions (box crops): use word-grouping instead of gap-based # row detection. Box images are small with complex internal layouts # (headings, sub-columns) where the horizontal projection approach # merges rows. Word-grouping directly clusters words by Y proximity, # which is more robust for these cases. if is_sub_session and word_dicts: from cv_layout import _build_rows_from_word_grouping rows = _build_rows_from_word_grouping( word_dicts, left_x, right_x, top_y, bottom_y, right_x - left_x, bottom_y - top_y, ) logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows") else: zones = column_result.get("zones") or [] # zones can be None for sub-sessions # Collect box y-ranges for filtering. # Use border_thickness to shrink the exclusion zone: the border pixels # belong visually to the box frame, but text rows above/below the box # may overlap with the border area and must not be clipped. box_ranges = [] # [(y_start, y_end)] box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering for zone in zones: if zone.get("zone_type") == "box" and zone.get("box"): box = zone["box"] bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin box_ranges.append((box["y"], box["y"] + box["height"])) # Inner range: shrink by border thickness so boundary rows aren't excluded box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) if box_ranges and inv is not None: # Combined-image approach: strip box regions from inv image, # run row detection on the combined image, then remap y-coords back. content_strips = [] # [(y_start, y_end)] in absolute coords # Build content strips by subtracting box inner ranges from [top_y, bottom_y]. # Using inner ranges means the border area is included in the content # strips, so the last row above a box isn't clipped by the border. sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0]) strip_start = top_y for by_start, by_end in sorted_boxes: if by_start > strip_start: content_strips.append((strip_start, by_start)) strip_start = max(strip_start, by_end) if strip_start < bottom_y: content_strips.append((strip_start, bottom_y)) # Filter to strips with meaningful height content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20] if content_strips: # Stack content strips vertically inv_strips = [inv[ys:ye, :] for ys, ye in content_strips] combined_inv = np.vstack(inv_strips) # Filter word_dicts to only include words from content strips combined_words = [] cum_y = 0 strip_offsets = [] # (combined_y_start, strip_height, abs_y_start) for ys, ye in content_strips: h = ye - ys strip_offsets.append((cum_y, h, ys)) for w in word_dicts: w_abs_y = w['top'] + top_y # word y is relative to content top w_center = w_abs_y + w['height'] / 2 if ys <= w_center < ye: # Remap to combined coordinates w_copy = dict(w) w_copy['top'] = cum_y + (w_abs_y - ys) combined_words.append(w_copy) cum_y += h # Run row detection on combined image combined_h = combined_inv.shape[0] rows = detect_row_geometry( combined_inv, combined_words, left_x, right_x, 0, combined_h, ) # Remap y-coordinates back to absolute page coords def _combined_y_to_abs(cy: int) -> int: for c_start, s_h, abs_start in strip_offsets: if cy < c_start + s_h: return abs_start + (cy - c_start) last_c, last_h, last_abs = strip_offsets[-1] return last_abs + last_h for r in rows: abs_y = _combined_y_to_abs(r.y) abs_y_end = _combined_y_to_abs(r.y + r.height) r.y = abs_y r.height = abs_y_end - abs_y else: rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) else: # No boxes — standard row detection rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) duration = time.time() - t0 # Assign zone_index based on which content zone each row falls in # Build content zone list with indices zones = column_result.get("zones") or [] content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else [] # Build serializable result (exclude words to keep payload small) rows_data = [] for r in rows: # Determine zone_index zone_idx = 0 row_center_y = r.y + r.height / 2 for zi, zone in content_zones: zy = zone["y"] zh = zone["height"] if zy <= row_center_y < zy + zh: zone_idx = zi break rd = { "index": r.index, "x": r.x, "y": r.y, "width": r.width, "height": r.height, "word_count": r.word_count, "row_type": r.row_type, "gap_before": r.gap_before, "zone_index": zone_idx, } rows_data.append(rd) type_counts = {} for r in rows: type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1 row_result = { "rows": rows_data, "summary": type_counts, "total_rows": len(rows), "duration_seconds": round(duration, 2), } # Persist to DB — also invalidate word_result since rows changed await update_session_db( session_id, row_result=row_result, word_result=None, current_step=7, ) cached["row_result"] = row_result cached.pop("word_result", None) logger.info(f"OCR Pipeline: rows session {session_id}: " f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}") content_rows = sum(1 for r in rows if r.row_type == "content") avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0 await _append_pipeline_log(session_id, "rows", { "total_rows": len(rows), "content_rows": content_rows, "artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0), "avg_row_height_px": avg_height, }, duration_ms=int(duration * 1000)) return { "session_id": session_id, **row_result, } @router.post("/sessions/{session_id}/rows/manual") async def set_manual_rows(session_id: str, req: ManualRowsRequest): """Override detected rows with manual definitions.""" row_result = { "rows": req.rows, "total_rows": len(req.rows), "duration_seconds": 0, "method": "manual", } await update_session_db(session_id, row_result=row_result, word_result=None) if session_id in _cache: _cache[session_id]["row_result"] = row_result _cache[session_id].pop("word_result", None) logger.info(f"OCR Pipeline: manual rows session {session_id}: " f"{len(req.rows)} rows set") return {"session_id": session_id, **row_result} @router.post("/sessions/{session_id}/ground-truth/rows") async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest): """Save ground truth feedback for the row detection step.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} gt = { "is_correct": req.is_correct, "corrected_rows": req.corrected_rows, "notes": req.notes, "saved_at": datetime.utcnow().isoformat(), "row_result": session.get("row_result"), } ground_truth["rows"] = gt await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth return {"session_id": session_id, "ground_truth": gt} @router.get("/sessions/{session_id}/ground-truth/rows") async def get_row_ground_truth(session_id: str): """Retrieve saved ground truth for row detection.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} rows_gt = ground_truth.get("rows") if not rows_gt: raise HTTPException(status_code=404, detail="No row ground truth saved") return { "session_id": session_id, "rows_gt": rows_gt, "rows_auto": session.get("row_result"), }