feat(ocr-pipeline): add SSE streaming for word recognition (Step 5)

Cells now appear one-by-one in the UI as they are OCR'd, with a live
progress bar, instead of waiting for the full result.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-01 17:54:20 +01:00
parent a666e883da
commit 7f27783008
3 changed files with 506 additions and 93 deletions

View File

@@ -62,7 +62,11 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
const [usedEngine, setUsedEngine] = useState<string>('')
const [pronunciation, setPronunciation] = useState<'british' | 'american'>('british')
// Streaming progress state
const [streamProgress, setStreamProgress] = useState<{ current: number; total: number } | null>(null)
const enRef = useRef<HTMLInputElement>(null)
const tableEndRef = useRef<HTMLDivElement>(null)
const isVocab = gridResult?.layout === 'vocab'
@@ -110,16 +114,107 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
const eng = engine || ocrEngine
setDetecting(true)
setError(null)
setStreamProgress(null)
setEditedCells([])
setEditedEntries([])
setGridResult(null)
try {
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words?engine=${eng}&pronunciation=${pronunciation}`, {
method: 'POST',
})
const res = await fetch(
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words?stream=true&engine=${eng}&pronunciation=${pronunciation}`,
{ method: 'POST' },
)
if (!res.ok) {
const err = await res.json().catch(() => ({ detail: res.statusText }))
throw new Error(err.detail || 'Worterkennung fehlgeschlagen')
}
const data = await res.json()
applyGridResult(data)
const reader = res.body!.getReader()
const decoder = new TextDecoder()
let buffer = ''
let streamLayout: string | null = null
let streamColumnsUsed: GridResult['columns_used'] = []
let streamGridShape: GridResult['grid_shape'] | null = null
let streamCells: GridCell[] = []
while (true) {
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
// Parse SSE events (separated by \n\n)
while (buffer.includes('\n\n')) {
const idx = buffer.indexOf('\n\n')
const chunk = buffer.slice(0, idx).trim()
buffer = buffer.slice(idx + 2)
if (!chunk.startsWith('data: ')) continue
const dataStr = chunk.slice(6) // strip "data: "
let event: any
try {
event = JSON.parse(dataStr)
} catch {
continue
}
if (event.type === 'meta') {
streamLayout = event.layout || 'generic'
streamGridShape = event.grid_shape || null
// Show partial grid result so UI renders structure
setGridResult(prev => ({
...prev,
layout: event.layout || 'generic',
grid_shape: event.grid_shape,
columns_used: [],
cells: [],
summary: { total_cells: event.grid_shape?.total_cells || 0, non_empty_cells: 0, low_confidence: 0 },
duration_seconds: 0,
ocr_engine: '',
} as GridResult))
}
if (event.type === 'columns') {
streamColumnsUsed = event.columns_used || []
setGridResult(prev => prev ? { ...prev, columns_used: streamColumnsUsed } : prev)
}
if (event.type === 'cell') {
const cell: GridCell = { ...event.cell, status: 'pending' }
streamCells = [...streamCells, cell]
setEditedCells(streamCells)
setStreamProgress(event.progress)
// Auto-scroll table to bottom
setTimeout(() => tableEndRef.current?.scrollIntoView({ behavior: 'smooth', block: 'nearest' }), 16)
}
if (event.type === 'complete') {
// Build final GridResult
const finalResult: GridResult = {
cells: streamCells,
grid_shape: streamGridShape || { rows: 0, cols: 0, total_cells: streamCells.length },
columns_used: streamColumnsUsed,
layout: streamLayout || 'generic',
image_width: 0,
image_height: 0,
duration_seconds: event.duration_seconds || 0,
ocr_engine: event.ocr_engine || '',
summary: event.summary || {},
}
// If vocab: apply post-processed entries from complete event
if (event.vocab_entries) {
finalResult.entries = event.vocab_entries
finalResult.vocab_entries = event.vocab_entries
finalResult.entry_count = event.vocab_entries.length
}
applyGridResult(finalResult)
setUsedEngine(event.ocr_engine || '')
setStreamProgress(null)
}
}
}
} catch (e) {
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
} finally {
@@ -288,11 +383,23 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
return (
<div className="space-y-4">
{/* Loading */}
{/* Loading with streaming progress */}
{detecting && (
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
Worterkennung laeuft...
<div className="space-y-1">
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
{streamProgress
? `Zelle ${streamProgress.current}/${streamProgress.total} erkannt...`
: 'Worterkennung startet...'}
</div>
{streamProgress && streamProgress.total > 0 && (
<div className="w-full bg-gray-200 dark:bg-gray-700 rounded-full h-1.5">
<div
className="bg-teal-500 h-1.5 rounded-full transition-all duration-150"
style={{ width: `${(streamProgress.current / streamProgress.total) * 100}%` }}
/>
</div>
)}
</div>
)}
@@ -378,8 +485,8 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
</div>
</div>
{/* Result summary */}
{gridResult && summary && (
{/* Result summary (only after streaming completes) */}
{gridResult && summary && !detecting && (
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
<div className="flex items-center justify-between">
<h4 className="text-sm font-medium text-gray-700 dark:text-gray-300">
@@ -511,6 +618,67 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
</tbody>
</table>
)}
<div ref={tableEndRef} />
</div>
</div>
)}
{/* Streaming cell table (shown while detecting, before complete) */}
{detecting && editedCells.length > 0 && !gridResult?.summary?.non_empty_cells && (
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
<h4 className="text-sm font-medium text-gray-700 dark:text-gray-300">
Live: {editedCells.length} Zellen erkannt...
</h4>
<div className="max-h-80 overflow-y-auto">
<table className="w-full text-xs">
<thead className="sticky top-0 bg-white dark:bg-gray-800">
<tr className="text-left text-gray-500 dark:text-gray-400 border-b dark:border-gray-700">
<th className="py-1 pr-2 w-12">Zelle</th>
{columnsUsed.map((col, i) => (
<th key={i} className={`py-1 pr-2 ${colTypeColor(col.type)}`}>
{colTypeLabel(col.type)}
</th>
))}
<th className="py-1 w-12 text-right">Conf</th>
</tr>
</thead>
<tbody>
{(() => {
const liveByRow: Map<number, GridCell[]> = new Map()
for (const cell of editedCells) {
const existing = liveByRow.get(cell.row_index) || []
existing.push(cell)
liveByRow.set(cell.row_index, existing)
}
const liveSorted = [...liveByRow.keys()].sort((a, b) => a - b)
return liveSorted.map(rowIdx => {
const rowCells = liveByRow.get(rowIdx) || []
const avgConf = rowCells.length
? Math.round(rowCells.reduce((s, c) => s + c.confidence, 0) / rowCells.length)
: 0
return (
<tr key={rowIdx} className="border-b dark:border-gray-700/50 animate-fade-in">
<td className="py-1 pr-2 text-gray-400 font-mono text-[10px]">
R{String(rowIdx).padStart(2, '0')}
</td>
{columnsUsed.map((col) => {
const cell = rowCells.find(c => c.col_index === col.index)
return (
<td key={col.index} className="py-1 pr-2 font-mono text-gray-700 dark:text-gray-300">
<MultilineText text={cell?.text || ''} />
</td>
)
})}
<td className={`py-1 text-right font-mono ${confColor(avgConf)}`}>
{avgConf}%
</td>
</tr>
)
})
})()}
</tbody>
</table>
<div ref={tableEndRef} />
</div>
</div>
)}

View File

@@ -19,7 +19,7 @@ import io
import logging
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple
import numpy as np
@@ -3009,6 +3009,94 @@ def _replace_phonetics_in_text(text: str, pronunciation: str = 'british') -> str
return _PHONETIC_BRACKET_RE.sub(replacer, text)
def _ocr_single_cell(
row_idx: int,
col_idx: int,
row: RowGeometry,
col: PageRegion,
ocr_img: np.ndarray,
img_bgr: Optional[np.ndarray],
img_w: int,
img_h: int,
use_rapid: bool,
engine_name: str,
lang: str,
lang_map: Dict[str, str],
) -> Dict[str, Any]:
"""OCR a single cell (column × row intersection) and return its dict."""
pad = 8 # pixels
cell_x = max(0, col.x - pad)
cell_y = max(0, row.y - pad)
cell_w = col.width + 2 * pad
cell_h = row.height + 2 * pad
# Clamp to image bounds
if cell_x + cell_w > img_w:
cell_w = img_w - cell_x
if cell_y + cell_h > img_h:
cell_h = img_h - cell_y
if cell_w <= 0 or cell_h <= 0:
return {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': '',
'confidence': 0.0,
'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height},
'bbox_pct': {
'x': round(col.x / img_w * 100, 2),
'y': round(row.y / img_h * 100, 2),
'w': round(col.width / img_w * 100, 2),
'h': round(row.height / img_h * 100, 2),
},
'ocr_engine': engine_name,
}
cell_region = PageRegion(
type=col.type,
x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
)
# OCR the cell
if use_rapid:
words = ocr_region_rapid(img_bgr, cell_region)
else:
cell_lang = lang_map.get(col.type, lang)
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
# Group into lines, then join in reading order
if words:
avg_h = sum(w['height'] for w in words) / len(words)
y_tol = max(10, int(avg_h * 0.5))
else:
y_tol = 15
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = 0.0
if words:
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
return {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': text,
'confidence': avg_conf,
'bbox_px': {'x': cell_x, 'y': cell_y, 'w': cell_w, 'h': cell_h},
'bbox_pct': {
'x': round(cell_x / img_w * 100, 2),
'y': round(cell_y / img_h * 100, 2),
'w': round(cell_w / img_w * 100, 2),
'h': round(cell_h / img_h * 100, 2),
},
'ocr_engine': engine_name,
}
def build_cell_grid(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
@@ -3089,79 +3177,12 @@ def build_cell_grid(
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
# Compute cell region: column x/width, row y/height
pad = 8 # pixels
cell_x = max(0, col.x - pad)
cell_y = max(0, row.y - pad)
cell_w = col.width + 2 * pad
cell_h = row.height + 2 * pad
# Clamp to image bounds
if cell_x + cell_w > img_w:
cell_w = img_w - cell_x
if cell_y + cell_h > img_h:
cell_h = img_h - cell_y
if cell_w <= 0 or cell_h <= 0:
cells.append({
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': '',
'confidence': 0.0,
'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height},
'bbox_pct': {
'x': round(col.x / img_w * 100, 2),
'y': round(row.y / img_h * 100, 2),
'w': round(col.width / img_w * 100, 2),
'h': round(row.height / img_h * 100, 2),
},
'ocr_engine': engine_name,
})
continue
cell_region = PageRegion(
type=col.type,
x=cell_x, y=cell_y,
width=cell_w, height=cell_h,
cell = _ocr_single_cell(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
use_rapid, engine_name, lang, lang_map,
)
# OCR the cell
if use_rapid:
words = ocr_region_rapid(img_bgr, cell_region)
else:
cell_lang = lang_map.get(col.type, lang)
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
# Group into lines, then join in reading order
if words:
avg_h = sum(w['height'] for w in words) / len(words)
y_tol = max(10, int(avg_h * 0.5))
else:
y_tol = 15
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = 0.0
if words:
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
cells.append({
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': text,
'confidence': avg_conf,
'bbox_px': {'x': cell_x, 'y': cell_y, 'w': cell_w, 'h': cell_h},
'bbox_pct': {
'x': round(cell_x / img_w * 100, 2),
'y': round(cell_y / img_h * 100, 2),
'w': round(cell_w / img_w * 100, 2),
'h': round(cell_h / img_h * 100, 2),
},
'ocr_engine': engine_name,
})
cells.append(cell)
logger.info(f"build_cell_grid: {len(cells)} cells from "
f"{len(content_rows)} rows × {len(relevant_cols)} columns, "
@@ -3170,6 +3191,72 @@ def build_cell_grid(
return cells, columns_meta
def build_cell_grid_streaming(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
"""Like build_cell_grid(), but yields each cell as it is OCR'd.
Yields:
(cell_dict, columns_meta, total_cells) for each cell.
"""
# Resolve engine choice (same as build_cell_grid)
use_rapid = False
if ocr_engine == "auto":
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
return
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
return
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{
'index': col_idx,
'type': col.type,
'x': col.x,
'width': col.width,
}
for col_idx, col in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
total_cells = len(content_rows) * len(relevant_cols)
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
cell = _ocr_single_cell(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
use_rapid, engine_name, lang, lang_map,
)
yield cell, columns_meta, total_cells
def _cells_to_vocab_entries(
cells: List[Dict[str, Any]],
columns_meta: List[Dict[str, Any]],

View File

@@ -15,6 +15,7 @@ Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
import uuid
@@ -24,8 +25,8 @@ from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
@@ -39,6 +40,7 @@ from cv_vocab_pipeline import (
analyze_layout,
analyze_layout_by_words,
build_cell_grid,
build_cell_grid_streaming,
build_word_grid,
classify_column_types,
create_layout_image,
@@ -1023,12 +1025,19 @@ async def get_row_ground_truth(session_id: str):
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(session_id: str, engine: str = "auto", pronunciation: str = "british"):
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
):
"""Build word grid from columns × rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', or 'rapid'
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
stream: false (default) for JSON response, true for SSE streaming
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
@@ -1049,12 +1058,6 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
@@ -1081,6 +1084,27 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
for r in row_result["rows"]
]
if stream:
return StreamingResponse(
_word_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (unchanged) ---
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Build generic cell grid
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
@@ -1154,6 +1178,140 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
}
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Compute grid shape upfront for the meta event
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
# Determine layout
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Start streaming — first event: meta
columns_meta = None # will be set from first yield
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# Stream cells one by one
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
for cell, cols_meta, total in build_cell_grid_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
# Send columns_used as part of first cell or update meta
meta_update = {
"type": "columns",
"columns_used": cols_meta,
}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done — build final result
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
word_result = {
"cells": all_cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(all_cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
# Vocab post-processing
vocab_entries = None
if is_vocab:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
entries = _split_comma_entries(entries)
entries = _attach_example_sentences(entries)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=5,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
# Final complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None