""" OCR Pipeline Column Detection Endpoints (Step 5) Detect invisible columns, manual column override, and ground truth. Extracted from ocr_pipeline_geometry.py for file-size compliance. """ import logging import time from dataclasses import asdict from datetime import datetime from typing import Dict, List import cv2 from fastapi import APIRouter, HTTPException from cv_vocab_pipeline import ( _detect_header_footer_gaps, _detect_sub_columns, classify_column_types, create_layout_image, create_ocr_image, analyze_layout, detect_column_geometry_zoned, expand_narrow_columns, ) from ocr_pipeline_session_store import ( get_session_db, update_session_db, ) from ocr_pipeline_common import ( _cache, _load_session_to_cache, _get_cached, _append_pipeline_log, ManualColumnsRequest, ColumnGroundTruthRequest, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) @router.post("/sessions/{session_id}/columns") async def detect_columns(session_id: str): """Run column detection on the cropped (or dewarped) image.""" if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") if img_bgr is None: raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection") # ----------------------------------------------------------------------- # Sub-sessions (box crops): skip column detection entirely. # Instead, create a single pseudo-column spanning the full image width. # Also run Tesseract + binarization here so that the row detection step # can reuse the cached intermediates (_word_dicts, _inv, _content_bounds) # instead of falling back to detect_column_geometry() which may fail # on small box images with < 5 words. # ----------------------------------------------------------------------- session = await get_session_db(session_id) if session and session.get("parent_session_id"): h, w = img_bgr.shape[:2] # Binarize + invert for row detection (horizontal projection profile) ocr_img = create_ocr_image(img_bgr) inv = cv2.bitwise_not(ocr_img) # Run Tesseract to get word bounding boxes. try: from PIL import Image as PILImage pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) import pytesseract data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT) word_dicts = [] for i in range(len(data['text'])): conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1 text = str(data['text'][i]).strip() if conf < 30 or not text: continue word_dicts.append({ 'text': text, 'conf': conf, 'left': int(data['left'][i]), 'top': int(data['top'][i]), 'width': int(data['width'][i]), 'height': int(data['height'][i]), }) # Log all words including low-confidence ones for debugging all_count = sum(1 for i in range(len(data['text'])) if str(data['text'][i]).strip()) low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) for i in range(len(data['text'])) if str(data['text'][i]).strip() and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30 and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0] if low_conf: logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}") logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)") except Exception as e: logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}") word_dicts = [] # Cache intermediates for row detection (detect_rows reuses these) cached["_word_dicts"] = word_dicts cached["_inv"] = inv cached["_content_bounds"] = (0, w, 0, h) column_result = { "columns": [{ "type": "column_text", "x": 0, "y": 0, "width": w, "height": h, }], "zones": None, "boxes_detected": 0, "duration_seconds": 0, "method": "sub_session_pseudo_column", } await update_session_db( session_id, column_result=column_result, row_result=None, word_result=None, current_step=6, ) cached["column_result"] = column_result cached.pop("row_result", None) cached.pop("word_result", None) logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px") return {"session_id": session_id, **column_result} t0 = time.time() # Binarized image for layout analysis ocr_img = create_ocr_image(img_bgr) h, w = ocr_img.shape[:2] # Phase A: Zone-aware geometry detection zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr) boxes_detected = 0 if zoned_result is None: # Fallback to projection-based layout layout_img = create_layout_image(img_bgr) regions = analyze_layout(layout_img, ocr_img) zones_data = None else: geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result content_w = right_x - left_x boxes_detected = len(boxes) # Cache intermediates for row detection (avoids second Tesseract run) cached["_word_dicts"] = word_dicts cached["_inv"] = inv cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) cached["_zones_data"] = zones_data cached["_boxes_detected"] = boxes_detected # Detect header/footer early so sub-column clustering ignores them header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) # Split sub-columns (e.g. page references) before classification geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, top_y=top_y, header_y=header_y, footer_y=footer_y) # Expand narrow columns (sub-columns are often very narrow) geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts) # Phase B: Content-based classification regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, left_x=left_x, right_x=right_x, inv=inv) duration = time.time() - t0 columns = [asdict(r) for r in regions] # Determine classification methods used methods = list(set( c.get("classification_method", "") for c in columns if c.get("classification_method") )) column_result = { "columns": columns, "classification_methods": methods, "duration_seconds": round(duration, 2), "boxes_detected": boxes_detected, } # Add zone data when boxes are present if zones_data and boxes_detected > 0: column_result["zones"] = zones_data # Persist to DB -- also invalidate downstream results (rows, words) await update_session_db( session_id, column_result=column_result, row_result=None, word_result=None, current_step=6, ) # Update cache cached["column_result"] = column_result cached.pop("row_result", None) cached.pop("word_result", None) col_count = len([c for c in columns if c["type"].startswith("column")]) logger.info(f"OCR Pipeline: columns session {session_id}: " f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)") img_w = img_bgr.shape[1] await _append_pipeline_log(session_id, "columns", { "total_columns": len(columns), "column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns], "column_types": [c["type"] for c in columns], "boxes_detected": boxes_detected, }, duration_ms=int(duration * 1000)) return { "session_id": session_id, **column_result, } @router.post("/sessions/{session_id}/columns/manual") async def set_manual_columns(session_id: str, req: ManualColumnsRequest): """Override detected columns with manual definitions.""" column_result = { "columns": req.columns, "duration_seconds": 0, "method": "manual", } await update_session_db(session_id, column_result=column_result, row_result=None, word_result=None) if session_id in _cache: _cache[session_id]["column_result"] = column_result _cache[session_id].pop("row_result", None) _cache[session_id].pop("word_result", None) logger.info(f"OCR Pipeline: manual columns session {session_id}: " f"{len(req.columns)} columns set") return {"session_id": session_id, **column_result} @router.post("/sessions/{session_id}/ground-truth/columns") async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): """Save ground truth feedback for the column detection step.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} gt = { "is_correct": req.is_correct, "corrected_columns": req.corrected_columns, "notes": req.notes, "saved_at": datetime.utcnow().isoformat(), "column_result": session.get("column_result"), } ground_truth["columns"] = gt await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth return {"session_id": session_id, "ground_truth": gt} @router.get("/sessions/{session_id}/ground-truth/columns") async def get_column_ground_truth(session_id: str): """Retrieve saved ground truth for column detection, including auto vs GT diff.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} columns_gt = ground_truth.get("columns") if not columns_gt: raise HTTPException(status_code=404, detail="No column ground truth saved") return { "session_id": session_id, "columns_gt": columns_gt, "columns_auto": session.get("column_result"), }