feat(ocr-pipeline): add SSE streaming for word recognition (Step 5)
Cells now appear one-by-one in the UI as they are OCR'd, with a live progress bar, instead of waiting for the full result. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@ Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
@@ -24,8 +25,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
@@ -39,6 +40,7 @@ from cv_vocab_pipeline import (
|
||||
analyze_layout,
|
||||
analyze_layout_by_words,
|
||||
build_cell_grid,
|
||||
build_cell_grid_streaming,
|
||||
build_word_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
@@ -1023,12 +1025,19 @@ async def get_row_ground_truth(session_id: str):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/words")
|
||||
async def detect_words(session_id: str, engine: str = "auto", pronunciation: str = "british"):
|
||||
async def detect_words(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
engine: str = "auto",
|
||||
pronunciation: str = "british",
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Build word grid from columns × rows, OCR each cell.
|
||||
|
||||
Query params:
|
||||
engine: 'auto' (default), 'tesseract', or 'rapid'
|
||||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
@@ -1049,12 +1058,6 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Convert column dicts back to PageRegion objects
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
@@ -1081,6 +1084,27 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_word_stream_generator(
|
||||
session_id, cached, col_regions, row_geoms,
|
||||
dewarped_bgr, engine, pronunciation, request,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# --- Non-streaming path (unchanged) ---
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Build generic cell grid
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
@@ -1154,6 +1178,140 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
}
|
||||
|
||||
|
||||
async def _word_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||
t0 = time.time()
|
||||
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Compute grid shape upfront for the meta event
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
|
||||
# Determine layout
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
|
||||
# Start streaming — first event: meta
|
||||
columns_meta = None # will be set from first yield
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
# Stream cells one by one
|
||||
all_cells: List[Dict[str, Any]] = []
|
||||
cell_idx = 0
|
||||
|
||||
for cell, cols_meta, total in build_cell_grid_streaming(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||
return
|
||||
|
||||
if columns_meta is None:
|
||||
columns_meta = cols_meta
|
||||
# Send columns_used as part of first cell or update meta
|
||||
meta_update = {
|
||||
"type": "columns",
|
||||
"columns_used": cols_meta,
|
||||
}
|
||||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||
|
||||
all_cells.append(cell)
|
||||
cell_idx += 1
|
||||
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": cell_idx, "total": total},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# All cells done — build final result
|
||||
duration = time.time() - t0
|
||||
if columns_meta is None:
|
||||
columns_meta = []
|
||||
|
||||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||
|
||||
word_result = {
|
||||
"cells": all_cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(all_cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(all_cells),
|
||||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
# Vocab post-processing
|
||||
vocab_entries = None
|
||||
if is_vocab:
|
||||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
entries = _split_comma_entries(entries)
|
||||
entries = _attach_example_sentences(entries)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
current_step=5,
|
||||
)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||
|
||||
# Final complete event
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
|
||||
|
||||
class WordGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
Reference in New Issue
Block a user