feat(ocr-pipeline): line grouping fix + RapidOCR integration
Fix A: Use _group_words_into_lines() with adaptive Y-tolerance to correctly order words in multi-line cells (fixes word reordering bug). RapidOCR: Add as alternative OCR engine (PaddleOCR models on ONNX Runtime, native ARM64). Engine selectable via dropdown in UI or ?engine= query param. Auto mode prefers RapidOCR when available. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -143,6 +143,7 @@ export interface WordResult {
|
|||||||
image_width: number
|
image_width: number
|
||||||
image_height: number
|
image_height: number
|
||||||
duration_seconds: number
|
duration_seconds: number
|
||||||
|
ocr_engine?: string
|
||||||
summary: {
|
summary: {
|
||||||
total_entries: number
|
total_entries: number
|
||||||
with_english: number
|
with_english: number
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
|||||||
const [activeIndex, setActiveIndex] = useState(0)
|
const [activeIndex, setActiveIndex] = useState(0)
|
||||||
const [editedEntries, setEditedEntries] = useState<WordEntry[]>([])
|
const [editedEntries, setEditedEntries] = useState<WordEntry[]>([])
|
||||||
const [mode, setMode] = useState<'overview' | 'labeling'>('overview')
|
const [mode, setMode] = useState<'overview' | 'labeling'>('overview')
|
||||||
|
const [ocrEngine, setOcrEngine] = useState<'auto' | 'tesseract' | 'rapid'>('auto')
|
||||||
|
const [usedEngine, setUsedEngine] = useState<string>('')
|
||||||
|
|
||||||
const enRef = useRef<HTMLInputElement>(null)
|
const enRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
@@ -35,6 +37,7 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
|||||||
const info = await res.json()
|
const info = await res.json()
|
||||||
if (info.word_result) {
|
if (info.word_result) {
|
||||||
setWordResult(info.word_result)
|
setWordResult(info.word_result)
|
||||||
|
setUsedEngine(info.word_result.ocr_engine || '')
|
||||||
initEntries(info.word_result.entries)
|
initEntries(info.word_result.entries)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -54,27 +57,29 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
|||||||
setActiveIndex(0)
|
setActiveIndex(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
const runAutoDetection = useCallback(async () => {
|
const runAutoDetection = useCallback(async (engine?: string) => {
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
|
const eng = engine || ocrEngine
|
||||||
setDetecting(true)
|
setDetecting(true)
|
||||||
setError(null)
|
setError(null)
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words`, {
|
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/words?engine=${eng}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
})
|
})
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
const err = await res.json().catch(() => ({ detail: res.statusText }))
|
const err = await res.json().catch(() => ({ detail: res.statusText }))
|
||||||
throw new Error(err.detail || 'Worterkennung fehlgeschlagen')
|
throw new Error(err.detail || 'Worterkennung fehlgeschlagen')
|
||||||
}
|
}
|
||||||
const data: WordResult = await res.json()
|
const data = await res.json()
|
||||||
setWordResult(data)
|
setWordResult(data)
|
||||||
|
setUsedEngine(data.ocr_engine || eng)
|
||||||
initEntries(data.entries)
|
initEntries(data.entries)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
|
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
|
||||||
} finally {
|
} finally {
|
||||||
setDetecting(false)
|
setDetecting(false)
|
||||||
}
|
}
|
||||||
}, [sessionId])
|
}, [sessionId, ocrEngine])
|
||||||
|
|
||||||
const handleGroundTruth = useCallback(async (isCorrect: boolean) => {
|
const handleGroundTruth = useCallback(async (isCorrect: boolean) => {
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
@@ -512,6 +517,17 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
|||||||
{wordResult && (
|
{wordResult && (
|
||||||
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
|
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-3">
|
||||||
<div className="flex items-center gap-3 flex-wrap">
|
<div className="flex items-center gap-3 flex-wrap">
|
||||||
|
{/* OCR Engine selector */}
|
||||||
|
<select
|
||||||
|
value={ocrEngine}
|
||||||
|
onChange={(e) => setOcrEngine(e.target.value as 'auto' | 'tesseract' | 'rapid')}
|
||||||
|
className="px-2 py-1.5 text-xs border rounded-lg dark:bg-gray-700 dark:border-gray-600"
|
||||||
|
>
|
||||||
|
<option value="auto">Auto (RapidOCR wenn verfuegbar)</option>
|
||||||
|
<option value="rapid">RapidOCR (ONNX)</option>
|
||||||
|
<option value="tesseract">Tesseract</option>
|
||||||
|
</select>
|
||||||
|
|
||||||
<button
|
<button
|
||||||
onClick={() => runAutoDetection()}
|
onClick={() => runAutoDetection()}
|
||||||
disabled={detecting}
|
disabled={detecting}
|
||||||
@@ -520,6 +536,17 @@ export function StepWordRecognition({ sessionId, onNext, goToStep }: StepWordRec
|
|||||||
Erneut erkennen
|
Erneut erkennen
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
{/* Show which engine was used */}
|
||||||
|
{usedEngine && (
|
||||||
|
<span className={`px-2 py-0.5 rounded text-[10px] uppercase font-semibold ${
|
||||||
|
usedEngine === 'rapid'
|
||||||
|
? 'bg-purple-100 dark:bg-purple-900/30 text-purple-700 dark:text-purple-300'
|
||||||
|
: 'bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400'
|
||||||
|
}`}>
|
||||||
|
{usedEngine}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
|
||||||
<button
|
<button
|
||||||
onClick={() => goToStep(3)}
|
onClick={() => goToStep(3)}
|
||||||
className="px-3 py-1.5 text-xs border rounded-lg hover:bg-gray-50 dark:hover:bg-gray-700 dark:border-gray-600 text-orange-600 dark:text-orange-400 border-orange-300 dark:border-orange-700"
|
className="px-3 py-1.5 text-xs border rounded-lg hover:bg-gray-50 dark:hover:bg-gray-700 dark:border-gray-600 text-orange-600 dark:text-orange-400 border-orange-300 dark:border-orange-700"
|
||||||
|
|||||||
@@ -2173,6 +2173,101 @@ def analyze_layout_by_words(ocr_img: np.ndarray, dewarped_bgr: np.ndarray) -> Li
|
|||||||
# Pipeline Step 5: Word Grid from Columns × Rows
|
# Pipeline Step 5: Word Grid from Columns × Rows
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
def _words_to_reading_order_text(words: List[Dict], y_tolerance_px: int = 15) -> str:
|
||||||
|
"""Join OCR words into text in correct reading order.
|
||||||
|
|
||||||
|
Groups words into visual lines by Y-tolerance, sorts each line by X,
|
||||||
|
then joins lines with spaces. This fixes multi-line cell reading order.
|
||||||
|
"""
|
||||||
|
if not words:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
lines = _group_words_into_lines(words, y_tolerance_px=y_tolerance_px)
|
||||||
|
line_texts = []
|
||||||
|
for line in lines:
|
||||||
|
line_texts.append(' '.join(w['text'] for w in line))
|
||||||
|
return ' '.join(line_texts)
|
||||||
|
|
||||||
|
|
||||||
|
# --- RapidOCR integration (PaddleOCR models on ONNX Runtime) ---
|
||||||
|
|
||||||
|
_rapid_engine = None
|
||||||
|
RAPIDOCR_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rapidocr import RapidOCR as _RapidOCRClass
|
||||||
|
RAPIDOCR_AVAILABLE = True
|
||||||
|
logger.info("RapidOCR available — can be used as alternative to Tesseract")
|
||||||
|
except ImportError:
|
||||||
|
logger.info("RapidOCR not installed — using Tesseract only")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rapid_engine():
|
||||||
|
"""Lazy-init RapidOCR engine (downloads models on first use)."""
|
||||||
|
global _rapid_engine
|
||||||
|
if _rapid_engine is None:
|
||||||
|
_rapid_engine = _RapidOCRClass()
|
||||||
|
logger.info("RapidOCR engine initialized")
|
||||||
|
return _rapid_engine
|
||||||
|
|
||||||
|
|
||||||
|
def ocr_region_rapid(
|
||||||
|
img_bgr: np.ndarray,
|
||||||
|
region: PageRegion,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Run RapidOCR on a specific region, returning word dicts compatible with Tesseract format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_bgr: Full-page BGR image (NOT binarized — RapidOCR works on color/gray).
|
||||||
|
region: Region to crop and OCR.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of word dicts with text, left, top, width, height, conf, region_type.
|
||||||
|
"""
|
||||||
|
engine = _get_rapid_engine()
|
||||||
|
|
||||||
|
# Crop region from BGR image
|
||||||
|
crop = img_bgr[region.y:region.y + region.height,
|
||||||
|
region.x:region.x + region.width]
|
||||||
|
|
||||||
|
if crop.size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = engine(crop)
|
||||||
|
|
||||||
|
if result is None or result.boxes is None or result.txts is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = []
|
||||||
|
boxes = result.boxes # shape (N, 4, 2) — 4 corner points per text line
|
||||||
|
txts = result.txts # tuple of strings
|
||||||
|
scores = result.scores # tuple of floats
|
||||||
|
|
||||||
|
for i, (box, txt, score) in enumerate(zip(boxes, txts, scores)):
|
||||||
|
if not txt or not txt.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# box is [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] (clockwise from top-left)
|
||||||
|
xs = [p[0] for p in box]
|
||||||
|
ys = [p[1] for p in box]
|
||||||
|
left = int(min(xs))
|
||||||
|
top = int(min(ys))
|
||||||
|
w = int(max(xs) - left)
|
||||||
|
h = int(max(ys) - top)
|
||||||
|
|
||||||
|
words.append({
|
||||||
|
'text': txt.strip(),
|
||||||
|
'left': left + region.x, # Absolute coords
|
||||||
|
'top': top + region.y,
|
||||||
|
'width': w,
|
||||||
|
'height': h,
|
||||||
|
'conf': int(score * 100), # 0-100 like Tesseract
|
||||||
|
'region_type': region.type,
|
||||||
|
})
|
||||||
|
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
def build_word_grid(
|
def build_word_grid(
|
||||||
ocr_img: np.ndarray,
|
ocr_img: np.ndarray,
|
||||||
column_regions: List[PageRegion],
|
column_regions: List[PageRegion],
|
||||||
@@ -2180,20 +2275,37 @@ def build_word_grid(
|
|||||||
img_w: int,
|
img_w: int,
|
||||||
img_h: int,
|
img_h: int,
|
||||||
lang: str = "eng+deu",
|
lang: str = "eng+deu",
|
||||||
|
ocr_engine: str = "auto",
|
||||||
|
img_bgr: Optional[np.ndarray] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Build a word grid by intersecting columns and rows, then OCR each cell.
|
"""Build a word grid by intersecting columns and rows, then OCR each cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ocr_img: Binarized full-page image.
|
ocr_img: Binarized full-page image (for Tesseract).
|
||||||
column_regions: Classified columns from Step 3 (PageRegion list).
|
column_regions: Classified columns from Step 3 (PageRegion list).
|
||||||
row_geometries: Rows from Step 4 (RowGeometry list).
|
row_geometries: Rows from Step 4 (RowGeometry list).
|
||||||
img_w: Image width in pixels.
|
img_w: Image width in pixels.
|
||||||
img_h: Image height in pixels.
|
img_h: Image height in pixels.
|
||||||
lang: Default Tesseract language.
|
lang: Default Tesseract language.
|
||||||
|
ocr_engine: 'tesseract', 'rapid', or 'auto' (rapid if available, else tesseract).
|
||||||
|
img_bgr: BGR color image (required for RapidOCR).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of entry dicts with english/german/example text and bbox info (percent).
|
List of entry dicts with english/german/example text and bbox info (percent).
|
||||||
"""
|
"""
|
||||||
|
# Resolve engine choice
|
||||||
|
use_rapid = False
|
||||||
|
if ocr_engine == "auto":
|
||||||
|
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||||
|
elif ocr_engine == "rapid":
|
||||||
|
if not RAPIDOCR_AVAILABLE:
|
||||||
|
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||||
|
else:
|
||||||
|
use_rapid = True
|
||||||
|
|
||||||
|
engine_name = "rapid" if use_rapid else "tesseract"
|
||||||
|
logger.info(f"build_word_grid: using OCR engine '{engine_name}'")
|
||||||
|
|
||||||
# Filter to content rows only (skip header/footer)
|
# Filter to content rows only (skip header/footer)
|
||||||
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||||
if not content_rows:
|
if not content_rows:
|
||||||
@@ -2210,7 +2322,7 @@ def build_word_grid(
|
|||||||
# Sort columns left-to-right
|
# Sort columns left-to-right
|
||||||
relevant_cols.sort(key=lambda c: c.x)
|
relevant_cols.sort(key=lambda c: c.x)
|
||||||
|
|
||||||
# Choose OCR language per column type
|
# Choose OCR language per column type (Tesseract only)
|
||||||
lang_map = {
|
lang_map = {
|
||||||
'column_en': 'eng',
|
'column_en': 'eng',
|
||||||
'column_de': 'deu',
|
'column_de': 'deu',
|
||||||
@@ -2235,6 +2347,7 @@ def build_word_grid(
|
|||||||
'bbox_en': None,
|
'bbox_en': None,
|
||||||
'bbox_de': None,
|
'bbox_de': None,
|
||||||
'bbox_ex': None,
|
'bbox_ex': None,
|
||||||
|
'ocr_engine': engine_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
confidences: List[float] = []
|
confidences: List[float] = []
|
||||||
@@ -2263,12 +2376,22 @@ def build_word_grid(
|
|||||||
width=cell_w, height=cell_h,
|
width=cell_w, height=cell_h,
|
||||||
)
|
)
|
||||||
|
|
||||||
cell_lang = lang_map.get(col.type, lang)
|
# OCR the cell
|
||||||
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
|
if use_rapid:
|
||||||
|
words = ocr_region_rapid(img_bgr, cell_region)
|
||||||
|
else:
|
||||||
|
cell_lang = lang_map.get(col.type, lang)
|
||||||
|
words = ocr_region(ocr_img, cell_region, lang=cell_lang, psm=6)
|
||||||
|
|
||||||
|
# Group into lines, then join in reading order (Fix A)
|
||||||
|
# Use half of average word height as Y-tolerance
|
||||||
|
if words:
|
||||||
|
avg_h = sum(w['height'] for w in words) / len(words)
|
||||||
|
y_tol = max(10, int(avg_h * 0.5))
|
||||||
|
else:
|
||||||
|
y_tol = 15
|
||||||
|
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
|
||||||
|
|
||||||
# Sort words by Y then X (reading order for multi-line cells)
|
|
||||||
words.sort(key=lambda w: (w['top'], w['left']))
|
|
||||||
text = ' '.join(w['text'] for w in words)
|
|
||||||
if words:
|
if words:
|
||||||
avg_conf = sum(w['conf'] for w in words) / len(words)
|
avg_conf = sum(w['conf'] for w in words) / len(words)
|
||||||
confidences.append(avg_conf)
|
confidences.append(avg_conf)
|
||||||
@@ -2300,7 +2423,8 @@ def build_word_grid(
|
|||||||
entries.append(entry)
|
entries.append(entry)
|
||||||
|
|
||||||
logger.info(f"build_word_grid: {len(entries)} entries from "
|
logger.info(f"build_word_grid: {len(entries)} entries from "
|
||||||
f"{len(content_rows)} content rows × {len(relevant_cols)} columns")
|
f"{len(content_rows)} content rows × {len(relevant_cols)} columns "
|
||||||
|
f"(engine={engine_name})")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
|||||||
@@ -1007,8 +1007,12 @@ async def get_row_ground_truth(session_id: str):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/words")
|
@router.post("/sessions/{session_id}/words")
|
||||||
async def detect_words(session_id: str):
|
async def detect_words(session_id: str, engine: str = "auto"):
|
||||||
"""Build word grid from columns × rows, OCR each cell."""
|
"""Build word grid from columns × rows, OCR each cell.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
engine: 'auto' (default), 'tesseract', or 'rapid'
|
||||||
|
"""
|
||||||
if session_id not in _cache:
|
if session_id not in _cache:
|
||||||
await _load_session_to_cache(session_id)
|
await _load_session_to_cache(session_id)
|
||||||
cached = _get_cached(session_id)
|
cached = _get_cached(session_id)
|
||||||
@@ -1030,7 +1034,7 @@ async def detect_words(session_id: str):
|
|||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
# Create binarized OCR image
|
# Create binarized OCR image (for Tesseract)
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
@@ -1060,8 +1064,11 @@ async def detect_words(session_id: str):
|
|||||||
for r in row_result["rows"]
|
for r in row_result["rows"]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build word grid
|
# Build word grid — pass both binarized (for Tesseract) and BGR (for RapidOCR)
|
||||||
entries = build_word_grid(ocr_img, col_regions, row_geoms, img_w, img_h)
|
entries = build_word_grid(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
)
|
||||||
duration = time.time() - t0
|
duration = time.time() - t0
|
||||||
|
|
||||||
# Build summary
|
# Build summary
|
||||||
@@ -1072,6 +1079,9 @@ async def detect_words(session_id: str):
|
|||||||
"low_confidence": sum(1 for e in entries if e.get("confidence", 0) < 50),
|
"low_confidence": sum(1 for e in entries if e.get("confidence", 0) < 50),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Determine which engine was actually used
|
||||||
|
used_engine = entries[0].get("ocr_engine", "tesseract") if entries else engine
|
||||||
|
|
||||||
word_result = {
|
word_result = {
|
||||||
"entries": entries,
|
"entries": entries,
|
||||||
"entry_count": len(entries),
|
"entry_count": len(entries),
|
||||||
@@ -1079,6 +1089,7 @@ async def detect_words(session_id: str):
|
|||||||
"image_height": img_h,
|
"image_height": img_h,
|
||||||
"duration_seconds": round(duration, 2),
|
"duration_seconds": round(duration, 2),
|
||||||
"summary": summary,
|
"summary": summary,
|
||||||
|
"ocr_engine": used_engine,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Persist to DB
|
# Persist to DB
|
||||||
|
|||||||
Reference in New Issue
Block a user