feat(ocr-pipeline): add SSE streaming for word recognition (Step 5)
Cells now appear one-by-one in the UI as they are OCR'd, with a live progress bar, instead of waiting for the full result. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -62,7 +62,11 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
||||
const [usedEngine, setUsedEngine] = useState<string>('')
|
||||
const [pronunciation, setPronunciation] = useState<'british' | 'american'>('british')
|
||||
|
||||
// Streaming progress state
|
||||
const [streamProgress, setStreamProgress] = useState<{ current: number; total: number } | null>(null)
|
||||
|
||||
const enRef = useRef<HTMLInputElement>(null)
|
||||
const tableEndRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const isVocab = gridResult?.layout === 'vocab'
|
||||
|
||||
@@ -110,16 +114,107 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
||||
const eng = engine || ocrEngine
|
||||
setDetecting(true)
|
||||
setError(null)
|
||||
setStreamProgress(null)
|
||||
setEditedCells([])
|
||||
setEditedEntries([])
|
||||
setGridResult(null)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words?engine=${eng}&pronunciation=${pronunciation}`, {
|
||||
method: 'POST',
|
||||
})
|
||||
const res = await fetch(
|
||||
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words?stream=true&engine=${eng}&pronunciation=${pronunciation}`,
|
||||
{ method: 'POST' },
|
||||
)
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({ detail: res.statusText }))
|
||||
throw new Error(err.detail || 'Worterkennung fehlgeschlagen')
|
||||
}
|
||||
const data = await res.json()
|
||||
applyGridResult(data)
|
||||
|
||||
const reader = res.body!.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
let streamLayout: string | null = null
|
||||
let streamColumnsUsed: GridResult['columns_used'] = []
|
||||
let streamGridShape: GridResult['grid_shape'] | null = null
|
||||
let streamCells: GridCell[] = []
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
// Parse SSE events (separated by \n\n)
|
||||
while (buffer.includes('\n\n')) {
|
||||
const idx = buffer.indexOf('\n\n')
|
||||
const chunk = buffer.slice(0, idx).trim()
|
||||
buffer = buffer.slice(idx + 2)
|
||||
|
||||
if (!chunk.startsWith('data: ')) continue
|
||||
const dataStr = chunk.slice(6) // strip "data: "
|
||||
|
||||
let event: any
|
||||
try {
|
||||
event = JSON.parse(dataStr)
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
|
||||
if (event.type === 'meta') {
|
||||
streamLayout = event.layout || 'generic'
|
||||
streamGridShape = event.grid_shape || null
|
||||
// Show partial grid result so UI renders structure
|
||||
setGridResult(prev => ({
|
||||
...prev,
|
||||
layout: event.layout || 'generic',
|
||||
grid_shape: event.grid_shape,
|
||||
columns_used: [],
|
||||
cells: [],
|
||||
summary: { total_cells: event.grid_shape?.total_cells || 0, non_empty_cells: 0, low_confidence: 0 },
|
||||
duration_seconds: 0,
|
||||
ocr_engine: '',
|
||||
} as GridResult))
|
||||
}
|
||||
|
||||
if (event.type === 'columns') {
|
||||
streamColumnsUsed = event.columns_used || []
|
||||
setGridResult(prev => prev ? { ...prev, columns_used: streamColumnsUsed } : prev)
|
||||
}
|
||||
|
||||
if (event.type === 'cell') {
|
||||
const cell: GridCell = { ...event.cell, status: 'pending' }
|
||||
streamCells = [...streamCells, cell]
|
||||
setEditedCells(streamCells)
|
||||
setStreamProgress(event.progress)
|
||||
// Auto-scroll table to bottom
|
||||
setTimeout(() => tableEndRef.current?.scrollIntoView({ behavior: 'smooth', block: 'nearest' }), 16)
|
||||
}
|
||||
|
||||
if (event.type === 'complete') {
|
||||
// Build final GridResult
|
||||
const finalResult: GridResult = {
|
||||
cells: streamCells,
|
||||
grid_shape: streamGridShape || { rows: 0, cols: 0, total_cells: streamCells.length },
|
||||
columns_used: streamColumnsUsed,
|
||||
layout: streamLayout || 'generic',
|
||||
image_width: 0,
|
||||
image_height: 0,
|
||||
duration_seconds: event.duration_seconds || 0,
|
||||
ocr_engine: event.ocr_engine || '',
|
||||
summary: event.summary || {},
|
||||
}
|
||||
|
||||
// If vocab: apply post-processed entries from complete event
|
||||
if (event.vocab_entries) {
|
||||
finalResult.entries = event.vocab_entries
|
||||
finalResult.vocab_entries = event.vocab_entries
|
||||
finalResult.entry_count = event.vocab_entries.length
|
||||
}
|
||||
|
||||
applyGridResult(finalResult)
|
||||
setUsedEngine(event.ocr_engine || '')
|
||||
setStreamProgress(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
|
||||
} finally {
|
||||
@@ -288,11 +383,23 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Loading */}
|
||||
{/* Loading with streaming progress */}
|
||||
{detecting && (
|
||||
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
|
||||
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
|
||||
Worterkennung laeuft...
|
||||
<div className="space-y-1">
|
||||
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
|
||||
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
|
||||
{streamProgress
|
||||
? `Zelle ${streamProgress.current}/${streamProgress.total} erkannt...`
|
||||
: 'Worterkennung startet...'}
|
||||
</div>
|
||||
{streamProgress && streamProgress.total > 0 && (
|
||||
<div className="w-full bg-gray-200 dark:bg-gray-700 rounded-full h-1.5">
|
||||
<div
|
||||
className="bg-teal-500 h-1.5 rounded-full transition-all duration-150"
|
||||
style={{ width: `${(streamProgress.current / streamProgress.total) * 100}%` }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -378,8 +485,8 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Result summary */}
|
||||
{gridResult && summary && (
|
||||
{/* Result summary (only after streaming completes) */}
|
||||
{gridResult && summary && !detecting && (
|
||||
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<h4 className="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
@@ -511,6 +618,67 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
||||
</tbody>
|
||||
</table>
|
||||
)}
|
||||
<div ref={tableEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Streaming cell table (shown while detecting, before complete) */}
|
||||
{detecting && editedCells.length > 0 && !gridResult?.summary?.non_empty_cells && (
|
||||
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
|
||||
<h4 className="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Live: {editedCells.length} Zellen erkannt...
|
||||
</h4>
|
||||
<div className="max-h-80 overflow-y-auto">
|
||||
<table className="w-full text-xs">
|
||||
<thead className="sticky top-0 bg-white dark:bg-gray-800">
|
||||
<tr className="text-left text-gray-500 dark:text-gray-400 border-b dark:border-gray-700">
|
||||
<th className="py-1 pr-2 w-12">Zelle</th>
|
||||
{columnsUsed.map((col, i) => (
|
||||
<th key={i} className={`py-1 pr-2 ${colTypeColor(col.type)}`}>
|
||||
{colTypeLabel(col.type)}
|
||||
</th>
|
||||
))}
|
||||
<th className="py-1 w-12 text-right">Conf</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{(() => {
|
||||
const liveByRow: Map<number, GridCell[]> = new Map()
|
||||
for (const cell of editedCells) {
|
||||
const existing = liveByRow.get(cell.row_index) || []
|
||||
existing.push(cell)
|
||||
liveByRow.set(cell.row_index, existing)
|
||||
}
|
||||
const liveSorted = [...liveByRow.keys()].sort((a, b) => a - b)
|
||||
return liveSorted.map(rowIdx => {
|
||||
const rowCells = liveByRow.get(rowIdx) || []
|
||||
const avgConf = rowCells.length
|
||||
? Math.round(rowCells.reduce((s, c) => s + c.confidence, 0) / rowCells.length)
|
||||
: 0
|
||||
return (
|
||||
<tr key={rowIdx} className="border-b dark:border-gray-700/50 animate-fade-in">
|
||||
<td className="py-1 pr-2 text-gray-400 font-mono text-[10px]">
|
||||
R{String(rowIdx).padStart(2, '0')}
|
||||
</td>
|
||||
{columnsUsed.map((col) => {
|
||||
const cell = rowCells.find(c => c.col_index === col.index)
|
||||
return (
|
||||
<td key={col.index} className="py-1 pr-2 font-mono text-gray-700 dark:text-gray-300">
|
||||
<MultilineText text={cell?.text || ''} />
|
||||
</td>
|
||||
)
|
||||
})}
|
||||
<td className={`py-1 text-right font-mono ${confColor(avgConf)}`}>
|
||||
{avgConf}%
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})
|
||||
})()}
|
||||
</tbody>
|
||||
</table>
|
||||
<div ref={tableEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -19,7 +19,7 @@ import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -3009,6 +3009,94 @@ def _replace_phonetics_in_text(text: str, pronunciation: str = 'british') -> str
|
||||
return _PHONETIC_BRACKET_RE.sub(replacer, text)
|
||||
|
||||
|
||||
def _ocr_single_cell(
|
||||
row_idx: int,
|
||||
col_idx: int,
|
||||
row: RowGeometry,
|
||||
col: PageRegion,
|
||||
ocr_img: np.ndarray,
|
||||
img_bgr: Optional[np.ndarray],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
use_rapid: bool,
|
||||
engine_name: str,
|
||||
lang: str,
|
||||
lang_map: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
"""OCR a single cell (column × row intersection) and return its dict."""
|
||||
pad = 8 # pixels
|
||||
cell_x = max(0, col.x - pad)
|
||||
cell_y = max(0, row.y - pad)
|
||||
cell_w = col.width + 2 * pad
|
||||
cell_h = row.height + 2 * pad
|
||||
|
||||
# Clamp to image bounds
|
||||
if cell_x + cell_w > img_w:
|
||||
cell_w = img_w - cell_x
|
||||
if cell_y + cell_h > img_h:
|
||||
cell_h = img_h - cell_y
|
||||
|
||||
if cell_w <= 0 or cell_h <= 0:
|
||||
return {
|
||||
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||
'row_index': row_idx,
|
||||
'col_index': col_idx,
|
||||
'col_type': col.type,
|
||||
'text': '',
|
||||
'confidence': 0.0,
|
||||
'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height},
|
||||
'bbox_pct': {
|
||||
'x': round(col.x / img_w * 100, 2),
|
||||
'y': round(row.y / img_h * 100, 2),
|
||||
'w': round(col.width / img_w * 100, 2),
|
||||
'h': round(row.height / img_h * 100, 2),
|
||||
},
|
||||
'ocr_engine': engine_name,
|
||||
}
|
||||
|
||||
cell_region = PageRegion(
|
||||
type=col.type,
|
||||
x=cell_x, y=cell_y,
|
||||
width=cell_w, height=cell_h,
|
||||
)
|
||||
|
||||
# OCR the cell
|
||||
if use_rapid:
|
||||
words = ocr_region_rapid(img_bgr, cell_region)
|
||||
else:
|
||||
cell_lang = lang_map.get(col.type, lang)
|
||||
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
|
||||
|
||||
# Group into lines, then join in reading order
|
||||
if words:
|
||||
avg_h = sum(w['height'] for w in words) / len(words)
|
||||
y_tol = max(10, int(avg_h * 0.5))
|
||||
else:
|
||||
y_tol = 15
|
||||
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
|
||||
|
||||
avg_conf = 0.0
|
||||
if words:
|
||||
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
|
||||
|
||||
return {
|
||||
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||
'row_index': row_idx,
|
||||
'col_index': col_idx,
|
||||
'col_type': col.type,
|
||||
'text': text,
|
||||
'confidence': avg_conf,
|
||||
'bbox_px': {'x': cell_x, 'y': cell_y, 'w': cell_w, 'h': cell_h},
|
||||
'bbox_pct': {
|
||||
'x': round(cell_x / img_w * 100, 2),
|
||||
'y': round(cell_y / img_h * 100, 2),
|
||||
'w': round(cell_w / img_w * 100, 2),
|
||||
'h': round(cell_h / img_h * 100, 2),
|
||||
},
|
||||
'ocr_engine': engine_name,
|
||||
}
|
||||
|
||||
|
||||
def build_cell_grid(
|
||||
ocr_img: np.ndarray,
|
||||
column_regions: List[PageRegion],
|
||||
@@ -3089,79 +3177,12 @@ def build_cell_grid(
|
||||
|
||||
for row_idx, row in enumerate(content_rows):
|
||||
for col_idx, col in enumerate(relevant_cols):
|
||||
# Compute cell region: column x/width, row y/height
|
||||
pad = 8 # pixels
|
||||
cell_x = max(0, col.x - pad)
|
||||
cell_y = max(0, row.y - pad)
|
||||
cell_w = col.width + 2 * pad
|
||||
cell_h = row.height + 2 * pad
|
||||
|
||||
# Clamp to image bounds
|
||||
if cell_x + cell_w > img_w:
|
||||
cell_w = img_w - cell_x
|
||||
if cell_y + cell_h > img_h:
|
||||
cell_h = img_h - cell_y
|
||||
|
||||
if cell_w <= 0 or cell_h <= 0:
|
||||
cells.append({
|
||||
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||
'row_index': row_idx,
|
||||
'col_index': col_idx,
|
||||
'col_type': col.type,
|
||||
'text': '',
|
||||
'confidence': 0.0,
|
||||
'bbox_px': {'x': col.x, 'y': row.y, 'w': col.width, 'h': row.height},
|
||||
'bbox_pct': {
|
||||
'x': round(col.x / img_w * 100, 2),
|
||||
'y': round(row.y / img_h * 100, 2),
|
||||
'w': round(col.width / img_w * 100, 2),
|
||||
'h': round(row.height / img_h * 100, 2),
|
||||
},
|
||||
'ocr_engine': engine_name,
|
||||
})
|
||||
continue
|
||||
|
||||
cell_region = PageRegion(
|
||||
type=col.type,
|
||||
x=cell_x, y=cell_y,
|
||||
width=cell_w, height=cell_h,
|
||||
cell = _ocr_single_cell(
|
||||
row_idx, col_idx, row, col,
|
||||
ocr_img, img_bgr, img_w, img_h,
|
||||
use_rapid, engine_name, lang, lang_map,
|
||||
)
|
||||
|
||||
# OCR the cell
|
||||
if use_rapid:
|
||||
words = ocr_region_rapid(img_bgr, cell_region)
|
||||
else:
|
||||
cell_lang = lang_map.get(col.type, lang)
|
||||
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
|
||||
|
||||
# Group into lines, then join in reading order
|
||||
if words:
|
||||
avg_h = sum(w['height'] for w in words) / len(words)
|
||||
y_tol = max(10, int(avg_h * 0.5))
|
||||
else:
|
||||
y_tol = 15
|
||||
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
|
||||
|
||||
avg_conf = 0.0
|
||||
if words:
|
||||
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
|
||||
|
||||
cells.append({
|
||||
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||
'row_index': row_idx,
|
||||
'col_index': col_idx,
|
||||
'col_type': col.type,
|
||||
'text': text,
|
||||
'confidence': avg_conf,
|
||||
'bbox_px': {'x': cell_x, 'y': cell_y, 'w': cell_w, 'h': cell_h},
|
||||
'bbox_pct': {
|
||||
'x': round(cell_x / img_w * 100, 2),
|
||||
'y': round(cell_y / img_h * 100, 2),
|
||||
'w': round(cell_w / img_w * 100, 2),
|
||||
'h': round(cell_h / img_h * 100, 2),
|
||||
},
|
||||
'ocr_engine': engine_name,
|
||||
})
|
||||
cells.append(cell)
|
||||
|
||||
logger.info(f"build_cell_grid: {len(cells)} cells from "
|
||||
f"{len(content_rows)} rows × {len(relevant_cols)} columns, "
|
||||
@@ -3170,6 +3191,72 @@ def build_cell_grid(
|
||||
return cells, columns_meta
|
||||
|
||||
|
||||
def build_cell_grid_streaming(
|
||||
ocr_img: np.ndarray,
|
||||
column_regions: List[PageRegion],
|
||||
row_geometries: List[RowGeometry],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
lang: str = "eng+deu",
|
||||
ocr_engine: str = "auto",
|
||||
img_bgr: Optional[np.ndarray] = None,
|
||||
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
|
||||
"""Like build_cell_grid(), but yields each cell as it is OCR'd.
|
||||
|
||||
Yields:
|
||||
(cell_dict, columns_meta, total_cells) for each cell.
|
||||
"""
|
||||
# Resolve engine choice (same as build_cell_grid)
|
||||
use_rapid = False
|
||||
if ocr_engine == "auto":
|
||||
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||
elif ocr_engine == "rapid":
|
||||
if not RAPIDOCR_AVAILABLE:
|
||||
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||
else:
|
||||
use_rapid = True
|
||||
|
||||
engine_name = "rapid" if use_rapid else "tesseract"
|
||||
|
||||
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||
if not content_rows:
|
||||
return
|
||||
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
|
||||
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
|
||||
if not relevant_cols:
|
||||
return
|
||||
|
||||
relevant_cols.sort(key=lambda c: c.x)
|
||||
|
||||
columns_meta = [
|
||||
{
|
||||
'index': col_idx,
|
||||
'type': col.type,
|
||||
'x': col.x,
|
||||
'width': col.width,
|
||||
}
|
||||
for col_idx, col in enumerate(relevant_cols)
|
||||
]
|
||||
|
||||
lang_map = {
|
||||
'column_en': 'eng',
|
||||
'column_de': 'deu',
|
||||
'column_example': 'eng+deu',
|
||||
}
|
||||
|
||||
total_cells = len(content_rows) * len(relevant_cols)
|
||||
|
||||
for row_idx, row in enumerate(content_rows):
|
||||
for col_idx, col in enumerate(relevant_cols):
|
||||
cell = _ocr_single_cell(
|
||||
row_idx, col_idx, row, col,
|
||||
ocr_img, img_bgr, img_w, img_h,
|
||||
use_rapid, engine_name, lang, lang_map,
|
||||
)
|
||||
yield cell, columns_meta, total_cells
|
||||
|
||||
|
||||
def _cells_to_vocab_entries(
|
||||
cells: List[Dict[str, Any]],
|
||||
columns_meta: List[Dict[str, Any]],
|
||||
|
||||
@@ -15,6 +15,7 @@ Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
@@ -24,8 +25,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
@@ -39,6 +40,7 @@ from cv_vocab_pipeline import (
|
||||
analyze_layout,
|
||||
analyze_layout_by_words,
|
||||
build_cell_grid,
|
||||
build_cell_grid_streaming,
|
||||
build_word_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
@@ -1023,12 +1025,19 @@ async def get_row_ground_truth(session_id: str):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/words")
|
||||
async def detect_words(session_id: str, engine: str = "auto", pronunciation: str = "british"):
|
||||
async def detect_words(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
engine: str = "auto",
|
||||
pronunciation: str = "british",
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Build word grid from columns × rows, OCR each cell.
|
||||
|
||||
Query params:
|
||||
engine: 'auto' (default), 'tesseract', or 'rapid'
|
||||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
@@ -1049,12 +1058,6 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
if not row_result or not row_result.get("rows"):
|
||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Convert column dicts back to PageRegion objects
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
@@ -1081,6 +1084,27 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_word_stream_generator(
|
||||
session_id, cached, col_regions, row_geoms,
|
||||
dewarped_bgr, engine, pronunciation, request,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# --- Non-streaming path (unchanged) ---
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Build generic cell grid
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
@@ -1154,6 +1178,140 @@ async def detect_words(session_id: str, engine: str = "auto", pronunciation: str
|
||||
}
|
||||
|
||||
|
||||
async def _word_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||
t0 = time.time()
|
||||
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Compute grid shape upfront for the meta event
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
|
||||
# Determine layout
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
|
||||
# Start streaming — first event: meta
|
||||
columns_meta = None # will be set from first yield
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
# Stream cells one by one
|
||||
all_cells: List[Dict[str, Any]] = []
|
||||
cell_idx = 0
|
||||
|
||||
for cell, cols_meta, total in build_cell_grid_streaming(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||
return
|
||||
|
||||
if columns_meta is None:
|
||||
columns_meta = cols_meta
|
||||
# Send columns_used as part of first cell or update meta
|
||||
meta_update = {
|
||||
"type": "columns",
|
||||
"columns_used": cols_meta,
|
||||
}
|
||||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||
|
||||
all_cells.append(cell)
|
||||
cell_idx += 1
|
||||
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": cell_idx, "total": total},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# All cells done — build final result
|
||||
duration = time.time() - t0
|
||||
if columns_meta is None:
|
||||
columns_meta = []
|
||||
|
||||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||
|
||||
word_result = {
|
||||
"cells": all_cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(all_cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(all_cells),
|
||||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
# Vocab post-processing
|
||||
vocab_entries = None
|
||||
if is_vocab:
|
||||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
entries = _split_comma_entries(entries)
|
||||
entries = _attach_example_sentences(entries)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
current_step=5,
|
||||
)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||
|
||||
# Final complete event
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
|
||||
|
||||
class WordGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
Reference in New Issue
Block a user