Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_geometry.py
Benjamin Admin 8e4cbd84c2
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 25s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Failing after 1m59s
CI / test-python-agent-core (push) Successful in 17s
CI / test-nodejs-website (push) Successful in 17s
Invalidate grid_editor_result when exclude regions change
When exclude regions are saved or deleted, the cached grid result is
cleared so the grid rebuilds with updated exclusions on the next step.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 09:19:09 +01:00

1106 lines
40 KiB
Python

"""
OCR Pipeline Geometry API - Deskew, Dewarp, Structure Detection, Column Detection.
Extracted from ocr_pipeline_api.py to keep modules focused.
Each endpoint group handles a geometric correction or detection step:
- Deskew (Step 2): Correct scan rotation
- Dewarp (Step 3): Correct vertical shear / book warp
- Structure Detection: Boxes, zones, color regions, graphics
- Column Detection (Step 5): Find invisible columns
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import os
import time
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from cv_vocab_pipeline import (
_apply_shear,
_detect_header_footer_gaps,
_detect_sub_columns,
classify_column_types,
create_layout_image,
create_ocr_image,
analyze_layout,
deskew_image,
deskew_image_by_word_alignment,
deskew_two_pass,
detect_column_geometry_zoned,
dewarp_image,
dewarp_image_manual,
expand_narrow_columns,
)
from cv_box_detect import detect_boxes
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
from cv_graphic_detect import detect_graphic_elements
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
_filter_border_ghost_words,
ManualDeskewRequest,
DeskewGroundTruthRequest,
ManualDewarpRequest,
CombinedAdjustRequest,
DewarpGroundTruthRequest,
ManualColumnsRequest,
ColumnGroundTruthRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Deskew Endpoints (Step 2)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/deskew")
async def auto_deskew(session_id: str):
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
# Ensure session is in cache
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
# Deskew runs right after orientation — use oriented image, fall back to original
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
t0 = time.time()
# Two-pass deskew: iterative (±5°) + word-alignment residual check
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
# Also run individual methods for reporting (non-authoritative)
try:
_, angle_hough = deskew_image(img_bgr.copy())
except Exception:
angle_hough = 0.0
success_enc, png_orig = cv2.imencode(".png", img_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
angle_wa = 0.0
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
duration = time.time() - t0
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
)
# Encode as PNG
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = deskewed_png_buf.tobytes() if success else b""
# Create binarized version
binarized_png = None
try:
binarized = create_ocr_image(deskewed_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception as e:
logger.warning(f"Binarization failed: {e}")
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
deskew_result = {
"angle_hough": round(angle_hough, 3),
"angle_word_alignment": round(angle_wa, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"angle_applied": round(angle_applied, 3),
"method_used": method_used,
"confidence": round(confidence, 2),
"duration_seconds": round(duration, 2),
"two_pass_debug": two_pass_debug,
}
# Update cache
cached["deskewed_bgr"] = deskewed_bgr
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
"current_step": 3,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: deskew session {session_id}: "
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
f"textline={angle_textline:.2f} "
f"-> {method_used} total={angle_applied:.2f}")
await _append_pipeline_log(session_id, "deskew", {
"angle_applied": round(angle_applied, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"confidence": round(confidence, 2),
"method": method_used,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**deskew_result,
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
}
@router.post("/sessions/{session_id}/deskew/manual")
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
"""Apply a manual rotation angle to the oriented image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None), None)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for deskewing")
angle = max(-5.0, min(5.0, req.angle))
h, w = img_bgr.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
deskewed_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(rotated)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(angle, 3),
"method_used": "manual",
}
# Update cache
cached["deskewed_bgr"] = rotated
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
return {
"session_id": session_id,
"angle_applied": round(angle, 3),
"method_used": "manual",
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
}
@router.post("/sessions/{session_id}/ground-truth/deskew")
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
"""Save ground truth feedback for the deskew step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_angle": req.corrected_angle,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"deskew_result": session.get("deskew_result"),
}
ground_truth["deskew"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
# Update cache
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Dewarp Endpoints
# ---------------------------------------------------------------------------
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
import re
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
import json
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
@router.post("/sessions/{session_id}/dewarp")
async def auto_dewarp(
session_id: str,
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
):
"""Detect and correct vertical shear on the deskewed image.
Methods:
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
- **cv**: CV ensemble only (same as ensemble)
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
"""
if method not in ("ensemble", "cv", "vlm"):
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
t0 = time.time()
if method == "vlm":
# Encode deskewed image to PNG for VLM
success, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
duration = time.time() - t0
# Encode as PNG
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
"detections": dewarp_info.get("detections", []),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
await _append_pipeline_log(session_id, "dewarp", {
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"method": dewarp_info["method"],
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**dewarp_result,
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/dewarp/manual")
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
"""Apply shear correction with a manual angle."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
if abs(shear_deg) < 0.001:
dewarped_bgr = deskewed_bgr
else:
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
)
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
return {
"session_id": session_id,
"shear_degrees": round(shear_deg, 3),
"method_used": "manual",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/adjust-combined")
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
"""Apply rotation + shear combined to the original image.
Used by the fine-tuning sliders to preview arbitrary rotation/shear
combinations without re-running the full deskew/dewarp pipeline.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
rotation = max(-15.0, min(15.0, req.rotation_degrees))
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
h, w = img_bgr.shape[:2]
result_bgr = img_bgr
# Step 1: Apply rotation
if abs(rotation) >= 0.001:
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
# Step 2: Apply shear
if abs(shear_deg) >= 0.001:
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
# Encode
success, png_buf = cv2.imencode(".png", result_bgr)
dewarped_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(result_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
# Build combined result dicts
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(rotation, 3),
"method_used": "manual_combined",
}
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual_combined",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["deskewed_bgr"] = result_bgr
cached["dewarped_bgr"] = result_bgr
cached["deskew_result"] = deskew_result
cached["dewarp_result"] = dewarp_result
# Persist to DB
db_update = {
"dewarped_png": dewarped_png,
"deskew_result": deskew_result,
"dewarp_result": dewarp_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
db_update["deskewed_png"] = dewarped_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
return {
"session_id": session_id,
"rotation_degrees": round(rotation, 3),
"shear_degrees": round(shear_deg, 3),
"method_used": "manual_combined",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/ground-truth/dewarp")
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
"""Save ground truth feedback for the dewarp step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_shear": req.corrected_shear,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"dewarp_result": session.get("dewarp_result"),
}
ground_truth["dewarp"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Structure Detection Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-structure")
async def detect_structure(session_id: str):
"""Detect document structure: boxes, zones, and color regions.
Runs box detection (line + shading) and color analysis on the cropped
image. Returns structured JSON with all detected elements for the
structure visualization step.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = (
cached.get("cropped_bgr")
if cached.get("cropped_bgr") is not None
else cached.get("dewarped_bgr")
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
t0 = time.time()
h, w = img_bgr.shape[:2]
# --- Content bounds from word result (if available) or full image ---
word_result = cached.get("word_result")
words: List[Dict] = []
if word_result and word_result.get("cells"):
for cell in word_result["cells"]:
for wb in (cell.get("word_boxes") or []):
words.append(wb)
# Fallback: use raw OCR words if cell word_boxes are empty
if not words and word_result:
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
raw = word_result.get(key, [])
if raw:
words = raw
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
break
# If no words yet, use image dimensions with small margin
if words:
content_x = max(0, min(int(wb["left"]) for wb in words))
content_y = max(0, min(int(wb["top"]) for wb in words))
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
content_w_px = content_r - content_x
content_h_px = content_b - content_y
else:
margin = int(min(w, h) * 0.03)
content_x, content_y = margin, margin
content_w_px = w - 2 * margin
content_h_px = h - 2 * margin
# --- Box detection ---
boxes = detect_boxes(
img_bgr,
content_x=content_x,
content_w=content_w_px,
content_y=content_y,
content_h=content_h_px,
)
# --- Zone splitting ---
from cv_box_detect import split_page_into_zones as _split_zones
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
# --- Color region sampling ---
# Sample background shading in each detected box
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
box_colors = []
for box in boxes:
# Sample the center region of each box
cy1 = box.y + box.height // 4
cy2 = box.y + 3 * box.height // 4
cx1 = box.x + box.width // 4
cx2 = box.x + 3 * box.width // 4
cy1 = max(0, min(cy1, h - 1))
cy2 = max(0, min(cy2, h - 1))
cx1 = max(0, min(cx1, w - 1))
cx2 = max(0, min(cx2, w - 1))
if cy2 > cy1 and cx2 > cx1:
roi_hsv = hsv[cy1:cy2, cx1:cx2]
med_h = float(np.median(roi_hsv[:, :, 0]))
med_s = float(np.median(roi_hsv[:, :, 1]))
med_v = float(np.median(roi_hsv[:, :, 2]))
if med_s > 15:
from cv_color_detect import _hue_to_color_name
bg_name = _hue_to_color_name(med_h)
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
else:
bg_name = "gray" if med_v < 220 else "white"
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
else:
bg_name = "unknown"
bg_hex = "#6b7280"
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
# --- Color text detection overview ---
# Quick scan for colored text regions across the page
color_summary: Dict[str, int] = {}
for color_name, ranges in _COLOR_RANGES.items():
mask = np.zeros((h, w), dtype=np.uint8)
for lower, upper in ranges:
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
pixel_count = int(np.sum(mask > 0))
if pixel_count > 50: # minimum threshold
color_summary[color_name] = pixel_count
# --- Graphic element detection ---
box_dicts = [
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
for b in boxes
]
graphics = detect_graphic_elements(
img_bgr, words,
detected_boxes=box_dicts,
)
# --- Filter border-ghost words from OCR result ---
ghost_count = 0
if boxes and word_result:
ghost_count = _filter_border_ghost_words(word_result, boxes)
if ghost_count:
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
await update_session_db(session_id, word_result=word_result)
cached["word_result"] = word_result
duration = time.time() - t0
# Preserve user-drawn exclude regions from previous run
prev_sr = cached.get("structure_result") or {}
prev_exclude = prev_sr.get("exclude_regions", [])
result_dict = {
"image_width": w,
"image_height": h,
"content_bounds": {
"x": content_x, "y": content_y,
"w": content_w_px, "h": content_h_px,
},
"boxes": [
{
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
"confidence": b.confidence,
"border_thickness": b.border_thickness,
"bg_color_name": box_colors[i]["color_name"],
"bg_color_hex": box_colors[i]["color_hex"],
}
for i, b in enumerate(boxes)
],
"zones": [
{
"index": z.index,
"zone_type": z.zone_type,
"y": z.y, "h": z.height,
"x": z.x, "w": z.width,
}
for z in zones
],
"graphics": [
{
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
"area": g.area,
"shape": g.shape,
"color_name": g.color_name,
"color_hex": g.color_hex,
"confidence": round(g.confidence, 2),
}
for g in graphics
],
"exclude_regions": prev_exclude,
"color_pixel_counts": color_summary,
"has_words": len(words) > 0,
"word_count": len(words),
"border_ghosts_removed": ghost_count,
"duration_seconds": round(duration, 2),
}
# Persist to session
await update_session_db(session_id, structure_result=result_dict)
cached["structure_result"] = result_dict
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
session_id, len(boxes), len(zones), len(graphics), duration)
return {"session_id": session_id, **result_dict}
# ---------------------------------------------------------------------------
# Exclude Regions — user-drawn rectangles to exclude from OCR results
# ---------------------------------------------------------------------------
class _ExcludeRegionIn(BaseModel):
x: int
y: int
w: int
h: int
label: str = ""
class _ExcludeRegionsBatchIn(BaseModel):
regions: list[_ExcludeRegionIn]
@router.put("/sessions/{session_id}/exclude-regions")
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
"""Replace all exclude regions for a session.
Regions are stored inside ``structure_result.exclude_regions``.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
# Update cache
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"exclude_regions": sr["exclude_regions"],
"count": len(sr["exclude_regions"]),
}
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
async def delete_exclude_region(session_id: str, region_index: int):
"""Remove a single exclude region by index."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
sr = session.get("structure_result") or {}
regions = sr.get("exclude_regions", [])
if region_index < 0 or region_index >= len(regions):
raise HTTPException(status_code=404, detail="Region index out of range")
removed = regions.pop(region_index)
sr["exclude_regions"] = regions
# Invalidate grid so it rebuilds with new exclude regions
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
if session_id in _cache:
_cache[session_id]["structure_result"] = sr
_cache[session_id].pop("grid_editor_result", None)
return {
"session_id": session_id,
"removed": removed,
"remaining": len(regions),
}
# ---------------------------------------------------------------------------
# Column Detection Endpoints (Step 3)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/columns")
async def detect_columns(session_id: str):
"""Run column detection on the cropped (or dewarped) image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
# -----------------------------------------------------------------------
# Sub-sessions (box crops): skip column detection entirely.
# Instead, create a single pseudo-column spanning the full image width.
# Also run Tesseract + binarization here so that the row detection step
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
# instead of falling back to detect_column_geometry() which may fail
# on small box images with < 5 words.
# -----------------------------------------------------------------------
session = await get_session_db(session_id)
if session and session.get("parent_session_id"):
h, w = img_bgr.shape[:2]
# Binarize + invert for row detection (horizontal projection profile)
ocr_img = create_ocr_image(img_bgr)
inv = cv2.bitwise_not(ocr_img)
# Run Tesseract to get word bounding boxes.
# Word positions are relative to the full image (no ROI crop needed
# because the sub-session image IS the cropped box already).
# detect_row_geometry expects word positions relative to content ROI,
# so with content_bounds = (0, w, 0, h) the coordinates are correct.
try:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
import pytesseract
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
word_dicts = []
for i in range(len(data['text'])):
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
text = str(data['text'][i]).strip()
if conf < 30 or not text:
continue
word_dicts.append({
'text': text, 'conf': conf,
'left': int(data['left'][i]),
'top': int(data['top'][i]),
'width': int(data['width'][i]),
'height': int(data['height'][i]),
})
# Log all words including low-confidence ones for debugging
all_count = sum(1 for i in range(len(data['text']))
if str(data['text'][i]).strip())
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
for i in range(len(data['text']))
if str(data['text'][i]).strip()
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
if low_conf:
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
except Exception as e:
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
word_dicts = []
# Cache intermediates for row detection (detect_rows reuses these)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (0, w, 0, h)
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": w, "height": h,
}],
"zones": None,
"boxes_detected": 0,
"duration_seconds": 0,
"method": "sub_session_pseudo_column",
}
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
return {"session_id": session_id, **column_result}
t0 = time.time()
# Binarized image for layout analysis
ocr_img = create_ocr_image(img_bgr)
h, w = ocr_img.shape[:2]
# Phase A: Zone-aware geometry detection
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
if zoned_result is None:
# Fallback to projection-based layout
layout_img = create_layout_image(img_bgr)
regions = analyze_layout(layout_img, ocr_img)
zones_data = None
boxes_detected = 0
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
content_w = right_x - left_x
boxes_detected = len(boxes)
# Cache intermediates for row detection (avoids second Tesseract run)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
cached["_zones_data"] = zones_data
cached["_boxes_detected"] = boxes_detected
# Detect header/footer early so sub-column clustering ignores them
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
# Split sub-columns (e.g. page references) before classification
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
# Expand narrow columns (sub-columns are often very narrow)
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
# Phase B: Content-based classification
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
duration = time.time() - t0
columns = [asdict(r) for r in regions]
# Determine classification methods used
methods = list(set(
c.get("classification_method", "") for c in columns
if c.get("classification_method")
))
column_result = {
"columns": columns,
"classification_methods": methods,
"duration_seconds": round(duration, 2),
"boxes_detected": boxes_detected,
}
# Add zone data when boxes are present
if zones_data and boxes_detected > 0:
column_result["zones"] = zones_data
# Persist to DB — also invalidate downstream results (rows, words)
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=6,
)
# Update cache
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
col_count = len([c for c in columns if c["type"].startswith("column")])
logger.info(f"OCR Pipeline: columns session {session_id}: "
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
img_w = img_bgr.shape[1]
await _append_pipeline_log(session_id, "columns", {
"total_columns": len(columns),
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
"column_types": [c["type"] for c in columns],
"boxes_detected": boxes_detected,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**column_result,
}
@router.post("/sessions/{session_id}/columns/manual")
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
"""Override detected columns with manual definitions."""
column_result = {
"columns": req.columns,
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None)
if session_id in _cache:
_cache[session_id]["column_result"] = column_result
_cache[session_id].pop("row_result", None)
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
f"{len(req.columns)} columns set")
return {"session_id": session_id, **column_result}
@router.post("/sessions/{session_id}/ground-truth/columns")
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
"""Save ground truth feedback for the column detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_columns": req.corrected_columns,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"column_result": session.get("column_result"),
}
ground_truth["columns"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/columns")
async def get_column_ground_truth(session_id: str):
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
columns_gt = ground_truth.get("columns")
if not columns_gt:
raise HTTPException(status_code=404, detail="No column ground truth saved")
return {
"session_id": session_id,
"columns_gt": columns_gt,
"columns_auto": session.get("column_result"),
}