feat(ocr-pipeline): add row detection step with horizontal gap analysis
Add Step 4 (row detection) between column detection and word recognition. Uses horizontal projection profiles + whitespace gaps (same method as columns). Includes header/footer classification via gap-size heuristics. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
"""
|
||||
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
||||
|
||||
Zerlegt den OCR-Prozess in 7 einzelne Schritte:
|
||||
Zerlegt den OCR-Prozess in 8 einzelne Schritte:
|
||||
1. Deskewing - Scan begradigen
|
||||
2. Dewarping - Buchwoelbung entzerren
|
||||
3. Spaltenerkennung - Unsichtbare Spalten finden
|
||||
4. Worterkennung - OCR mit Bounding Boxes
|
||||
5. Koordinatenzuweisung - Exakte Positionen
|
||||
6. Seitenrekonstruktion - Seite nachbauen
|
||||
7. Ground Truth Validierung - Gesamtpruefung
|
||||
4. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen
|
||||
5. Worterkennung - OCR mit Bounding Boxes
|
||||
6. Koordinatenzuweisung - Exakte Positionen
|
||||
7. Seitenrekonstruktion - Seite nachbauen
|
||||
8. Ground Truth Validierung - Gesamtpruefung
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
@@ -30,9 +31,13 @@ from pydantic import BaseModel
|
||||
from cv_vocab_pipeline import (
|
||||
analyze_layout,
|
||||
analyze_layout_by_words,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_row_geometry,
|
||||
dewarp_image,
|
||||
dewarp_image_manual,
|
||||
render_image_high_res,
|
||||
@@ -139,6 +144,16 @@ class ColumnGroundTruthRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class ManualRowsRequest(BaseModel):
|
||||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class RowGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Management Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -275,14 +290,17 @@ async def delete_session(session_id: str):
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, or columns-overlay."""
|
||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay"}
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if image_type == "columns-overlay":
|
||||
return await _get_columns_overlay(session_id)
|
||||
|
||||
if image_type == "rows-overlay":
|
||||
return await _get_rows_overlay(session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
@@ -643,9 +661,27 @@ async def detect_columns(session_id: str):
|
||||
|
||||
# Binarized image for layout analysis
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
# Phase A: Geometry detection (returns word_dicts + inv for reuse)
|
||||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||||
|
||||
if geo_result is None:
|
||||
# Fallback to projection-based layout
|
||||
layout_img = create_layout_image(dewarped_bgr)
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
else:
|
||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
content_w = right_x - left_x
|
||||
|
||||
# Cache intermediates for row detection (avoids second Tesseract run)
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
# Phase B: Content-based classification
|
||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y)
|
||||
|
||||
# Word-based detection (with automatic fallback to projection profiles)
|
||||
regions = analyze_layout_by_words(ocr_img, dewarped_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
@@ -807,3 +843,209 @@ async def _get_columns_overlay(session_id: str) -> Response:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Row Detection Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/rows")
|
||||
async def detect_rows(session_id: str):
|
||||
"""Run row detection on the dewarped image using horizontal gap analysis."""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Dewarp must be completed before row detection")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Try to reuse cached word_dicts and inv from column detection
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
inv = cached.get("_inv")
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
|
||||
if word_dicts is None or inv is None or content_bounds is None:
|
||||
# Not cached — run column geometry to get intermediates
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||||
if geo_result is None:
|
||||
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
else:
|
||||
left_x, right_x, top_y, bottom_y = content_bounds
|
||||
|
||||
# Run row detection
|
||||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Build serializable result (exclude words to keep payload small)
|
||||
rows_data = []
|
||||
for r in rows:
|
||||
rows_data.append({
|
||||
"index": r.index,
|
||||
"x": r.x,
|
||||
"y": r.y,
|
||||
"width": r.width,
|
||||
"height": r.height,
|
||||
"word_count": r.word_count,
|
||||
"row_type": r.row_type,
|
||||
"gap_before": r.gap_before,
|
||||
})
|
||||
|
||||
type_counts = {}
|
||||
for r in rows:
|
||||
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||||
|
||||
row_result = {
|
||||
"rows": rows_data,
|
||||
"summary": type_counts,
|
||||
"total_rows": len(rows),
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
row_result=row_result,
|
||||
current_step=4,
|
||||
)
|
||||
|
||||
cached["row_result"] = row_result
|
||||
|
||||
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||||
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**row_result,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/rows/manual")
|
||||
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||||
"""Override detected rows with manual definitions."""
|
||||
row_result = {
|
||||
"rows": req.rows,
|
||||
"total_rows": len(req.rows),
|
||||
"duration_seconds": 0,
|
||||
"method": "manual",
|
||||
}
|
||||
|
||||
await update_session_db(session_id, row_result=row_result)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["row_result"] = row_result
|
||||
|
||||
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||||
f"{len(req.rows)} rows set")
|
||||
|
||||
return {"session_id": session_id, **row_result}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||||
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||||
"""Save ground truth feedback for the row detection step."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
gt = {
|
||||
"is_correct": req.is_correct,
|
||||
"corrected_rows": req.corrected_rows,
|
||||
"notes": req.notes,
|
||||
"saved_at": datetime.utcnow().isoformat(),
|
||||
"row_result": session.get("row_result"),
|
||||
}
|
||||
ground_truth["rows"] = gt
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||||
async def get_row_ground_truth(session_id: str):
|
||||
"""Retrieve saved ground truth for row detection."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
rows_gt = ground_truth.get("rows")
|
||||
if not rows_gt:
|
||||
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"rows_gt": rows_gt,
|
||||
"rows_auto": session.get("row_result"),
|
||||
}
|
||||
|
||||
|
||||
async def _get_rows_overlay(session_id: str) -> Response:
|
||||
"""Generate dewarped image with row bands drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
row_result = session.get("row_result")
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=404, detail="No row data available")
|
||||
|
||||
# Load dewarped image
|
||||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||||
if not dewarped_png:
|
||||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||||
|
||||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for row types (BGR)
|
||||
row_colors = {
|
||||
"content": (255, 180, 0), # Blue
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for row in row_result["rows"]:
|
||||
x, y = row["x"], row["y"]
|
||||
w, h = row["width"], row["height"]
|
||||
row_type = row.get("row_type", "content")
|
||||
color = row_colors.get(row_type, (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||||
|
||||
# Label
|
||||
idx = row.get("index", 0)
|
||||
label = f"R{idx} {row_type.upper()}"
|
||||
wc = row.get("word_count", 0)
|
||||
if wc:
|
||||
label = f"{label} ({wc}w)"
|
||||
cv2.putText(img, label, (x + 5, y + 18),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
|
||||
# Blend overlay at 15% opacity
|
||||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
Reference in New Issue
Block a user