feat(ocr-pipeline): add SSE streaming for word recognition (Step 5)

Cells now appear one-by-one in the UI as they are OCR'd, with a live
progress bar, instead of waiting for the full result.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-01 17:54:20 +01:00
parent a666e883da
commit 7f27783008
3 changed files with 506 additions and 93 deletions

View File

@@ -15,6 +15,7 @@ Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
import uuid
@@ -24,8 +25,8 @@ from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
@@ -39,6 +40,7 @@ from cv_vocab_pipeline import (
analyze_layout,
analyze_layout_by_words,
build_cell_grid,
build_cell_grid_streaming,
build_word_grid,
classify_column_types,
create_layout_image,
@@ -1023,12 +1025,19 @@ async def get_row_ground_truth(session_id: str):
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(session_id: str, engine: str = "auto", pronunciation: str = "british"):
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
):
"""Build word grid from columns × rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', or 'rapid'
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
stream: false (default) for JSON response, true for SSE streaming
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
@@ -1049,12 +1058,6 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
@@ -1081,6 +1084,27 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
for r in row_result["rows"]
]
if stream:
return StreamingResponse(
_word_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (unchanged) ---
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Build generic cell grid
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
@@ -1154,6 +1178,140 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
}
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Compute grid shape upfront for the meta event
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
# Determine layout
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Start streaming — first event: meta
columns_meta = None # will be set from first yield
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# Stream cells one by one
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
for cell, cols_meta, total in build_cell_grid_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
# Send columns_used as part of first cell or update meta
meta_update = {
"type": "columns",
"columns_used": cols_meta,
}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done — build final result
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
word_result = {
"cells": all_cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(all_cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
# Vocab post-processing
vocab_entries = None
if is_vocab:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
entries = _split_comma_entries(entries)
entries = _attach_example_sentences(entries)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=5,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
# Final complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None