Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 25s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Failing after 1m59s
CI / test-python-agent-core (push) Successful in 17s
CI / test-nodejs-website (push) Successful in 17s
When exclude regions are saved or deleted, the cached grid result is cleared so the grid rebuilds with updated exclusions on the next step. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1106 lines
40 KiB
Python
1106 lines
40 KiB
Python
"""
|
|
OCR Pipeline Geometry API - Deskew, Dewarp, Structure Detection, Column Detection.
|
|
|
|
Extracted from ocr_pipeline_api.py to keep modules focused.
|
|
Each endpoint group handles a geometric correction or detection step:
|
|
- Deskew (Step 2): Correct scan rotation
|
|
- Dewarp (Step 3): Correct vertical shear / book warp
|
|
- Structure Detection: Boxes, zones, color regions, graphics
|
|
- Column Detection (Step 5): Find invisible columns
|
|
|
|
Lizenz: Apache 2.0
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import asdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
|
|
from cv_vocab_pipeline import (
|
|
_apply_shear,
|
|
_detect_header_footer_gaps,
|
|
_detect_sub_columns,
|
|
classify_column_types,
|
|
create_layout_image,
|
|
create_ocr_image,
|
|
analyze_layout,
|
|
deskew_image,
|
|
deskew_image_by_word_alignment,
|
|
deskew_two_pass,
|
|
detect_column_geometry_zoned,
|
|
dewarp_image,
|
|
dewarp_image_manual,
|
|
expand_narrow_columns,
|
|
)
|
|
from cv_box_detect import detect_boxes
|
|
from cv_color_detect import _COLOR_RANGES, _COLOR_HEX
|
|
from cv_graphic_detect import detect_graphic_elements
|
|
from ocr_pipeline_session_store import (
|
|
get_session_db,
|
|
update_session_db,
|
|
)
|
|
from ocr_pipeline_common import (
|
|
_cache,
|
|
_load_session_to_cache,
|
|
_get_cached,
|
|
_get_base_image_png,
|
|
_append_pipeline_log,
|
|
_filter_border_ghost_words,
|
|
ManualDeskewRequest,
|
|
DeskewGroundTruthRequest,
|
|
ManualDewarpRequest,
|
|
CombinedAdjustRequest,
|
|
DewarpGroundTruthRequest,
|
|
ManualColumnsRequest,
|
|
ColumnGroundTruthRequest,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Deskew Endpoints (Step 2)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/deskew")
|
|
async def auto_deskew(session_id: str):
|
|
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
|
|
# Ensure session is in cache
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
# Deskew runs right after orientation — use oriented image, fall back to original
|
|
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
|
if (v := cached.get(k)) is not None), None)
|
|
if img_bgr is None:
|
|
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
|
|
|
t0 = time.time()
|
|
|
|
# Two-pass deskew: iterative (±5°) + word-alignment residual check
|
|
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
|
|
|
|
# Also run individual methods for reporting (non-authoritative)
|
|
try:
|
|
_, angle_hough = deskew_image(img_bgr.copy())
|
|
except Exception:
|
|
angle_hough = 0.0
|
|
|
|
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
|
orig_bytes = png_orig.tobytes() if success_enc else b""
|
|
try:
|
|
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
|
except Exception:
|
|
angle_wa = 0.0
|
|
|
|
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
|
|
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
|
|
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
|
|
|
|
duration = time.time() - t0
|
|
|
|
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
|
|
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
|
|
)
|
|
|
|
# Encode as PNG
|
|
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
|
|
|
# Create binarized version
|
|
binarized_png = None
|
|
try:
|
|
binarized = create_ocr_image(deskewed_bgr)
|
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
except Exception as e:
|
|
logger.warning(f"Binarization failed: {e}")
|
|
|
|
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
|
|
|
deskew_result = {
|
|
"angle_hough": round(angle_hough, 3),
|
|
"angle_word_alignment": round(angle_wa, 3),
|
|
"angle_iterative": round(angle_iterative, 3),
|
|
"angle_residual": round(angle_residual, 3),
|
|
"angle_textline": round(angle_textline, 3),
|
|
"angle_applied": round(angle_applied, 3),
|
|
"method_used": method_used,
|
|
"confidence": round(confidence, 2),
|
|
"duration_seconds": round(duration, 2),
|
|
"two_pass_debug": two_pass_debug,
|
|
}
|
|
|
|
# Update cache
|
|
cached["deskewed_bgr"] = deskewed_bgr
|
|
cached["binarized_png"] = binarized_png
|
|
cached["deskew_result"] = deskew_result
|
|
|
|
# Persist to DB
|
|
db_update = {
|
|
"deskewed_png": deskewed_png,
|
|
"deskew_result": deskew_result,
|
|
"current_step": 3,
|
|
}
|
|
if binarized_png:
|
|
db_update["binarized_png"] = binarized_png
|
|
await update_session_db(session_id, **db_update)
|
|
|
|
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
|
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
|
|
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
|
|
f"textline={angle_textline:.2f} "
|
|
f"-> {method_used} total={angle_applied:.2f}")
|
|
|
|
await _append_pipeline_log(session_id, "deskew", {
|
|
"angle_applied": round(angle_applied, 3),
|
|
"angle_iterative": round(angle_iterative, 3),
|
|
"angle_residual": round(angle_residual, 3),
|
|
"angle_textline": round(angle_textline, 3),
|
|
"confidence": round(confidence, 2),
|
|
"method": method_used,
|
|
}, duration_ms=int(duration * 1000))
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
**deskew_result,
|
|
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
|
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/deskew/manual")
|
|
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
|
"""Apply a manual rotation angle to the oriented image."""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
|
if (v := cached.get(k)) is not None), None)
|
|
if img_bgr is None:
|
|
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
|
|
|
angle = max(-5.0, min(5.0, req.angle))
|
|
|
|
h, w = img_bgr.shape[:2]
|
|
center = (w // 2, h // 2)
|
|
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
|
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
|
flags=cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_REPLICATE)
|
|
|
|
success, png_buf = cv2.imencode(".png", rotated)
|
|
deskewed_png = png_buf.tobytes() if success else b""
|
|
|
|
# Binarize
|
|
binarized_png = None
|
|
try:
|
|
binarized = create_ocr_image(rotated)
|
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
except Exception:
|
|
pass
|
|
|
|
deskew_result = {
|
|
**(cached.get("deskew_result") or {}),
|
|
"angle_applied": round(angle, 3),
|
|
"method_used": "manual",
|
|
}
|
|
|
|
# Update cache
|
|
cached["deskewed_bgr"] = rotated
|
|
cached["binarized_png"] = binarized_png
|
|
cached["deskew_result"] = deskew_result
|
|
|
|
# Persist to DB
|
|
db_update = {
|
|
"deskewed_png": deskewed_png,
|
|
"deskew_result": deskew_result,
|
|
}
|
|
if binarized_png:
|
|
db_update["binarized_png"] = binarized_png
|
|
await update_session_db(session_id, **db_update)
|
|
|
|
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"angle_applied": round(angle, 3),
|
|
"method_used": "manual",
|
|
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
|
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
|
"""Save ground truth feedback for the deskew step."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
gt = {
|
|
"is_correct": req.is_correct,
|
|
"corrected_angle": req.corrected_angle,
|
|
"notes": req.notes,
|
|
"saved_at": datetime.utcnow().isoformat(),
|
|
"deskew_result": session.get("deskew_result"),
|
|
}
|
|
ground_truth["deskew"] = gt
|
|
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
# Update cache
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
|
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
|
|
|
return {"session_id": session_id, "ground_truth": gt}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dewarp Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
|
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
|
|
|
The VLM is shown the image and asked: are the column/table borders tilted?
|
|
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
|
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
|
"""
|
|
import httpx
|
|
import base64
|
|
import re
|
|
|
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
|
|
|
prompt = (
|
|
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
|
"Are they perfectly vertical, or do they tilt slightly? "
|
|
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
|
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
|
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
|
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
|
)
|
|
|
|
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"images": [img_b64],
|
|
"stream": False,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
|
resp.raise_for_status()
|
|
text = resp.json().get("response", "")
|
|
|
|
# Parse JSON from response (may have surrounding text)
|
|
match = re.search(r'\{[^}]+\}', text)
|
|
if match:
|
|
import json
|
|
data = json.loads(match.group(0))
|
|
shear = float(data.get("shear_degrees", 0.0))
|
|
conf = float(data.get("confidence", 0.0))
|
|
# Clamp to reasonable range
|
|
shear = max(-3.0, min(3.0, shear))
|
|
conf = max(0.0, min(1.0, conf))
|
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
|
except Exception as e:
|
|
logger.warning(f"VLM dewarp failed: {e}")
|
|
|
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/dewarp")
|
|
async def auto_dewarp(
|
|
session_id: str,
|
|
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
|
|
):
|
|
"""Detect and correct vertical shear on the deskewed image.
|
|
|
|
Methods:
|
|
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
|
|
- **cv**: CV ensemble only (same as ensemble)
|
|
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
|
|
"""
|
|
if method not in ("ensemble", "cv", "vlm"):
|
|
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
|
|
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
deskewed_bgr = cached.get("deskewed_bgr")
|
|
if deskewed_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
|
|
|
t0 = time.time()
|
|
|
|
if method == "vlm":
|
|
# Encode deskewed image to PNG for VLM
|
|
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
|
img_bytes = png_buf.tobytes() if success else b""
|
|
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
|
shear_deg = vlm_det["shear_degrees"]
|
|
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
|
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
|
else:
|
|
dewarped_bgr = deskewed_bgr
|
|
dewarp_info = {
|
|
"method": vlm_det["method"],
|
|
"shear_degrees": shear_deg,
|
|
"confidence": vlm_det["confidence"],
|
|
"detections": [vlm_det],
|
|
}
|
|
else:
|
|
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
|
|
|
duration = time.time() - t0
|
|
|
|
# Encode as PNG
|
|
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
|
dewarped_png = png_buf.tobytes() if success else b""
|
|
|
|
dewarp_result = {
|
|
"method_used": dewarp_info["method"],
|
|
"shear_degrees": dewarp_info["shear_degrees"],
|
|
"confidence": dewarp_info["confidence"],
|
|
"duration_seconds": round(duration, 2),
|
|
"detections": dewarp_info.get("detections", []),
|
|
}
|
|
|
|
# Update cache
|
|
cached["dewarped_bgr"] = dewarped_bgr
|
|
cached["dewarp_result"] = dewarp_result
|
|
|
|
# Persist to DB
|
|
await update_session_db(
|
|
session_id,
|
|
dewarped_png=dewarped_png,
|
|
dewarp_result=dewarp_result,
|
|
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
|
current_step=4,
|
|
)
|
|
|
|
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
|
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
|
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
|
|
|
await _append_pipeline_log(session_id, "dewarp", {
|
|
"shear_degrees": dewarp_info["shear_degrees"],
|
|
"confidence": dewarp_info["confidence"],
|
|
"method": dewarp_info["method"],
|
|
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
|
|
}, duration_ms=int(duration * 1000))
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
**dewarp_result,
|
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/dewarp/manual")
|
|
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
|
"""Apply shear correction with a manual angle."""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
deskewed_bgr = cached.get("deskewed_bgr")
|
|
if deskewed_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
|
|
|
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
|
|
|
if abs(shear_deg) < 0.001:
|
|
dewarped_bgr = deskewed_bgr
|
|
else:
|
|
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
|
|
|
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
|
dewarped_png = png_buf.tobytes() if success else b""
|
|
|
|
dewarp_result = {
|
|
**(cached.get("dewarp_result") or {}),
|
|
"method_used": "manual",
|
|
"shear_degrees": round(shear_deg, 3),
|
|
}
|
|
|
|
# Update cache
|
|
cached["dewarped_bgr"] = dewarped_bgr
|
|
cached["dewarp_result"] = dewarp_result
|
|
|
|
# Persist to DB
|
|
await update_session_db(
|
|
session_id,
|
|
dewarped_png=dewarped_png,
|
|
dewarp_result=dewarp_result,
|
|
)
|
|
|
|
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"shear_degrees": round(shear_deg, 3),
|
|
"method_used": "manual",
|
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/adjust-combined")
|
|
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
|
|
"""Apply rotation + shear combined to the original image.
|
|
|
|
Used by the fine-tuning sliders to preview arbitrary rotation/shear
|
|
combinations without re-running the full deskew/dewarp pipeline.
|
|
"""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
img_bgr = cached.get("original_bgr")
|
|
if img_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Original image not available")
|
|
|
|
rotation = max(-15.0, min(15.0, req.rotation_degrees))
|
|
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
|
|
|
|
h, w = img_bgr.shape[:2]
|
|
result_bgr = img_bgr
|
|
|
|
# Step 1: Apply rotation
|
|
if abs(rotation) >= 0.001:
|
|
center = (w // 2, h // 2)
|
|
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
|
|
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
|
|
flags=cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_REPLICATE)
|
|
|
|
# Step 2: Apply shear
|
|
if abs(shear_deg) >= 0.001:
|
|
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
|
|
|
|
# Encode
|
|
success, png_buf = cv2.imencode(".png", result_bgr)
|
|
dewarped_png = png_buf.tobytes() if success else b""
|
|
|
|
# Binarize
|
|
binarized_png = None
|
|
try:
|
|
binarized = create_ocr_image(result_bgr)
|
|
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
|
binarized_png = bin_buf.tobytes() if success_bin else None
|
|
except Exception:
|
|
pass
|
|
|
|
# Build combined result dicts
|
|
deskew_result = {
|
|
**(cached.get("deskew_result") or {}),
|
|
"angle_applied": round(rotation, 3),
|
|
"method_used": "manual_combined",
|
|
}
|
|
dewarp_result = {
|
|
**(cached.get("dewarp_result") or {}),
|
|
"method_used": "manual_combined",
|
|
"shear_degrees": round(shear_deg, 3),
|
|
}
|
|
|
|
# Update cache
|
|
cached["deskewed_bgr"] = result_bgr
|
|
cached["dewarped_bgr"] = result_bgr
|
|
cached["deskew_result"] = deskew_result
|
|
cached["dewarp_result"] = dewarp_result
|
|
|
|
# Persist to DB
|
|
db_update = {
|
|
"dewarped_png": dewarped_png,
|
|
"deskew_result": deskew_result,
|
|
"dewarp_result": dewarp_result,
|
|
}
|
|
if binarized_png:
|
|
db_update["binarized_png"] = binarized_png
|
|
db_update["deskewed_png"] = dewarped_png
|
|
await update_session_db(session_id, **db_update)
|
|
|
|
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
|
|
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"rotation_degrees": round(rotation, 3),
|
|
"shear_degrees": round(shear_deg, 3),
|
|
"method_used": "manual_combined",
|
|
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
|
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
|
"""Save ground truth feedback for the dewarp step."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
gt = {
|
|
"is_correct": req.is_correct,
|
|
"corrected_shear": req.corrected_shear,
|
|
"notes": req.notes,
|
|
"saved_at": datetime.utcnow().isoformat(),
|
|
"dewarp_result": session.get("dewarp_result"),
|
|
}
|
|
ground_truth["dewarp"] = gt
|
|
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
|
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
|
|
|
return {"session_id": session_id, "ground_truth": gt}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Structure Detection Endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/detect-structure")
|
|
async def detect_structure(session_id: str):
|
|
"""Detect document structure: boxes, zones, and color regions.
|
|
|
|
Runs box detection (line + shading) and color analysis on the cropped
|
|
image. Returns structured JSON with all detected elements for the
|
|
structure visualization step.
|
|
"""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
img_bgr = (
|
|
cached.get("cropped_bgr")
|
|
if cached.get("cropped_bgr") is not None
|
|
else cached.get("dewarped_bgr")
|
|
)
|
|
if img_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
|
|
|
t0 = time.time()
|
|
h, w = img_bgr.shape[:2]
|
|
|
|
# --- Content bounds from word result (if available) or full image ---
|
|
word_result = cached.get("word_result")
|
|
words: List[Dict] = []
|
|
if word_result and word_result.get("cells"):
|
|
for cell in word_result["cells"]:
|
|
for wb in (cell.get("word_boxes") or []):
|
|
words.append(wb)
|
|
# Fallback: use raw OCR words if cell word_boxes are empty
|
|
if not words and word_result:
|
|
for key in ("raw_paddle_words_split", "raw_tesseract_words", "raw_paddle_words"):
|
|
raw = word_result.get(key, [])
|
|
if raw:
|
|
words = raw
|
|
logger.info("detect-structure: using %d words from %s (no cell word_boxes)", len(words), key)
|
|
break
|
|
# If no words yet, use image dimensions with small margin
|
|
if words:
|
|
content_x = max(0, min(int(wb["left"]) for wb in words))
|
|
content_y = max(0, min(int(wb["top"]) for wb in words))
|
|
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
|
|
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
|
|
content_w_px = content_r - content_x
|
|
content_h_px = content_b - content_y
|
|
else:
|
|
margin = int(min(w, h) * 0.03)
|
|
content_x, content_y = margin, margin
|
|
content_w_px = w - 2 * margin
|
|
content_h_px = h - 2 * margin
|
|
|
|
# --- Box detection ---
|
|
boxes = detect_boxes(
|
|
img_bgr,
|
|
content_x=content_x,
|
|
content_w=content_w_px,
|
|
content_y=content_y,
|
|
content_h=content_h_px,
|
|
)
|
|
|
|
# --- Zone splitting ---
|
|
from cv_box_detect import split_page_into_zones as _split_zones
|
|
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
|
|
|
# --- Color region sampling ---
|
|
# Sample background shading in each detected box
|
|
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
|
box_colors = []
|
|
for box in boxes:
|
|
# Sample the center region of each box
|
|
cy1 = box.y + box.height // 4
|
|
cy2 = box.y + 3 * box.height // 4
|
|
cx1 = box.x + box.width // 4
|
|
cx2 = box.x + 3 * box.width // 4
|
|
cy1 = max(0, min(cy1, h - 1))
|
|
cy2 = max(0, min(cy2, h - 1))
|
|
cx1 = max(0, min(cx1, w - 1))
|
|
cx2 = max(0, min(cx2, w - 1))
|
|
if cy2 > cy1 and cx2 > cx1:
|
|
roi_hsv = hsv[cy1:cy2, cx1:cx2]
|
|
med_h = float(np.median(roi_hsv[:, :, 0]))
|
|
med_s = float(np.median(roi_hsv[:, :, 1]))
|
|
med_v = float(np.median(roi_hsv[:, :, 2]))
|
|
if med_s > 15:
|
|
from cv_color_detect import _hue_to_color_name
|
|
bg_name = _hue_to_color_name(med_h)
|
|
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
|
|
else:
|
|
bg_name = "gray" if med_v < 220 else "white"
|
|
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
|
|
else:
|
|
bg_name = "unknown"
|
|
bg_hex = "#6b7280"
|
|
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
|
|
|
|
# --- Color text detection overview ---
|
|
# Quick scan for colored text regions across the page
|
|
color_summary: Dict[str, int] = {}
|
|
for color_name, ranges in _COLOR_RANGES.items():
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
for lower, upper in ranges:
|
|
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
|
pixel_count = int(np.sum(mask > 0))
|
|
if pixel_count > 50: # minimum threshold
|
|
color_summary[color_name] = pixel_count
|
|
|
|
# --- Graphic element detection ---
|
|
box_dicts = [
|
|
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
|
|
for b in boxes
|
|
]
|
|
graphics = detect_graphic_elements(
|
|
img_bgr, words,
|
|
detected_boxes=box_dicts,
|
|
)
|
|
|
|
# --- Filter border-ghost words from OCR result ---
|
|
ghost_count = 0
|
|
if boxes and word_result:
|
|
ghost_count = _filter_border_ghost_words(word_result, boxes)
|
|
if ghost_count:
|
|
logger.info("detect-structure: removed %d border-ghost words", ghost_count)
|
|
await update_session_db(session_id, word_result=word_result)
|
|
cached["word_result"] = word_result
|
|
|
|
duration = time.time() - t0
|
|
|
|
# Preserve user-drawn exclude regions from previous run
|
|
prev_sr = cached.get("structure_result") or {}
|
|
prev_exclude = prev_sr.get("exclude_regions", [])
|
|
|
|
result_dict = {
|
|
"image_width": w,
|
|
"image_height": h,
|
|
"content_bounds": {
|
|
"x": content_x, "y": content_y,
|
|
"w": content_w_px, "h": content_h_px,
|
|
},
|
|
"boxes": [
|
|
{
|
|
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
|
"confidence": b.confidence,
|
|
"border_thickness": b.border_thickness,
|
|
"bg_color_name": box_colors[i]["color_name"],
|
|
"bg_color_hex": box_colors[i]["color_hex"],
|
|
}
|
|
for i, b in enumerate(boxes)
|
|
],
|
|
"zones": [
|
|
{
|
|
"index": z.index,
|
|
"zone_type": z.zone_type,
|
|
"y": z.y, "h": z.height,
|
|
"x": z.x, "w": z.width,
|
|
}
|
|
for z in zones
|
|
],
|
|
"graphics": [
|
|
{
|
|
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
|
|
"area": g.area,
|
|
"shape": g.shape,
|
|
"color_name": g.color_name,
|
|
"color_hex": g.color_hex,
|
|
"confidence": round(g.confidence, 2),
|
|
}
|
|
for g in graphics
|
|
],
|
|
"exclude_regions": prev_exclude,
|
|
"color_pixel_counts": color_summary,
|
|
"has_words": len(words) > 0,
|
|
"word_count": len(words),
|
|
"border_ghosts_removed": ghost_count,
|
|
"duration_seconds": round(duration, 2),
|
|
}
|
|
|
|
# Persist to session
|
|
await update_session_db(session_id, structure_result=result_dict)
|
|
cached["structure_result"] = result_dict
|
|
|
|
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
|
|
session_id, len(boxes), len(zones), len(graphics), duration)
|
|
|
|
return {"session_id": session_id, **result_dict}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Exclude Regions — user-drawn rectangles to exclude from OCR results
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _ExcludeRegionIn(BaseModel):
|
|
x: int
|
|
y: int
|
|
w: int
|
|
h: int
|
|
label: str = ""
|
|
|
|
|
|
class _ExcludeRegionsBatchIn(BaseModel):
|
|
regions: list[_ExcludeRegionIn]
|
|
|
|
|
|
@router.put("/sessions/{session_id}/exclude-regions")
|
|
async def set_exclude_regions(session_id: str, body: _ExcludeRegionsBatchIn):
|
|
"""Replace all exclude regions for a session.
|
|
|
|
Regions are stored inside ``structure_result.exclude_regions``.
|
|
"""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
sr = session.get("structure_result") or {}
|
|
sr["exclude_regions"] = [r.model_dump() for r in body.regions]
|
|
|
|
# Invalidate grid so it rebuilds with new exclude regions
|
|
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
|
|
|
# Update cache
|
|
if session_id in _cache:
|
|
_cache[session_id]["structure_result"] = sr
|
|
_cache[session_id].pop("grid_editor_result", None)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"exclude_regions": sr["exclude_regions"],
|
|
"count": len(sr["exclude_regions"]),
|
|
}
|
|
|
|
|
|
@router.delete("/sessions/{session_id}/exclude-regions/{region_index}")
|
|
async def delete_exclude_region(session_id: str, region_index: int):
|
|
"""Remove a single exclude region by index."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
sr = session.get("structure_result") or {}
|
|
regions = sr.get("exclude_regions", [])
|
|
|
|
if region_index < 0 or region_index >= len(regions):
|
|
raise HTTPException(status_code=404, detail="Region index out of range")
|
|
|
|
removed = regions.pop(region_index)
|
|
sr["exclude_regions"] = regions
|
|
|
|
# Invalidate grid so it rebuilds with new exclude regions
|
|
await update_session_db(session_id, structure_result=sr, grid_editor_result=None)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["structure_result"] = sr
|
|
_cache[session_id].pop("grid_editor_result", None)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"removed": removed,
|
|
"remaining": len(regions),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Column Detection Endpoints (Step 3)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/columns")
|
|
async def detect_columns(session_id: str):
|
|
"""Run column detection on the cropped (or dewarped) image."""
|
|
if session_id not in _cache:
|
|
await _load_session_to_cache(session_id)
|
|
cached = _get_cached(session_id)
|
|
|
|
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
|
if img_bgr is None:
|
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Sub-sessions (box crops): skip column detection entirely.
|
|
# Instead, create a single pseudo-column spanning the full image width.
|
|
# Also run Tesseract + binarization here so that the row detection step
|
|
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
|
|
# instead of falling back to detect_column_geometry() which may fail
|
|
# on small box images with < 5 words.
|
|
# -----------------------------------------------------------------------
|
|
session = await get_session_db(session_id)
|
|
if session and session.get("parent_session_id"):
|
|
h, w = img_bgr.shape[:2]
|
|
|
|
# Binarize + invert for row detection (horizontal projection profile)
|
|
ocr_img = create_ocr_image(img_bgr)
|
|
inv = cv2.bitwise_not(ocr_img)
|
|
|
|
# Run Tesseract to get word bounding boxes.
|
|
# Word positions are relative to the full image (no ROI crop needed
|
|
# because the sub-session image IS the cropped box already).
|
|
# detect_row_geometry expects word positions relative to content ROI,
|
|
# so with content_bounds = (0, w, 0, h) the coordinates are correct.
|
|
try:
|
|
from PIL import Image as PILImage
|
|
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
|
import pytesseract
|
|
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
|
|
word_dicts = []
|
|
for i in range(len(data['text'])):
|
|
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
|
|
text = str(data['text'][i]).strip()
|
|
if conf < 30 or not text:
|
|
continue
|
|
word_dicts.append({
|
|
'text': text, 'conf': conf,
|
|
'left': int(data['left'][i]),
|
|
'top': int(data['top'][i]),
|
|
'width': int(data['width'][i]),
|
|
'height': int(data['height'][i]),
|
|
})
|
|
# Log all words including low-confidence ones for debugging
|
|
all_count = sum(1 for i in range(len(data['text']))
|
|
if str(data['text'][i]).strip())
|
|
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
|
|
for i in range(len(data['text']))
|
|
if str(data['text'][i]).strip()
|
|
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
|
|
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
|
|
if low_conf:
|
|
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
|
|
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
|
|
except Exception as e:
|
|
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
|
|
word_dicts = []
|
|
|
|
# Cache intermediates for row detection (detect_rows reuses these)
|
|
cached["_word_dicts"] = word_dicts
|
|
cached["_inv"] = inv
|
|
cached["_content_bounds"] = (0, w, 0, h)
|
|
|
|
column_result = {
|
|
"columns": [{
|
|
"type": "column_text",
|
|
"x": 0, "y": 0,
|
|
"width": w, "height": h,
|
|
}],
|
|
"zones": None,
|
|
"boxes_detected": 0,
|
|
"duration_seconds": 0,
|
|
"method": "sub_session_pseudo_column",
|
|
}
|
|
await update_session_db(
|
|
session_id,
|
|
column_result=column_result,
|
|
row_result=None,
|
|
word_result=None,
|
|
current_step=6,
|
|
)
|
|
cached["column_result"] = column_result
|
|
cached.pop("row_result", None)
|
|
cached.pop("word_result", None)
|
|
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
|
|
return {"session_id": session_id, **column_result}
|
|
|
|
t0 = time.time()
|
|
|
|
# Binarized image for layout analysis
|
|
ocr_img = create_ocr_image(img_bgr)
|
|
h, w = ocr_img.shape[:2]
|
|
|
|
# Phase A: Zone-aware geometry detection
|
|
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
|
|
|
|
if zoned_result is None:
|
|
# Fallback to projection-based layout
|
|
layout_img = create_layout_image(img_bgr)
|
|
regions = analyze_layout(layout_img, ocr_img)
|
|
zones_data = None
|
|
boxes_detected = 0
|
|
else:
|
|
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
|
|
content_w = right_x - left_x
|
|
boxes_detected = len(boxes)
|
|
|
|
# Cache intermediates for row detection (avoids second Tesseract run)
|
|
cached["_word_dicts"] = word_dicts
|
|
cached["_inv"] = inv
|
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
|
cached["_zones_data"] = zones_data
|
|
cached["_boxes_detected"] = boxes_detected
|
|
|
|
# Detect header/footer early so sub-column clustering ignores them
|
|
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
|
|
|
# Split sub-columns (e.g. page references) before classification
|
|
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
|
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
|
|
|
# Expand narrow columns (sub-columns are often very narrow)
|
|
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
|
|
|
|
# Phase B: Content-based classification
|
|
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
|
left_x=left_x, right_x=right_x, inv=inv)
|
|
|
|
duration = time.time() - t0
|
|
|
|
columns = [asdict(r) for r in regions]
|
|
|
|
# Determine classification methods used
|
|
methods = list(set(
|
|
c.get("classification_method", "") for c in columns
|
|
if c.get("classification_method")
|
|
))
|
|
|
|
column_result = {
|
|
"columns": columns,
|
|
"classification_methods": methods,
|
|
"duration_seconds": round(duration, 2),
|
|
"boxes_detected": boxes_detected,
|
|
}
|
|
|
|
# Add zone data when boxes are present
|
|
if zones_data and boxes_detected > 0:
|
|
column_result["zones"] = zones_data
|
|
|
|
# Persist to DB — also invalidate downstream results (rows, words)
|
|
await update_session_db(
|
|
session_id,
|
|
column_result=column_result,
|
|
row_result=None,
|
|
word_result=None,
|
|
current_step=6,
|
|
)
|
|
|
|
# Update cache
|
|
cached["column_result"] = column_result
|
|
cached.pop("row_result", None)
|
|
cached.pop("word_result", None)
|
|
|
|
col_count = len([c for c in columns if c["type"].startswith("column")])
|
|
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
|
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
|
|
|
|
img_w = img_bgr.shape[1]
|
|
await _append_pipeline_log(session_id, "columns", {
|
|
"total_columns": len(columns),
|
|
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
|
|
"column_types": [c["type"] for c in columns],
|
|
"boxes_detected": boxes_detected,
|
|
}, duration_ms=int(duration * 1000))
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
**column_result,
|
|
}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/columns/manual")
|
|
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
|
"""Override detected columns with manual definitions."""
|
|
column_result = {
|
|
"columns": req.columns,
|
|
"duration_seconds": 0,
|
|
"method": "manual",
|
|
}
|
|
|
|
await update_session_db(session_id, column_result=column_result,
|
|
row_result=None, word_result=None)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["column_result"] = column_result
|
|
_cache[session_id].pop("row_result", None)
|
|
_cache[session_id].pop("word_result", None)
|
|
|
|
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
|
f"{len(req.columns)} columns set")
|
|
|
|
return {"session_id": session_id, **column_result}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/ground-truth/columns")
|
|
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
|
"""Save ground truth feedback for the column detection step."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
gt = {
|
|
"is_correct": req.is_correct,
|
|
"corrected_columns": req.corrected_columns,
|
|
"notes": req.notes,
|
|
"saved_at": datetime.utcnow().isoformat(),
|
|
"column_result": session.get("column_result"),
|
|
}
|
|
ground_truth["columns"] = gt
|
|
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
return {"session_id": session_id, "ground_truth": gt}
|
|
|
|
|
|
@router.get("/sessions/{session_id}/ground-truth/columns")
|
|
async def get_column_ground_truth(session_id: str):
|
|
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
columns_gt = ground_truth.get("columns")
|
|
if not columns_gt:
|
|
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"columns_gt": columns_gt,
|
|
"columns_auto": session.get("column_result"),
|
|
}
|