feat(ocr-pipeline): add Step 5 word recognition (grid from columns × rows)
Backend: build_word_grid() intersects column regions with content rows, OCRs each cell with language-specific Tesseract, and returns vocabulary entries with percent-based bounding boxes. New endpoints: POST /words, GET /image/words-overlay, ground-truth save/retrieve for words. Frontend: StepWordRecognition with overview + step-through labeling modes, goToStep callback for row correction feedback loop. MkDocs: OCR Pipeline documentation added. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2169,6 +2169,142 @@ def analyze_layout_by_words(ocr_img: np.ndarray, dewarped_bgr: np.ndarray) -> Li
|
||||
return regions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline Step 5: Word Grid from Columns × Rows
|
||||
# =============================================================================
|
||||
|
||||
def build_word_grid(
|
||||
ocr_img: np.ndarray,
|
||||
column_regions: List[PageRegion],
|
||||
row_geometries: List[RowGeometry],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
lang: str = "eng+deu",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build a word grid by intersecting columns and rows, then OCR each cell.
|
||||
|
||||
Args:
|
||||
ocr_img: Binarized full-page image.
|
||||
column_regions: Classified columns from Step 3 (PageRegion list).
|
||||
row_geometries: Rows from Step 4 (RowGeometry list).
|
||||
img_w: Image width in pixels.
|
||||
img_h: Image height in pixels.
|
||||
lang: Default Tesseract language.
|
||||
|
||||
Returns:
|
||||
List of entry dicts with english/german/example text and bbox info (percent).
|
||||
"""
|
||||
# Filter to content rows only (skip header/footer)
|
||||
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||
if not content_rows:
|
||||
logger.warning("build_word_grid: no content rows found")
|
||||
return []
|
||||
|
||||
# Map column types to roles
|
||||
VOCAB_COLUMN_TYPES = {'column_en', 'column_de', 'column_example'}
|
||||
relevant_cols = [c for c in column_regions if c.type in VOCAB_COLUMN_TYPES]
|
||||
if not relevant_cols:
|
||||
logger.warning("build_word_grid: no relevant vocabulary columns found")
|
||||
return []
|
||||
|
||||
# Sort columns left-to-right
|
||||
relevant_cols.sort(key=lambda c: c.x)
|
||||
|
||||
# Choose OCR language per column type
|
||||
lang_map = {
|
||||
'column_en': 'eng',
|
||||
'column_de': 'deu',
|
||||
'column_example': 'eng+deu',
|
||||
}
|
||||
|
||||
entries: List[Dict[str, Any]] = []
|
||||
|
||||
for row_idx, row in enumerate(content_rows):
|
||||
entry: Dict[str, Any] = {
|
||||
'row_index': row_idx,
|
||||
'english': '',
|
||||
'german': '',
|
||||
'example': '',
|
||||
'confidence': 0.0,
|
||||
'bbox': {
|
||||
'x': round(row.x / img_w * 100, 2),
|
||||
'y': round(row.y / img_h * 100, 2),
|
||||
'w': round(row.width / img_w * 100, 2),
|
||||
'h': round(row.height / img_h * 100, 2),
|
||||
},
|
||||
'bbox_en': None,
|
||||
'bbox_de': None,
|
||||
'bbox_ex': None,
|
||||
}
|
||||
|
||||
confidences: List[float] = []
|
||||
|
||||
for col in relevant_cols:
|
||||
# Compute cell region: column x/width, row y/height
|
||||
cell_x = col.x
|
||||
cell_y = row.y
|
||||
cell_w = col.width
|
||||
cell_h = row.height
|
||||
|
||||
# Clamp to image bounds
|
||||
cell_x = max(0, cell_x)
|
||||
cell_y = max(0, cell_y)
|
||||
if cell_x + cell_w > img_w:
|
||||
cell_w = img_w - cell_x
|
||||
if cell_y + cell_h > img_h:
|
||||
cell_h = img_h - cell_y
|
||||
|
||||
if cell_w <= 0 or cell_h <= 0:
|
||||
continue
|
||||
|
||||
cell_region = PageRegion(
|
||||
type=col.type,
|
||||
x=cell_x, y=cell_y,
|
||||
width=cell_w, height=cell_h,
|
||||
)
|
||||
|
||||
cell_lang = lang_map.get(col.type, lang)
|
||||
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=7)
|
||||
|
||||
# Sort words by x position, join to text
|
||||
words.sort(key=lambda w: w['left'])
|
||||
text = ' '.join(w['text'] for w in words)
|
||||
if words:
|
||||
avg_conf = sum(w['conf'] for w in words) / len(words)
|
||||
confidences.append(avg_conf)
|
||||
|
||||
# Bbox in percent
|
||||
cell_bbox = {
|
||||
'x': round(cell_x / img_w * 100, 2),
|
||||
'y': round(cell_y / img_h * 100, 2),
|
||||
'w': round(cell_w / img_w * 100, 2),
|
||||
'h': round(cell_h / img_h * 100, 2),
|
||||
}
|
||||
|
||||
if col.type == 'column_en':
|
||||
entry['english'] = text
|
||||
entry['bbox_en'] = cell_bbox
|
||||
elif col.type == 'column_de':
|
||||
entry['german'] = text
|
||||
entry['bbox_de'] = cell_bbox
|
||||
elif col.type == 'column_example':
|
||||
entry['example'] = text
|
||||
entry['bbox_ex'] = cell_bbox
|
||||
|
||||
entry['confidence'] = round(
|
||||
sum(confidences) / len(confidences), 1
|
||||
) if confidences else 0.0
|
||||
|
||||
# Only include if at least one field has text
|
||||
if entry['english'] or entry['german'] or entry['example']:
|
||||
entries.append(entry)
|
||||
|
||||
logger.info(f"build_word_grid: {len(entries)} entries from "
|
||||
f"{len(content_rows)} content rows × {len(relevant_cols)} columns")
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stage 6: Multi-Pass OCR
|
||||
# =============================================================================
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Migration 004: Add word_result column for OCR Pipeline Step 5
|
||||
-- Stores the word recognition grid result (entries with english/german/example + bboxes)
|
||||
|
||||
ALTER TABLE ocr_pipeline_sessions ADD COLUMN IF NOT EXISTS word_result JSONB;
|
||||
@@ -29,8 +29,11 @@ from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
analyze_layout,
|
||||
analyze_layout_by_words,
|
||||
build_word_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
@@ -261,6 +264,10 @@ async def get_session_info(session_id: str):
|
||||
result["dewarp_result"] = session["dewarp_result"]
|
||||
if session.get("column_result"):
|
||||
result["column_result"] = session["column_result"]
|
||||
if session.get("row_result"):
|
||||
result["row_result"] = session["row_result"]
|
||||
if session.get("word_result"):
|
||||
result["word_result"] = session["word_result"]
|
||||
|
||||
return result
|
||||
|
||||
@@ -291,7 +298,7 @@ async def delete_session(session_id: str):
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay"}
|
||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
@@ -301,6 +308,9 @@ async def get_image(session_id: str, image_type: str):
|
||||
if image_type == "rows-overlay":
|
||||
return await _get_rows_overlay(session_id)
|
||||
|
||||
if image_type == "words-overlay":
|
||||
return await _get_words_overlay(session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
@@ -992,6 +1002,153 @@ async def get_row_ground_truth(session_id: str):
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Word Recognition Endpoints (Step 5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/words")
|
||||
async def detect_words(session_id: str):
|
||||
"""Build word grid from columns × rows, OCR each cell."""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Dewarp must be completed before word detection")
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Convert column dicts back to PageRegion objects
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"],
|
||||
x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
# Convert row dicts back to RowGeometry objects
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"],
|
||||
x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0),
|
||||
words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
# Build word grid
|
||||
entries = build_word_grid(ocr_img, col_regions, row_geoms, img_w, img_h)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Build summary
|
||||
summary = {
|
||||
"total_entries": len(entries),
|
||||
"with_english": sum(1 for e in entries if e.get("english")),
|
||||
"with_german": sum(1 for e in entries if e.get("german")),
|
||||
"low_confidence": sum(1 for e in entries if e.get("confidence", 0) < 50),
|
||||
}
|
||||
|
||||
word_result = {
|
||||
"entries": entries,
|
||||
"entry_count": len(entries),
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||||
f"{len(entries)} entries ({duration:.2f}s), summary: {summary}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**word_result,
|
||||
}
|
||||
|
||||
|
||||
class WordGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/ground-truth/words")
|
||||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||||
"""Save ground truth feedback for the word recognition step."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
gt = {
|
||||
"is_correct": req.is_correct,
|
||||
"corrected_entries": req.corrected_entries,
|
||||
"notes": req.notes,
|
||||
"saved_at": datetime.utcnow().isoformat(),
|
||||
"word_result": session.get("word_result"),
|
||||
}
|
||||
ground_truth["words"] = gt
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/ground-truth/words")
|
||||
async def get_word_ground_truth(session_id: str):
|
||||
"""Retrieve saved ground truth for word recognition."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
words_gt = ground_truth.get("words")
|
||||
if not words_gt:
|
||||
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"words_gt": words_gt,
|
||||
"words_auto": session.get("word_result"),
|
||||
}
|
||||
|
||||
|
||||
async def _get_rows_overlay(session_id: str) -> Response:
|
||||
"""Generate dewarped image with row bands drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
@@ -1049,3 +1206,106 @@ async def _get_rows_overlay(session_id: str) -> Response:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
|
||||
async def _get_words_overlay(session_id: str) -> Response:
|
||||
"""Generate dewarped image with word grid cells drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result or not word_result.get("entries"):
|
||||
raise HTTPException(status_code=404, detail="No word data available")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
|
||||
# Load dewarped image
|
||||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||||
if not dewarped_png:
|
||||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||||
|
||||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
# Color map for cell types (BGR)
|
||||
cell_colors = {
|
||||
"column_en": (255, 180, 0), # Blue
|
||||
"column_de": (0, 200, 0), # Green
|
||||
"column_example": (0, 140, 255), # Orange
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
|
||||
# Draw column divider lines (vertical)
|
||||
if column_result and column_result.get("columns"):
|
||||
for col in column_result["columns"]:
|
||||
col_type = col.get("type", "")
|
||||
if col_type in cell_colors:
|
||||
cx = col["x"]
|
||||
cv2.line(img, (cx, 0), (cx, img_h), cell_colors[col_type], 1)
|
||||
cx_end = col["x"] + col["width"]
|
||||
cv2.line(img, (cx_end, 0), (cx_end, img_h), cell_colors[col_type], 1)
|
||||
|
||||
# Draw row divider lines (horizontal) for content rows
|
||||
if row_result and row_result.get("rows"):
|
||||
for row in row_result["rows"]:
|
||||
if row.get("row_type") == "content":
|
||||
ry = row["y"]
|
||||
cv2.line(img, (0, ry), (img_w, ry), (180, 180, 180), 1)
|
||||
|
||||
# Draw entry cells with text labels
|
||||
entries = word_result["entries"]
|
||||
for entry in entries:
|
||||
conf = entry.get("confidence", 0)
|
||||
# Color by confidence: green > 70, yellow 50-70, red < 50
|
||||
if conf >= 70:
|
||||
text_color = (0, 180, 0)
|
||||
elif conf >= 50:
|
||||
text_color = (0, 180, 220)
|
||||
else:
|
||||
text_color = (0, 0, 220)
|
||||
|
||||
for bbox_key, field_key, col_type in [
|
||||
("bbox_en", "english", "column_en"),
|
||||
("bbox_de", "german", "column_de"),
|
||||
("bbox_ex", "example", "column_example"),
|
||||
]:
|
||||
bbox = entry.get(bbox_key)
|
||||
text = entry.get(field_key, "")
|
||||
if not bbox or not text:
|
||||
continue
|
||||
|
||||
# Convert percent to pixels
|
||||
bx = int(bbox["x"] / 100 * img_w)
|
||||
by = int(bbox["y"] / 100 * img_h)
|
||||
bw = int(bbox["w"] / 100 * img_w)
|
||||
bh = int(bbox["h"] / 100 * img_h)
|
||||
|
||||
color = cell_colors.get(col_type, (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), color, -1)
|
||||
|
||||
# Border
|
||||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), text_color, 1)
|
||||
|
||||
# Text label (truncate if too long)
|
||||
label = text[:30] if len(text) > 30 else text
|
||||
font_scale = 0.35
|
||||
cv2.putText(img, label, (bx + 3, by + bh - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, 1)
|
||||
|
||||
# Blend overlay at 10% opacity
|
||||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
@@ -80,7 +80,7 @@ async def create_session_db(
|
||||
) VALUES ($1, $2, $3, $4, 'active', 1)
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
""", uuid.UUID(session_id), name, filename, original_png)
|
||||
|
||||
@@ -94,7 +94,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
@@ -136,10 +136,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
||||
'name', 'filename', 'status', 'current_step',
|
||||
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
|
||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||
'ground_truth', 'auto_shear_degrees',
|
||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||
}
|
||||
|
||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'ground_truth'}
|
||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
@@ -164,7 +164,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
""", *values)
|
||||
|
||||
@@ -220,7 +220,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||
result[key] = result[key].isoformat()
|
||||
|
||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'ground_truth']:
|
||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
|
||||
if key in result and result[key] is not None:
|
||||
if isinstance(result[key], str):
|
||||
result[key] = json.loads(result[key])
|
||||
|
||||
Reference in New Issue
Block a user