feat: cell-first OCR + document type detection + dynamic pipeline steps

Cell-First OCR (v2): Each cell is cropped and OCR'd in isolation,
eliminating neighbour bleeding (e.g. "to", "ps" in marker columns).
Uses ThreadPoolExecutor for parallel Tesseract calls.

Document type detection: Classifies pages as vocab_table, full_text,
or generic_table using projection profiles (<2s, no OCR needed).
Frontend dynamically skips columns/rows steps for full-text pages.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-04 13:52:38 +01:00
parent 00a74b3144
commit 29c74a9962
7 changed files with 1001 additions and 75 deletions

View File

@@ -11,7 +11,7 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti
import { StepLlmReview } from '@/components/ocr-pipeline/StepLlmReview'
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types'
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem, type DocumentTypeResult } from './types'
const KLAUSUR_API = '/klausur-api'
@@ -23,6 +23,7 @@ export default function OcrPipelinePage() {
const [loadingSessions, setLoadingSessions] = useState(true)
const [editingName, setEditingName] = useState<string | null>(null)
const [editNameValue, setEditNameValue] = useState('')
const [docTypeResult, setDocTypeResult] = useState<DocumentTypeResult | null>(null)
const [steps, setSteps] = useState<PipelineStep[]>(
PIPELINE_STEPS.map((s, i) => ({
...s,
@@ -59,16 +60,23 @@ export default function OcrPipelinePage() {
setSessionId(sid)
setSessionName(data.name || data.filename || '')
// Restore doc type result if available
const savedDocType: DocumentTypeResult | null = data.doc_type_result || null
setDocTypeResult(savedDocType)
// Determine which step to jump to based on current_step
const dbStep = data.current_step || 1
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
const uiStep = Math.max(0, dbStep - 1)
const skipSteps = savedDocType?.skip_steps || []
setSteps(
PIPELINE_STEPS.map((s, i) => ({
...s,
status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
status: skipSteps.includes(s.id)
? 'skipped'
: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
})),
)
setCurrentStep(uiStep)
@@ -84,6 +92,7 @@ export default function OcrPipelinePage() {
if (sessionId === sid) {
setSessionId(null)
setCurrentStep(0)
setDocTypeResult(null)
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
}
} catch (e) {
@@ -123,16 +132,28 @@ export default function OcrPipelinePage() {
}
const handleNext = () => {
if (currentStep < steps.length - 1) {
if (currentStep >= steps.length - 1) return
// Find the next non-skipped step
const skipSteps = docTypeResult?.skip_steps || []
let nextStep = currentStep + 1
while (nextStep < steps.length && skipSteps.includes(PIPELINE_STEPS[nextStep]?.id)) {
nextStep++
}
if (nextStep >= steps.length) nextStep = steps.length - 1
setSteps((prev) =>
prev.map((s, i) => {
if (i === currentStep) return { ...s, status: 'completed' }
if (i === currentStep + 1) return { ...s, status: 'active' }
if (i === nextStep) return { ...s, status: 'active' }
// Mark skipped steps between current and next
if (i > currentStep && i < nextStep && skipSteps.includes(PIPELINE_STEPS[i]?.id)) {
return { ...s, status: 'skipped' }
}
return s
}),
)
setCurrentStep((prev) => prev + 1)
}
setCurrentStep(nextStep)
}
const handleDeskewComplete = (sid: string) => {
@@ -142,10 +163,69 @@ export default function OcrPipelinePage() {
handleNext()
}
const handleDewarpNext = async () => {
// Auto-detect document type after dewarp, then advance
if (sessionId) {
try {
const res = await fetch(
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-type`,
{ method: 'POST' },
)
if (res.ok) {
const data: DocumentTypeResult = await res.json()
setDocTypeResult(data)
// Mark skipped steps immediately
const skipSteps = data.skip_steps || []
if (skipSteps.length > 0) {
setSteps((prev) =>
prev.map((s) =>
skipSteps.includes(s.id) ? { ...s, status: 'skipped' } : s,
),
)
}
}
} catch (e) {
console.error('Doc type detection failed:', e)
// Not critical — continue without it
}
}
handleNext()
}
const handleDocTypeChange = (newDocType: DocumentTypeResult['doc_type']) => {
if (!docTypeResult) return
// Build new skip_steps based on doc type
let skipSteps: string[] = []
if (newDocType === 'full_text') {
skipSteps = ['columns', 'rows']
}
// vocab_table and generic_table: no skips
const updated: DocumentTypeResult = {
...docTypeResult,
doc_type: newDocType,
skip_steps: skipSteps,
pipeline: newDocType === 'full_text' ? 'full_page' : 'cell_first',
}
setDocTypeResult(updated)
// Update step statuses
setSteps((prev) =>
prev.map((s) => {
if (skipSteps.includes(s.id)) return { ...s, status: 'skipped' as const }
if (s.status === 'skipped') return { ...s, status: 'pending' as const }
return s
}),
)
}
const handleNewSession = () => {
setSessionId(null)
setSessionName('')
setCurrentStep(0)
setDocTypeResult(null)
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
}
@@ -188,7 +268,7 @@ export default function OcrPipelinePage() {
case 0:
return <StepDeskew sessionId={sessionId} onNext={handleDeskewComplete} />
case 1:
return <StepDewarp sessionId={sessionId} onNext={handleNext} />
return <StepDewarp sessionId={sessionId} onNext={handleDewarpNext} />
case 2:
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
case 3:
@@ -314,7 +394,14 @@ export default function OcrPipelinePage() {
</div>
)}
<PipelineStepper steps={steps} currentStep={currentStep} onStepClick={handleStepClick} onReprocess={sessionId ? reprocessFromStep : undefined} />
<PipelineStepper
steps={steps}
currentStep={currentStep}
onStepClick={handleStepClick}
onReprocess={sessionId ? reprocessFromStep : undefined}
docTypeResult={docTypeResult}
onDocTypeChange={handleDocTypeChange}
/>
<div className="min-h-[400px]">{renderStep()}</div>
</div>

View File

@@ -1,4 +1,4 @@
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed'
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed' | 'skipped'
export interface PipelineStep {
id: string
@@ -17,6 +17,15 @@ export interface SessionListItem {
updated_at?: string
}
export interface DocumentTypeResult {
doc_type: 'vocab_table' | 'full_text' | 'generic_table'
confidence: number
pipeline: 'cell_first' | 'full_page'
skip_steps: string[]
features?: Record<string, unknown>
duration_seconds?: number
}
export interface SessionInfo {
session_id: string
filename: string
@@ -30,6 +39,7 @@ export interface SessionInfo {
column_result?: ColumnResult
row_result?: RowResult
word_result?: GridResult
doc_type_result?: DocumentTypeResult
}
export interface DeskewResult {

View File

@@ -1,29 +1,48 @@
'use client'
import { PipelineStep } from '@/app/(admin)/ai/ocr-pipeline/types'
import { PipelineStep, DocumentTypeResult } from '@/app/(admin)/ai/ocr-pipeline/types'
const DOC_TYPE_LABELS: Record<string, string> = {
vocab_table: 'Vokabeltabelle',
full_text: 'Volltext',
generic_table: 'Tabelle',
}
interface PipelineStepperProps {
steps: PipelineStep[]
currentStep: number
onStepClick: (index: number) => void
onReprocess?: (index: number) => void
docTypeResult?: DocumentTypeResult | null
onDocTypeChange?: (docType: DocumentTypeResult['doc_type']) => void
}
export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }: PipelineStepperProps) {
export function PipelineStepper({
steps,
currentStep,
onStepClick,
onReprocess,
docTypeResult,
onDocTypeChange,
}: PipelineStepperProps) {
return (
<div className="space-y-2">
<div className="flex items-center justify-between px-4 py-3 bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700">
{steps.map((step, index) => {
const isActive = index === currentStep
const isCompleted = step.status === 'completed'
const isFailed = step.status === 'failed'
const isClickable = index <= currentStep || isCompleted
const isSkipped = step.status === 'skipped'
const isClickable = (index <= currentStep || isCompleted) && !isSkipped
return (
<div key={step.id} className="flex items-center">
{index > 0 && (
<div
className={`h-0.5 w-8 mx-1 ${
index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
isSkipped
? 'bg-gray-200 dark:bg-gray-700 border-t border-dashed border-gray-400'
: index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
}`}
/>
)}
@@ -32,7 +51,9 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
onClick={() => isClickable && onStepClick(index)}
disabled={!isClickable}
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-full text-sm font-medium transition-all ${
isActive
isSkipped
? 'bg-gray-100 text-gray-400 dark:bg-gray-800 dark:text-gray-600 line-through'
: isActive
? 'bg-teal-100 text-teal-700 dark:bg-teal-900/40 dark:text-teal-300 ring-2 ring-teal-400'
: isCompleted
? 'bg-green-100 text-green-700 dark:bg-green-900/40 dark:text-green-300'
@@ -42,7 +63,7 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
} ${isClickable ? 'cursor-pointer hover:opacity-80' : 'cursor-default'}`}
>
<span className="text-base">
{isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
{isSkipped ? '-' : isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
</span>
<span className="hidden sm:inline">{step.name}</span>
<span className="sm:hidden">{index + 1}</span>
@@ -62,5 +83,33 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
)
})}
</div>
{/* Document type badge */}
{docTypeResult && (
<div className="flex items-center gap-2 px-4 py-2 bg-blue-50 dark:bg-blue-900/20 rounded-lg border border-blue-200 dark:border-blue-800 text-sm">
<span className="text-blue-600 dark:text-blue-400 font-medium">
Dokumenttyp:
</span>
{onDocTypeChange ? (
<select
value={docTypeResult.doc_type}
onChange={(e) => onDocTypeChange(e.target.value as DocumentTypeResult['doc_type'])}
className="bg-white dark:bg-gray-800 border border-blue-300 dark:border-blue-700 rounded px-2 py-0.5 text-sm text-blue-700 dark:text-blue-300"
>
<option value="vocab_table">Vokabeltabelle</option>
<option value="generic_table">Tabelle (generisch)</option>
<option value="full_text">Volltext</option>
</select>
) : (
<span className="text-blue-700 dark:text-blue-300">
{DOC_TYPE_LABELS[docTypeResult.doc_type] || docTypeResult.doc_type}
</span>
)}
<span className="text-blue-400 dark:text-blue-500 text-xs">
({Math.round(docTypeResult.confidence * 100)}% Konfidenz)
</span>
</div>
)}
</div>
)
}

View File

@@ -18,6 +18,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
import io
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Dict, Generator, List, Optional, Tuple
@@ -159,6 +160,16 @@ class PipelineResult:
image_height: int = 0
@dataclass
class DocumentTypeResult:
"""Result of automatic document type detection."""
doc_type: str # 'vocab_table' | 'full_text' | 'generic_table'
confidence: float # 0.0-1.0
pipeline: str # 'cell_first' | 'full_page'
skip_steps: List[str] = field(default_factory=list) # e.g. ['columns', 'rows']
features: Dict[str, Any] = field(default_factory=dict) # debug info
# =============================================================================
# Stage 1: High-Resolution PDF Rendering
# =============================================================================
@@ -966,6 +977,164 @@ def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray:
return _apply_shear(img, -shear_degrees)
# =============================================================================
# Document Type Detection
# =============================================================================
def detect_document_type(ocr_img: np.ndarray, img_bgr: np.ndarray) -> DocumentTypeResult:
"""Detect whether the page is a vocab table, generic table, or full text.
Uses projection profiles and text density analysis — no OCR required.
Runs in < 2 seconds.
Args:
ocr_img: Binarized grayscale image (for projection profiles).
img_bgr: BGR color image.
Returns:
DocumentTypeResult with doc_type, confidence, pipeline, skip_steps.
"""
if ocr_img is None or ocr_img.size == 0:
return DocumentTypeResult(
doc_type='full_text', confidence=0.5, pipeline='full_page',
skip_steps=['columns', 'rows'],
features={'error': 'empty image'},
)
h, w = ocr_img.shape[:2]
# --- 1. Vertical projection profile → detect column gaps ---
# Sum dark pixels along each column (x-axis). Gaps = valleys in the profile.
# Invert: dark pixels on white background → high values = text.
vert_proj = np.sum(ocr_img < 128, axis=0).astype(float)
# Smooth the profile to avoid noise spikes
kernel_size = max(3, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
vert_smooth = np.convolve(vert_proj, np.ones(kernel_size) / kernel_size, mode='same')
# Find significant vertical gaps (columns of near-zero text density)
# A gap must be at least 1% of image width and have < 5% of max density
max_density = max(vert_smooth.max(), 1)
gap_threshold = max_density * 0.05
min_gap_width = max(5, w // 100)
in_gap = False
gap_count = 0
gap_start = 0
vert_gaps = []
for x in range(w):
if vert_smooth[x] < gap_threshold:
if not in_gap:
in_gap = True
gap_start = x
else:
if in_gap:
gap_width = x - gap_start
if gap_width >= min_gap_width:
gap_count += 1
vert_gaps.append((gap_start, x, gap_width))
in_gap = False
# Filter out margin gaps (within 10% of image edges)
margin_threshold = w * 0.10
internal_gaps = [g for g in vert_gaps if g[0] > margin_threshold and g[1] < w - margin_threshold]
internal_gap_count = len(internal_gaps)
# --- 2. Horizontal projection profile → detect row gaps ---
horiz_proj = np.sum(ocr_img < 128, axis=1).astype(float)
h_kernel = max(3, h // 200)
if h_kernel % 2 == 0:
h_kernel += 1
horiz_smooth = np.convolve(horiz_proj, np.ones(h_kernel) / h_kernel, mode='same')
h_max = max(horiz_smooth.max(), 1)
h_gap_threshold = h_max * 0.05
min_row_gap = max(3, h // 200)
row_gap_count = 0
in_gap = False
for y in range(h):
if horiz_smooth[y] < h_gap_threshold:
if not in_gap:
in_gap = True
gap_start = y
else:
if in_gap:
if y - gap_start >= min_row_gap:
row_gap_count += 1
in_gap = False
# --- 3. Text density distribution (4×4 grid) ---
grid_rows, grid_cols = 4, 4
cell_h, cell_w = h // grid_rows, w // grid_cols
densities = []
for gr in range(grid_rows):
for gc in range(grid_cols):
cell = ocr_img[gr * cell_h:(gr + 1) * cell_h,
gc * cell_w:(gc + 1) * cell_w]
if cell.size > 0:
d = float(np.count_nonzero(cell < 128)) / cell.size
densities.append(d)
density_std = float(np.std(densities)) if densities else 0
density_mean = float(np.mean(densities)) if densities else 0
features = {
'vertical_gaps': gap_count,
'internal_vertical_gaps': internal_gap_count,
'vertical_gap_details': [(g[0], g[1], g[2]) for g in vert_gaps[:10]],
'row_gaps': row_gap_count,
'density_mean': round(density_mean, 4),
'density_std': round(density_std, 4),
'image_size': (w, h),
}
# --- 4. Decision tree ---
# Use internal_gap_count (excludes margin gaps) for column detection.
if internal_gap_count >= 2 and row_gap_count >= 5:
# Multiple internal vertical gaps + many row gaps → table
confidence = min(0.95, 0.7 + internal_gap_count * 0.05 + row_gap_count * 0.005)
return DocumentTypeResult(
doc_type='vocab_table',
confidence=round(confidence, 2),
pipeline='cell_first',
skip_steps=[],
features=features,
)
elif internal_gap_count >= 1 and row_gap_count >= 3:
# Some internal structure, likely a table
confidence = min(0.85, 0.5 + internal_gap_count * 0.1 + row_gap_count * 0.01)
return DocumentTypeResult(
doc_type='generic_table',
confidence=round(confidence, 2),
pipeline='cell_first',
skip_steps=[],
features=features,
)
elif internal_gap_count == 0:
# No internal column gaps → full text (regardless of density)
confidence = min(0.95, 0.8 + (1 - min(density_std, 0.1)) * 0.15)
return DocumentTypeResult(
doc_type='full_text',
confidence=round(confidence, 2),
pipeline='full_page',
skip_steps=['columns', 'rows'],
features=features,
)
else:
# Ambiguous — default to vocab_table (most common use case)
return DocumentTypeResult(
doc_type='vocab_table',
confidence=0.5,
pipeline='cell_first',
skip_steps=[],
features=features,
)
# =============================================================================
# Stage 4: Dual Image Preparation
# =============================================================================
@@ -4481,8 +4650,395 @@ def _clean_cell_text(text: str) -> str:
return ' '.join(tokens)
def _clean_cell_text_lite(text: str) -> str:
"""Simplified noise filter for cell-first OCR (isolated cell crops).
Since each cell is OCR'd in isolation (no neighbour content visible),
trailing-noise stripping is unnecessary. Only 2 filters remain:
1. No real alphabetic word (>= 2 letters) and not a known abbreviation → empty.
2. Entire text is garbage (no dictionary word) → empty.
"""
stripped = text.strip()
if not stripped:
return ''
# --- Filter 1: No real word at all ---
if not _RE_REAL_WORD.search(stripped):
alpha_only = ''.join(_RE_ALPHA.findall(stripped)).lower()
if alpha_only not in _KNOWN_ABBREVIATIONS:
return ''
# --- Filter 2: Entire text is garbage ---
if _is_garbage_text(stripped):
return ''
return stripped
# ---------------------------------------------------------------------------
# Narrow-column OCR helpers (Proposal B)
# Cell-First OCR (v2) — each cell cropped and OCR'd in isolation
# ---------------------------------------------------------------------------
def _ocr_cell_crop(
row_idx: int,
col_idx: int,
row: RowGeometry,
col: PageRegion,
ocr_img: np.ndarray,
img_bgr: Optional[np.ndarray],
img_w: int,
img_h: int,
engine_name: str,
lang: str,
lang_map: Dict[str, str],
) -> Dict[str, Any]:
"""OCR a single cell by cropping the exact column×row intersection.
No padding beyond cell boundaries → no neighbour bleeding.
"""
# Display bbox: exact column × row intersection
disp_x = col.x
disp_y = row.y
disp_w = col.width
disp_h = row.height
# Crop boundaries (clamped to image)
cx = max(0, disp_x)
cy = max(0, disp_y)
cw = min(disp_w, img_w - cx)
ch = min(disp_h, img_h - cy)
empty_cell = {
'cell_id': f"R{row_idx:02d}_C{col_idx}",
'row_index': row_idx,
'col_index': col_idx,
'col_type': col.type,
'text': '',
'confidence': 0.0,
'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h},
'bbox_pct': {
'x': round(disp_x / img_w * 100, 2) if img_w else 0,
'y': round(disp_y / img_h * 100, 2) if img_h else 0,
'w': round(disp_w / img_w * 100, 2) if img_w else 0,
'h': round(disp_h / img_h * 100, 2) if img_h else 0,
},
'ocr_engine': 'cell_crop_v2',
}
if cw <= 0 or ch <= 0:
return empty_cell
# --- Pixel-density check: skip truly empty cells ---
if ocr_img is not None:
crop = ocr_img[cy:cy + ch, cx:cx + cw]
if crop.size > 0:
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
if dark_ratio < 0.005:
return empty_cell
# --- Prepare crop for OCR ---
cell_lang = lang_map.get(col.type, lang)
psm = _select_psm_for_column(col.type, col.width, row.height)
text = ''
avg_conf = 0.0
used_engine = 'cell_crop_v2'
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
words = ocr_region_trocr(img_bgr, cell_region,
handwritten=(engine_name == "trocr-handwritten"))
elif engine_name == "lighton" and img_bgr is not None:
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
words = ocr_region_lighton(img_bgr, cell_region)
elif engine_name == "rapid" and img_bgr is not None:
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
words = ocr_region_rapid(img_bgr, cell_region)
else:
# Tesseract: upscale tiny crops for better recognition
if ocr_img is not None:
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
upscaled = _ensure_minimum_crop_size(crop_slice)
up_h, up_w = upscaled.shape[:2]
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm)
# Remap word positions back to original image coordinates
if words and (up_w != cw or up_h != ch):
sx = cw / max(up_w, 1)
sy = ch / max(up_h, 1)
for w in words:
w['left'] = int(w['left'] * sx) + cx
w['top'] = int(w['top'] * sy) + cy
w['width'] = int(w['width'] * sx)
w['height'] = int(w['height'] * sy)
elif words:
for w in words:
w['left'] += cx
w['top'] += cy
else:
words = []
# Filter low-confidence words
_MIN_WORD_CONF = 30
if words:
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
if words:
y_tol = max(15, ch)
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
# --- PSM 7 fallback for still-empty Tesseract cells ---
if not text.strip() and engine_name == "tesseract" and ocr_img is not None:
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
upscaled = _ensure_minimum_crop_size(crop_slice)
up_h, up_w = upscaled.shape[:2]
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7)
if psm7_words:
psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF]
if psm7_words:
p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10)
if p7_text.strip():
text = p7_text
avg_conf = round(
sum(w['conf'] for w in psm7_words) / len(psm7_words), 1
)
used_engine = 'cell_crop_v2_psm7'
# --- Noise filter ---
if text.strip():
text = _clean_cell_text_lite(text)
if not text:
avg_conf = 0.0
result = dict(empty_cell)
result['text'] = text
result['confidence'] = avg_conf
result['ocr_engine'] = used_engine
return result
def build_cell_grid_v2(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Cell-First Grid: crop each cell in isolation, then OCR.
Drop-in replacement for build_cell_grid() — same signature & return type.
No full-page word assignment; each cell is OCR'd from its own crop.
"""
# Resolve engine
use_rapid = False
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "auto":
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
engine_name = "rapid" if use_rapid else "tesseract"
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
else:
engine_name = "tesseract"
logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}'")
# Filter to content rows only
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows found")
return [], []
# Filter phantom rows (word_count=0) and artifact rows
before = len(content_rows)
content_rows = [r for r in content_rows if r.word_count > 0]
skipped = before - len(content_rows)
if skipped > 0:
logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)")
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows with words found")
return [], []
before_art = len(content_rows)
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
artifact_skipped = before_art - len(content_rows)
if artifact_skipped > 0:
logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows")
if not content_rows:
logger.warning("build_cell_grid_v2: no content rows after artifact filtering")
return [], []
# Filter columns
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
logger.warning("build_cell_grid_v2: no usable columns found")
return [], []
# Heal row gaps
_heal_row_gaps(
content_rows,
top_bound=min(c.y for c in relevant_cols),
bottom_bound=max(c.y + c.height for c in relevant_cols),
)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
for ci, c in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
# --- Parallel OCR with ThreadPoolExecutor ---
# Tesseract is single-threaded per call, so we benefit from parallelism.
# ~40 rows × 4 cols = 160 cells, ~50% empty (density skip) → ~80 OCR calls.
cells: List[Dict[str, Any]] = []
cell_tasks = []
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
cell_tasks.append((row_idx, col_idx, row, col))
max_workers = 4 if engine_name == "tesseract" else 2
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
_ocr_cell_crop,
ri, ci, row, col,
ocr_img, img_bgr, img_w, img_h,
engine_name, lang, lang_map,
): (ri, ci)
for ri, ci, row, col in cell_tasks
}
for future in as_completed(futures):
try:
cell = future.result()
cells.append(cell)
except Exception as e:
ri, ci = futures[future]
logger.error(f"build_cell_grid_v2: cell R{ri:02d}_C{ci} failed: {e}")
# Sort cells by (row_index, col_index) since futures complete out of order
cells.sort(key=lambda c: (c['row_index'], c['col_index']))
# Remove all-empty rows
rows_with_text: set = set()
for cell in cells:
if cell['text'].strip():
rows_with_text.add(cell['row_index'])
before_filter = len(cells)
cells = [c for c in cells if c['row_index'] in rows_with_text]
empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1)
if empty_rows_removed > 0:
logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows")
logger.info(f"build_cell_grid_v2: {len(cells)} cells from "
f"{len(content_rows)} rows × {len(relevant_cols)} columns, "
f"engine={engine_name}")
return cells, columns_meta
def build_cell_grid_v2_streaming(
ocr_img: np.ndarray,
column_regions: List[PageRegion],
row_geometries: List[RowGeometry],
img_w: int,
img_h: int,
lang: str = "eng+deu",
ocr_engine: str = "auto",
img_bgr: Optional[np.ndarray] = None,
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
"""Streaming variant of build_cell_grid_v2 — yields each cell as OCR'd.
Yields:
(cell_dict, columns_meta, total_cells)
"""
# Resolve engine
use_rapid = False
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
engine_name = ocr_engine
elif ocr_engine == "auto":
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
engine_name = "rapid" if use_rapid else "tesseract"
elif ocr_engine == "rapid":
if not RAPIDOCR_AVAILABLE:
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
else:
use_rapid = True
engine_name = "rapid" if use_rapid else "tesseract"
else:
engine_name = "tesseract"
content_rows = [r for r in row_geometries if r.row_type == 'content']
if not content_rows:
return
content_rows = [r for r in content_rows if r.word_count > 0]
if not content_rows:
return
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
'margin_bottom', 'margin_left', 'margin_right'}
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
if not relevant_cols:
return
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
if not content_rows:
return
_heal_row_gaps(
content_rows,
top_bound=min(c.y for c in relevant_cols),
bottom_bound=max(c.y + c.height for c in relevant_cols),
)
relevant_cols.sort(key=lambda c: c.x)
columns_meta = [
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
for ci, c in enumerate(relevant_cols)
]
lang_map = {
'column_en': 'eng',
'column_de': 'deu',
'column_example': 'eng+deu',
}
total_cells = len(content_rows) * len(relevant_cols)
for row_idx, row in enumerate(content_rows):
for col_idx, col in enumerate(relevant_cols):
cell = _ocr_cell_crop(
row_idx, col_idx, row, col,
ocr_img, img_bgr, img_w, img_h,
engine_name, lang, lang_map,
)
yield cell, columns_meta, total_cells
# ---------------------------------------------------------------------------
# Narrow-column OCR helpers (Proposal B) — DEPRECATED (kept for legacy build_cell_grid)
# ---------------------------------------------------------------------------
def _compute_cell_padding(col_width: int, img_w: int) -> int:

View File

@@ -32,6 +32,7 @@ from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
DocumentTypeResult,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
@@ -43,6 +44,8 @@ from cv_vocab_pipeline import (
analyze_layout_by_words,
build_cell_grid,
build_cell_grid_streaming,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
build_word_grid,
classify_column_types,
create_layout_image,
@@ -50,6 +53,7 @@ from cv_vocab_pipeline import (
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_document_type,
detect_row_geometry,
expand_narrow_columns,
_apply_shear,
@@ -759,6 +763,54 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after dewarp (clean image available).
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
result = detect_document_type(ocr_img, dewarped_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
return {"session_id": session_id, **result_dict}
# ---------------------------------------------------------------------------
# Column Detection Endpoints (Step 3)
# ---------------------------------------------------------------------------
@@ -1196,8 +1248,10 @@ async def detect_words(
for r in row_result["rows"]
]
# Re-populate row.words from cached full-page Tesseract words.
# Word-lookup in _ocr_single_cell needs these to avoid re-running OCR.
# Cell-First OCR (v2): no full-page word re-population needed.
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
# so populate from cached words if available (just for counting).
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
@@ -1209,8 +1263,6 @@ async def detect_words(
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
# words['top'] is relative to content-ROI top_y.
# row.y is absolute. Convert: row_y_rel = row.y - top_y.
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
@@ -1240,15 +1292,15 @@ async def detect_words(
},
)
# --- Non-streaming path (unchanged) ---
# --- Non-streaming path ---
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Build generic cell grid
cells, columns_meta = build_cell_grid(
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
)
@@ -1358,7 +1410,7 @@ async def _word_stream_generator(
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
for cell, cols_meta, total in build_cell_grid_streaming(
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):

View File

@@ -64,7 +64,9 @@ async def init_ocr_pipeline_tables():
await conn.execute("""
ALTER TABLE ocr_pipeline_sessions
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
ADD COLUMN IF NOT EXISTS doc_type_result JSONB
""")
@@ -88,6 +90,7 @@ async def create_session_db(
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png)
@@ -102,6 +105,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
SELECT id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
@@ -146,9 +150,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
'clean_png', 'handwriting_removal_meta',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'word_result', 'ground_truth', 'auto_shear_degrees',
'doc_type', 'doc_type_result',
}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result'}
for key, value in kwargs.items():
if key in allowed_fields:
@@ -174,6 +179,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
created_at, updated_at
""", *values)
@@ -229,7 +235,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])

View File

@@ -25,7 +25,9 @@ from dataclasses import asdict
# Import module under test
from cv_vocab_pipeline import (
ColumnGeometry,
DocumentTypeResult,
PageRegion,
RowGeometry,
VocabRow,
PipelineResult,
deskew_image,
@@ -48,9 +50,12 @@ from cv_vocab_pipeline import (
CV_PIPELINE_AVAILABLE,
_is_noise_tail_token,
_clean_cell_text,
_clean_cell_text_lite,
_is_phonetic_only_text,
_merge_phonetic_continuation_rows,
_merge_continuation_rows,
_ocr_cell_crop,
detect_document_type,
)
@@ -1566,6 +1571,167 @@ class TestCellsToVocabEntriesPageRef:
assert entries[0]['source_page'] == 'p.59'
# =============================================
# CELL-FIRST OCR (v2) TESTS
# =============================================
class TestCleanCellTextLite:
"""Tests for _clean_cell_text_lite() — simplified noise filter."""
def test_empty_string(self):
assert _clean_cell_text_lite('') == ''
def test_whitespace_only(self):
assert _clean_cell_text_lite(' ') == ''
def test_real_word_passes(self):
assert _clean_cell_text_lite('hello') == 'hello'
def test_sentence_passes(self):
assert _clean_cell_text_lite('to have dinner') == 'to have dinner'
def test_garbage_text_cleared(self):
"""Garbage text (no dictionary words) should be cleared."""
assert _clean_cell_text_lite('xqzjk') == ''
def test_no_real_word_cleared(self):
"""Single chars with no real word (2+ letters) cleared."""
assert _clean_cell_text_lite('3') == ''
assert _clean_cell_text_lite('|') == ''
def test_known_abbreviation_kept(self):
"""Known abbreviations should pass through."""
assert _clean_cell_text_lite('sth') == 'sth'
assert _clean_cell_text_lite('eg') == 'eg'
def test_no_trailing_noise_stripping(self):
"""Unlike _clean_cell_text, lite does NOT strip trailing tokens.
Since cells are isolated, all tokens are legitimate."""
result = _clean_cell_text_lite('apple tree')
assert result == 'apple tree'
def test_page_reference(self):
"""Page references like p.60 should pass."""
# 'p' is a known abbreviation
assert _clean_cell_text_lite('p.60') != ''
class TestOcrCellCrop:
"""Tests for _ocr_cell_crop() — isolated cell OCR."""
def test_empty_cell_pixel_density(self):
"""Cells with very few dark pixels should return empty text."""
# All white image → no text
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
word_count=1, words=[{'text': 'a'}])
col = PageRegion(type='column_en', x=50, y=0, width=200, height=400)
result = _ocr_cell_crop(
0, 0, row, col, ocr_img, None, 600, 400,
'tesseract', 'eng+deu', {'column_en': 'eng'},
)
assert result['text'] == ''
assert result['cell_id'] == 'R00_C0'
assert result['col_type'] == 'column_en'
def test_zero_width_cell(self):
"""Zero-width cells should return empty."""
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
word_count=1, words=[])
col = PageRegion(type='column_en', x=50, y=0, width=0, height=400)
result = _ocr_cell_crop(
0, 0, row, col, ocr_img, None, 600, 400,
'tesseract', 'eng+deu', {},
)
assert result['text'] == ''
def test_bbox_calculation(self):
"""Check bbox_px and bbox_pct are correct."""
ocr_img = np.ones((1000, 2000), dtype=np.uint8) * 255
row = RowGeometry(index=0, x=0, y=100, width=2000, height=50,
word_count=1, words=[{'text': 'test'}])
col = PageRegion(type='column_de', x=400, y=0, width=600, height=1000)
result = _ocr_cell_crop(
0, 0, row, col, ocr_img, None, 2000, 1000,
'tesseract', 'eng+deu', {'column_de': 'deu'},
)
assert result['bbox_px'] == {'x': 400, 'y': 100, 'w': 600, 'h': 50}
assert result['bbox_pct']['x'] == 20.0 # 400/2000*100
assert result['bbox_pct']['y'] == 10.0 # 100/1000*100
class TestDetectDocumentType:
"""Tests for detect_document_type() — image-based classification."""
def test_empty_image(self):
"""Empty image should default to full_text."""
empty = np.array([], dtype=np.uint8).reshape(0, 0)
result = detect_document_type(empty, empty)
assert result.doc_type == 'full_text'
assert result.pipeline == 'full_page'
def test_table_image_detected(self):
"""Image with clear column gaps and row gaps → table."""
# Create 600x400 binary image with 3 columns separated by white gaps
img = np.ones((400, 600), dtype=np.uint8) * 255
# Column 1: x=20..170
for y in range(30, 370, 20):
img[y:y+10, 20:170] = 0
# Gap: x=170..210 (white)
# Column 2: x=210..370
for y in range(30, 370, 20):
img[y:y+10, 210:370] = 0
# Gap: x=370..410 (white)
# Column 3: x=410..580
for y in range(30, 370, 20):
img[y:y+10, 410:580] = 0
bgr = np.stack([img, img, img], axis=-1)
result = detect_document_type(img, bgr)
assert result.doc_type in ('vocab_table', 'generic_table')
assert result.pipeline == 'cell_first'
assert result.confidence >= 0.5
def test_fulltext_image_detected(self):
"""Uniform text without column gaps → full_text."""
img = np.ones((400, 600), dtype=np.uint8) * 255
# Uniform text lines across full width (no column gaps)
for y in range(30, 370, 15):
img[y:y+8, 30:570] = 0
bgr = np.stack([img, img, img], axis=-1)
result = detect_document_type(img, bgr)
assert result.doc_type == 'full_text'
assert result.pipeline == 'full_page'
assert 'columns' in result.skip_steps
assert 'rows' in result.skip_steps
def test_result_has_features(self):
"""Result should contain debug features."""
img = np.ones((200, 300), dtype=np.uint8) * 255
bgr = np.stack([img, img, img], axis=-1)
result = detect_document_type(img, bgr)
assert 'vertical_gaps' in result.features
assert 'row_gaps' in result.features
assert 'density_mean' in result.features
assert 'density_std' in result.features
def test_document_type_result_dataclass(self):
"""DocumentTypeResult dataclass should initialize correctly."""
r = DocumentTypeResult(
doc_type='vocab_table',
confidence=0.9,
pipeline='cell_first',
)
assert r.doc_type == 'vocab_table'
assert r.skip_steps == []
assert r.features == {}
# =============================================
# RUN TESTS
# =============================================