feat(ocr-pipeline): add Step 5 word recognition (grid from columns × rows)

Backend: build_word_grid() intersects column regions with content rows,
OCRs each cell with language-specific Tesseract, and returns vocabulary
entries with percent-based bounding boxes. New endpoints: POST /words,
GET /image/words-overlay, ground-truth save/retrieve for words.
Frontend: StepWordRecognition with overview + step-through labeling modes,
goToStep callback for row correction feedback loop.
MkDocs: OCR Pipeline documentation added.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-28 02:18:29 +01:00
parent 47dc2e6f7a
commit 954103cdf2
9 changed files with 1429 additions and 21 deletions

View File

@@ -2169,6 +2169,142 @@ def analyze_layout_by_words(ocr_img: np.ndarray, dewarped_bgr: np.ndarray) -> Li
return regions
# =============================================================================
# Pipeline Step 5: Word Grid from Columns × Rows
# =============================================================================
def build_word_grid(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
) -> List[Dict[str, Any]]:
"""Build a word grid by intersecting columns and rows, then OCR each cell.
Args:
ocr_img: Binarized full-page image.
column_regions: Classified columns from Step 3 (PageRegion list).
row_geometries: Rows from Step 4 (RowGeometry list).
img_w: Image width in pixels.
img_h: Image height in pixels.
lang: Default Tesseract language.
Returns:
List of entry dicts with english/german/example text and bbox info (percent).
"""
# Filter to content rows only (skip header/footer)
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
logger.warning("build_word_grid: no content rows found")
return []
# Map column types to roles
VOCAB_COLUMN_TYPES = {'column_en', 'column_de', 'column_example'}
relevant_cols = [c for c in column_regions if c.type in VOCAB_COLUMN_TYPES]
if not relevant_cols:
logger.warning("build_word_grid: no relevant vocabulary columns found")
return []
# Sort columns left-to-right
relevant_cols.sort(key=lambda c: c.x)
# Choose OCR language per column type
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
entries: List[Dict[str, Any]] = []
for row_idx, row in enumerate(content_rows):
entry: Dict[str, Any] = {
'row_index': row_idx,
'english': '',
'german': '',
'example': '',
'confidence': 0.0,
'bbox': {
'x': round(row.x / img_w * 100, 2),
'y': round(row.y / img_h * 100, 2),
'w': round(row.width / img_w * 100, 2),
'h': round(row.height / img_h * 100, 2),
},
'bbox_en': None,
'bbox_de': None,
'bbox_ex': None,
}
confidences: List[float] = []
for col in relevant_cols:
# Compute cell region: column x/width, row y/height
cell_x = col.x
cell_y = row.y
cell_w = col.width
cell_h = row.height
# Clamp to image bounds
cell_x = max(0, cell_x)
cell_y = max(0, cell_y)
if cell_x + cell_w > img_w:
cell_w = img_w - cell_x
if cell_y + cell_h > img_h:
cell_h = img_h - cell_y
if cell_w <= 0 or cell_h <= 0:
continue
cell_region = PageRegion(
type=col.type,
x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
)
cell_lang = lang_map.get(col.type, lang)
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=7)
# Sort words by x position, join to text
words.sort(key=lambda w: w['left'])
text = ' '.join(w['text'] for w in words)
if words:
avg_conf = sum(w['conf'] for w in words) / len(words)
confidences.append(avg_conf)
# Bbox in percent
cell_bbox = {
'x': round(cell_x / img_w * 100, 2),
'y': round(cell_y / img_h * 100, 2),
'w': round(cell_w / img_w * 100, 2),
'h': round(cell_h / img_h * 100, 2),
}
if col.type == 'column_en':
entry['english'] = text
entry['bbox_en'] = cell_bbox
elif col.type == 'column_de':
entry['german'] = text
entry['bbox_de'] = cell_bbox
elif col.type == 'column_example':
entry['example'] = text
entry['bbox_ex'] = cell_bbox
entry['confidence'] = round(
sum(confidences) / len(confidences), 1
) if confidences else 0.0
# Only include if at least one field has text
if entry['english'] or entry['german'] or entry['example']:
entries.append(entry)
logger.info(f"build_word_grid: {len(entries)} entries from "
f"{len(content_rows)} content rows × {len(relevant_cols)} columns")
return entries
# =============================================================================
# Stage 6: Multi-Pass OCR
# =============================================================================

View File

@@ -0,0 +1,4 @@
-- Migration 004: Add word_result column for OCR Pipeline Step 5
-- Stores the word recognition grid result (entries with english/german/example + bboxes)
ALTER TABLE ocr_pipeline_sessions ADD COLUMN IF NOT EXISTS word_result JSONB;

View File

@@ -29,8 +29,11 @@ from fastapi.responses import Response
from pydantic import BaseModel
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
analyze_layout,
analyze_layout_by_words,
build_word_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
@@ -261,6 +264,10 @@ async def get_session_info(session_id: str):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
return result
@@ -291,7 +298,7 @@ async def delete_session(session_id: str):
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay"}
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
@@ -301,6 +308,9 @@ async def get_image(session_id: str, image_type: str):
if image_type == "rows-overlay":
return await _get_rows_overlay(session_id)
if image_type == "words-overlay":
return await _get_words_overlay(session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
@@ -992,6 +1002,153 @@ async def get_row_ground_truth(session_id: str):
}
# ---------------------------------------------------------------------------
# Word Recognition Endpoints (Step 5)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(session_id: str):
"""Build word grid from columns × rows, OCR each cell."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=400, detail="Column detection must be completed first")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
t0 = time.time()
# Create binarized OCR image
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Build word grid
entries = build_word_grid(ocr_img, col_regions, row_geoms, img_w, img_h)
duration = time.time() - t0
# Build summary
summary = {
"total_entries": len(entries),
"with_english": sum(1 for e in entries if e.get("english")),
"with_german": sum(1 for e in entries if e.get("german")),
"low_confidence": sum(1 for e in entries if e.get("confidence", 0) < 50),
}
word_result = {
"entries": entries,
"entry_count": len(entries),
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"summary": summary,
}
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=5,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"{len(entries)} entries ({duration:.2f}s), summary: {summary}")
return {
"session_id": session_id,
**word_result,
}
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
@router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_entries": req.corrected_entries,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"word_result": session.get("word_result"),
}
ground_truth["words"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
words_gt = ground_truth.get("words")
if not words_gt:
raise HTTPException(status_code=404, detail="No word ground truth saved")
return {
"session_id": session_id,
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate dewarped image with row bands drawn on it."""
session = await get_session_db(session_id)
@@ -1049,3 +1206,106 @@ async def _get_rows_overlay(session_id: str) -> Response:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate dewarped image with word grid cells drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result or not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
column_result = session.get("column_result")
row_result = session.get("row_result")
# Load dewarped image
dewarped_png = await get_session_image(session_id, "dewarped")
if not dewarped_png:
raise HTTPException(status_code=404, detail="Dewarped image not available")
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
# Color map for cell types (BGR)
cell_colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
}
overlay = img.copy()
# Draw column divider lines (vertical)
if column_result and column_result.get("columns"):
for col in column_result["columns"]:
col_type = col.get("type", "")
if col_type in cell_colors:
cx = col["x"]
cv2.line(img, (cx, 0), (cx, img_h), cell_colors[col_type], 1)
cx_end = col["x"] + col["width"]
cv2.line(img, (cx_end, 0), (cx_end, img_h), cell_colors[col_type], 1)
# Draw row divider lines (horizontal) for content rows
if row_result and row_result.get("rows"):
for row in row_result["rows"]:
if row.get("row_type") == "content":
ry = row["y"]
cv2.line(img, (0, ry), (img_w, ry), (180, 180, 180), 1)
# Draw entry cells with text labels
entries = word_result["entries"]
for entry in entries:
conf = entry.get("confidence", 0)
# Color by confidence: green > 70, yellow 50-70, red < 50
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
for bbox_key, field_key, col_type in [
("bbox_en", "english", "column_en"),
("bbox_de", "german", "column_de"),
("bbox_ex", "example", "column_example"),
]:
bbox = entry.get(bbox_key)
text = entry.get(field_key, "")
if not bbox or not text:
continue
# Convert percent to pixels
bx = int(bbox["x"] / 100 * img_w)
by = int(bbox["y"] / 100 * img_h)
bw = int(bbox["w"] / 100 * img_w)
bh = int(bbox["h"] / 100 * img_h)
color = cell_colors.get(col_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), color, -1)
# Border
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), text_color, 1)
# Text label (truncate if too long)
label = text[:30] if len(text) > 30 else text
font_scale = 0.35
cv2.putText(img, label, (bx + 3, by + bh - 4),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")

View File

@@ -80,7 +80,7 @@ async def create_session_db(
) VALUES ($1, $2, $3, $4, 'active', 1)
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
ground_truth, auto_shear_degrees,
word_result, ground_truth, auto_shear_degrees,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png)
@@ -94,7 +94,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
row = await conn.fetchrow("""
SELECT id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
ground_truth, auto_shear_degrees,
word_result, ground_truth, auto_shear_degrees,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
@@ -136,10 +136,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
'name', 'filename', 'status', 'current_step',
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'ground_truth', 'auto_shear_degrees',
'word_result', 'ground_truth', 'auto_shear_degrees',
}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'ground_truth'}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'}
for key, value in kwargs.items():
if key in allowed_fields:
@@ -164,7 +164,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
WHERE id = ${param_idx}
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
ground_truth, auto_shear_degrees,
word_result, ground_truth, auto_shear_degrees,
created_at, updated_at
""", *values)
@@ -220,7 +220,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'ground_truth']:
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])