feat: cell-first OCR + document type detection + dynamic pipeline steps
Cell-First OCR (v2): Each cell is cropped and OCR'd in isolation, eliminating neighbour bleeding (e.g. "to", "ps" in marker columns). Uses ThreadPoolExecutor for parallel Tesseract calls. Document type detection: Classifies pages as vocab_table, full_text, or generic_table using projection profiles (<2s, no OCR needed). Frontend dynamically skips columns/rows steps for full-text pages. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti
|
||||
import { StepLlmReview } from '@/components/ocr-pipeline/StepLlmReview'
|
||||
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
|
||||
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
|
||||
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types'
|
||||
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem, type DocumentTypeResult } from './types'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
@@ -23,6 +23,7 @@ export default function OcrPipelinePage() {
|
||||
const [loadingSessions, setLoadingSessions] = useState(true)
|
||||
const [editingName, setEditingName] = useState<string | null>(null)
|
||||
const [editNameValue, setEditNameValue] = useState('')
|
||||
const [docTypeResult, setDocTypeResult] = useState<DocumentTypeResult | null>(null)
|
||||
const [steps, setSteps] = useState<PipelineStep[]>(
|
||||
PIPELINE_STEPS.map((s, i) => ({
|
||||
...s,
|
||||
@@ -59,16 +60,23 @@ export default function OcrPipelinePage() {
|
||||
setSessionId(sid)
|
||||
setSessionName(data.name || data.filename || '')
|
||||
|
||||
// Restore doc type result if available
|
||||
const savedDocType: DocumentTypeResult | null = data.doc_type_result || null
|
||||
setDocTypeResult(savedDocType)
|
||||
|
||||
// Determine which step to jump to based on current_step
|
||||
const dbStep = data.current_step || 1
|
||||
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
|
||||
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
|
||||
const uiStep = Math.max(0, dbStep - 1)
|
||||
const skipSteps = savedDocType?.skip_steps || []
|
||||
|
||||
setSteps(
|
||||
PIPELINE_STEPS.map((s, i) => ({
|
||||
...s,
|
||||
status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
|
||||
status: skipSteps.includes(s.id)
|
||||
? 'skipped'
|
||||
: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
|
||||
})),
|
||||
)
|
||||
setCurrentStep(uiStep)
|
||||
@@ -84,6 +92,7 @@ export default function OcrPipelinePage() {
|
||||
if (sessionId === sid) {
|
||||
setSessionId(null)
|
||||
setCurrentStep(0)
|
||||
setDocTypeResult(null)
|
||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -123,16 +132,28 @@ export default function OcrPipelinePage() {
|
||||
}
|
||||
|
||||
const handleNext = () => {
|
||||
if (currentStep < steps.length - 1) {
|
||||
if (currentStep >= steps.length - 1) return
|
||||
|
||||
// Find the next non-skipped step
|
||||
const skipSteps = docTypeResult?.skip_steps || []
|
||||
let nextStep = currentStep + 1
|
||||
while (nextStep < steps.length && skipSteps.includes(PIPELINE_STEPS[nextStep]?.id)) {
|
||||
nextStep++
|
||||
}
|
||||
if (nextStep >= steps.length) nextStep = steps.length - 1
|
||||
|
||||
setSteps((prev) =>
|
||||
prev.map((s, i) => {
|
||||
if (i === currentStep) return { ...s, status: 'completed' }
|
||||
if (i === currentStep + 1) return { ...s, status: 'active' }
|
||||
if (i === nextStep) return { ...s, status: 'active' }
|
||||
// Mark skipped steps between current and next
|
||||
if (i > currentStep && i < nextStep && skipSteps.includes(PIPELINE_STEPS[i]?.id)) {
|
||||
return { ...s, status: 'skipped' }
|
||||
}
|
||||
return s
|
||||
}),
|
||||
)
|
||||
setCurrentStep((prev) => prev + 1)
|
||||
}
|
||||
setCurrentStep(nextStep)
|
||||
}
|
||||
|
||||
const handleDeskewComplete = (sid: string) => {
|
||||
@@ -142,10 +163,69 @@ export default function OcrPipelinePage() {
|
||||
handleNext()
|
||||
}
|
||||
|
||||
const handleDewarpNext = async () => {
|
||||
// Auto-detect document type after dewarp, then advance
|
||||
if (sessionId) {
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-type`,
|
||||
{ method: 'POST' },
|
||||
)
|
||||
if (res.ok) {
|
||||
const data: DocumentTypeResult = await res.json()
|
||||
setDocTypeResult(data)
|
||||
|
||||
// Mark skipped steps immediately
|
||||
const skipSteps = data.skip_steps || []
|
||||
if (skipSteps.length > 0) {
|
||||
setSteps((prev) =>
|
||||
prev.map((s) =>
|
||||
skipSteps.includes(s.id) ? { ...s, status: 'skipped' } : s,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Doc type detection failed:', e)
|
||||
// Not critical — continue without it
|
||||
}
|
||||
}
|
||||
handleNext()
|
||||
}
|
||||
|
||||
const handleDocTypeChange = (newDocType: DocumentTypeResult['doc_type']) => {
|
||||
if (!docTypeResult) return
|
||||
|
||||
// Build new skip_steps based on doc type
|
||||
let skipSteps: string[] = []
|
||||
if (newDocType === 'full_text') {
|
||||
skipSteps = ['columns', 'rows']
|
||||
}
|
||||
// vocab_table and generic_table: no skips
|
||||
|
||||
const updated: DocumentTypeResult = {
|
||||
...docTypeResult,
|
||||
doc_type: newDocType,
|
||||
skip_steps: skipSteps,
|
||||
pipeline: newDocType === 'full_text' ? 'full_page' : 'cell_first',
|
||||
}
|
||||
setDocTypeResult(updated)
|
||||
|
||||
// Update step statuses
|
||||
setSteps((prev) =>
|
||||
prev.map((s) => {
|
||||
if (skipSteps.includes(s.id)) return { ...s, status: 'skipped' as const }
|
||||
if (s.status === 'skipped') return { ...s, status: 'pending' as const }
|
||||
return s
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
const handleNewSession = () => {
|
||||
setSessionId(null)
|
||||
setSessionName('')
|
||||
setCurrentStep(0)
|
||||
setDocTypeResult(null)
|
||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||
}
|
||||
|
||||
@@ -188,7 +268,7 @@ export default function OcrPipelinePage() {
|
||||
case 0:
|
||||
return <StepDeskew sessionId={sessionId} onNext={handleDeskewComplete} />
|
||||
case 1:
|
||||
return <StepDewarp sessionId={sessionId} onNext={handleNext} />
|
||||
return <StepDewarp sessionId={sessionId} onNext={handleDewarpNext} />
|
||||
case 2:
|
||||
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
|
||||
case 3:
|
||||
@@ -314,7 +394,14 @@ export default function OcrPipelinePage() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
<PipelineStepper steps={steps} currentStep={currentStep} onStepClick={handleStepClick} onReprocess={sessionId ? reprocessFromStep : undefined} />
|
||||
<PipelineStepper
|
||||
steps={steps}
|
||||
currentStep={currentStep}
|
||||
onStepClick={handleStepClick}
|
||||
onReprocess={sessionId ? reprocessFromStep : undefined}
|
||||
docTypeResult={docTypeResult}
|
||||
onDocTypeChange={handleDocTypeChange}
|
||||
/>
|
||||
|
||||
<div className="min-h-[400px]">{renderStep()}</div>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed'
|
||||
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed' | 'skipped'
|
||||
|
||||
export interface PipelineStep {
|
||||
id: string
|
||||
@@ -17,6 +17,15 @@ export interface SessionListItem {
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface DocumentTypeResult {
|
||||
doc_type: 'vocab_table' | 'full_text' | 'generic_table'
|
||||
confidence: number
|
||||
pipeline: 'cell_first' | 'full_page'
|
||||
skip_steps: string[]
|
||||
features?: Record<string, unknown>
|
||||
duration_seconds?: number
|
||||
}
|
||||
|
||||
export interface SessionInfo {
|
||||
session_id: string
|
||||
filename: string
|
||||
@@ -30,6 +39,7 @@ export interface SessionInfo {
|
||||
column_result?: ColumnResult
|
||||
row_result?: RowResult
|
||||
word_result?: GridResult
|
||||
doc_type_result?: DocumentTypeResult
|
||||
}
|
||||
|
||||
export interface DeskewResult {
|
||||
|
||||
@@ -1,29 +1,48 @@
|
||||
'use client'
|
||||
|
||||
import { PipelineStep } from '@/app/(admin)/ai/ocr-pipeline/types'
|
||||
import { PipelineStep, DocumentTypeResult } from '@/app/(admin)/ai/ocr-pipeline/types'
|
||||
|
||||
const DOC_TYPE_LABELS: Record<string, string> = {
|
||||
vocab_table: 'Vokabeltabelle',
|
||||
full_text: 'Volltext',
|
||||
generic_table: 'Tabelle',
|
||||
}
|
||||
|
||||
interface PipelineStepperProps {
|
||||
steps: PipelineStep[]
|
||||
currentStep: number
|
||||
onStepClick: (index: number) => void
|
||||
onReprocess?: (index: number) => void
|
||||
docTypeResult?: DocumentTypeResult | null
|
||||
onDocTypeChange?: (docType: DocumentTypeResult['doc_type']) => void
|
||||
}
|
||||
|
||||
export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }: PipelineStepperProps) {
|
||||
export function PipelineStepper({
|
||||
steps,
|
||||
currentStep,
|
||||
onStepClick,
|
||||
onReprocess,
|
||||
docTypeResult,
|
||||
onDocTypeChange,
|
||||
}: PipelineStepperProps) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between px-4 py-3 bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700">
|
||||
{steps.map((step, index) => {
|
||||
const isActive = index === currentStep
|
||||
const isCompleted = step.status === 'completed'
|
||||
const isFailed = step.status === 'failed'
|
||||
const isClickable = index <= currentStep || isCompleted
|
||||
const isSkipped = step.status === 'skipped'
|
||||
const isClickable = (index <= currentStep || isCompleted) && !isSkipped
|
||||
|
||||
return (
|
||||
<div key={step.id} className="flex items-center">
|
||||
{index > 0 && (
|
||||
<div
|
||||
className={`h-0.5 w-8 mx-1 ${
|
||||
index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
|
||||
isSkipped
|
||||
? 'bg-gray-200 dark:bg-gray-700 border-t border-dashed border-gray-400'
|
||||
: index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
@@ -32,7 +51,9 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
||||
onClick={() => isClickable && onStepClick(index)}
|
||||
disabled={!isClickable}
|
||||
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-full text-sm font-medium transition-all ${
|
||||
isActive
|
||||
isSkipped
|
||||
? 'bg-gray-100 text-gray-400 dark:bg-gray-800 dark:text-gray-600 line-through'
|
||||
: isActive
|
||||
? 'bg-teal-100 text-teal-700 dark:bg-teal-900/40 dark:text-teal-300 ring-2 ring-teal-400'
|
||||
: isCompleted
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||
@@ -42,7 +63,7 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
||||
} ${isClickable ? 'cursor-pointer hover:opacity-80' : 'cursor-default'}`}
|
||||
>
|
||||
<span className="text-base">
|
||||
{isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
|
||||
{isSkipped ? '-' : isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
|
||||
</span>
|
||||
<span className="hidden sm:inline">{step.name}</span>
|
||||
<span className="sm:hidden">{index + 1}</span>
|
||||
@@ -62,5 +83,33 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Document type badge */}
|
||||
{docTypeResult && (
|
||||
<div className="flex items-center gap-2 px-4 py-2 bg-blue-50 dark:bg-blue-900/20 rounded-lg border border-blue-200 dark:border-blue-800 text-sm">
|
||||
<span className="text-blue-600 dark:text-blue-400 font-medium">
|
||||
Dokumenttyp:
|
||||
</span>
|
||||
{onDocTypeChange ? (
|
||||
<select
|
||||
value={docTypeResult.doc_type}
|
||||
onChange={(e) => onDocTypeChange(e.target.value as DocumentTypeResult['doc_type'])}
|
||||
className="bg-white dark:bg-gray-800 border border-blue-300 dark:border-blue-700 rounded px-2 py-0.5 text-sm text-blue-700 dark:text-blue-300"
|
||||
>
|
||||
<option value="vocab_table">Vokabeltabelle</option>
|
||||
<option value="generic_table">Tabelle (generisch)</option>
|
||||
<option value="full_text">Volltext</option>
|
||||
</select>
|
||||
) : (
|
||||
<span className="text-blue-700 dark:text-blue-300">
|
||||
{DOC_TYPE_LABELS[docTypeResult.doc_type] || docTypeResult.doc_type}
|
||||
</span>
|
||||
)}
|
||||
<span className="text-blue-400 dark:text-blue-500 text-xs">
|
||||
({Math.round(docTypeResult.confidence * 100)}% Konfidenz)
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
@@ -159,6 +160,16 @@ class PipelineResult:
|
||||
image_height: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTypeResult:
|
||||
"""Result of automatic document type detection."""
|
||||
doc_type: str # 'vocab_table' | 'full_text' | 'generic_table'
|
||||
confidence: float # 0.0-1.0
|
||||
pipeline: str # 'cell_first' | 'full_page'
|
||||
skip_steps: List[str] = field(default_factory=list) # e.g. ['columns', 'rows']
|
||||
features: Dict[str, Any] = field(default_factory=dict) # debug info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stage 1: High-Resolution PDF Rendering
|
||||
# =============================================================================
|
||||
@@ -966,6 +977,164 @@ def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray:
|
||||
return _apply_shear(img, -shear_degrees)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Type Detection
|
||||
# =============================================================================
|
||||
|
||||
def detect_document_type(ocr_img: np.ndarray, img_bgr: np.ndarray) -> DocumentTypeResult:
|
||||
"""Detect whether the page is a vocab table, generic table, or full text.
|
||||
|
||||
Uses projection profiles and text density analysis — no OCR required.
|
||||
Runs in < 2 seconds.
|
||||
|
||||
Args:
|
||||
ocr_img: Binarized grayscale image (for projection profiles).
|
||||
img_bgr: BGR color image.
|
||||
|
||||
Returns:
|
||||
DocumentTypeResult with doc_type, confidence, pipeline, skip_steps.
|
||||
"""
|
||||
if ocr_img is None or ocr_img.size == 0:
|
||||
return DocumentTypeResult(
|
||||
doc_type='full_text', confidence=0.5, pipeline='full_page',
|
||||
skip_steps=['columns', 'rows'],
|
||||
features={'error': 'empty image'},
|
||||
)
|
||||
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
# --- 1. Vertical projection profile → detect column gaps ---
|
||||
# Sum dark pixels along each column (x-axis). Gaps = valleys in the profile.
|
||||
# Invert: dark pixels on white background → high values = text.
|
||||
vert_proj = np.sum(ocr_img < 128, axis=0).astype(float)
|
||||
|
||||
# Smooth the profile to avoid noise spikes
|
||||
kernel_size = max(3, w // 100)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
vert_smooth = np.convolve(vert_proj, np.ones(kernel_size) / kernel_size, mode='same')
|
||||
|
||||
# Find significant vertical gaps (columns of near-zero text density)
|
||||
# A gap must be at least 1% of image width and have < 5% of max density
|
||||
max_density = max(vert_smooth.max(), 1)
|
||||
gap_threshold = max_density * 0.05
|
||||
min_gap_width = max(5, w // 100)
|
||||
|
||||
in_gap = False
|
||||
gap_count = 0
|
||||
gap_start = 0
|
||||
vert_gaps = []
|
||||
|
||||
for x in range(w):
|
||||
if vert_smooth[x] < gap_threshold:
|
||||
if not in_gap:
|
||||
in_gap = True
|
||||
gap_start = x
|
||||
else:
|
||||
if in_gap:
|
||||
gap_width = x - gap_start
|
||||
if gap_width >= min_gap_width:
|
||||
gap_count += 1
|
||||
vert_gaps.append((gap_start, x, gap_width))
|
||||
in_gap = False
|
||||
|
||||
# Filter out margin gaps (within 10% of image edges)
|
||||
margin_threshold = w * 0.10
|
||||
internal_gaps = [g for g in vert_gaps if g[0] > margin_threshold and g[1] < w - margin_threshold]
|
||||
internal_gap_count = len(internal_gaps)
|
||||
|
||||
# --- 2. Horizontal projection profile → detect row gaps ---
|
||||
horiz_proj = np.sum(ocr_img < 128, axis=1).astype(float)
|
||||
h_kernel = max(3, h // 200)
|
||||
if h_kernel % 2 == 0:
|
||||
h_kernel += 1
|
||||
horiz_smooth = np.convolve(horiz_proj, np.ones(h_kernel) / h_kernel, mode='same')
|
||||
|
||||
h_max = max(horiz_smooth.max(), 1)
|
||||
h_gap_threshold = h_max * 0.05
|
||||
min_row_gap = max(3, h // 200)
|
||||
|
||||
row_gap_count = 0
|
||||
in_gap = False
|
||||
for y in range(h):
|
||||
if horiz_smooth[y] < h_gap_threshold:
|
||||
if not in_gap:
|
||||
in_gap = True
|
||||
gap_start = y
|
||||
else:
|
||||
if in_gap:
|
||||
if y - gap_start >= min_row_gap:
|
||||
row_gap_count += 1
|
||||
in_gap = False
|
||||
|
||||
# --- 3. Text density distribution (4×4 grid) ---
|
||||
grid_rows, grid_cols = 4, 4
|
||||
cell_h, cell_w = h // grid_rows, w // grid_cols
|
||||
densities = []
|
||||
for gr in range(grid_rows):
|
||||
for gc in range(grid_cols):
|
||||
cell = ocr_img[gr * cell_h:(gr + 1) * cell_h,
|
||||
gc * cell_w:(gc + 1) * cell_w]
|
||||
if cell.size > 0:
|
||||
d = float(np.count_nonzero(cell < 128)) / cell.size
|
||||
densities.append(d)
|
||||
|
||||
density_std = float(np.std(densities)) if densities else 0
|
||||
density_mean = float(np.mean(densities)) if densities else 0
|
||||
|
||||
features = {
|
||||
'vertical_gaps': gap_count,
|
||||
'internal_vertical_gaps': internal_gap_count,
|
||||
'vertical_gap_details': [(g[0], g[1], g[2]) for g in vert_gaps[:10]],
|
||||
'row_gaps': row_gap_count,
|
||||
'density_mean': round(density_mean, 4),
|
||||
'density_std': round(density_std, 4),
|
||||
'image_size': (w, h),
|
||||
}
|
||||
|
||||
# --- 4. Decision tree ---
|
||||
# Use internal_gap_count (excludes margin gaps) for column detection.
|
||||
if internal_gap_count >= 2 and row_gap_count >= 5:
|
||||
# Multiple internal vertical gaps + many row gaps → table
|
||||
confidence = min(0.95, 0.7 + internal_gap_count * 0.05 + row_gap_count * 0.005)
|
||||
return DocumentTypeResult(
|
||||
doc_type='vocab_table',
|
||||
confidence=round(confidence, 2),
|
||||
pipeline='cell_first',
|
||||
skip_steps=[],
|
||||
features=features,
|
||||
)
|
||||
elif internal_gap_count >= 1 and row_gap_count >= 3:
|
||||
# Some internal structure, likely a table
|
||||
confidence = min(0.85, 0.5 + internal_gap_count * 0.1 + row_gap_count * 0.01)
|
||||
return DocumentTypeResult(
|
||||
doc_type='generic_table',
|
||||
confidence=round(confidence, 2),
|
||||
pipeline='cell_first',
|
||||
skip_steps=[],
|
||||
features=features,
|
||||
)
|
||||
elif internal_gap_count == 0:
|
||||
# No internal column gaps → full text (regardless of density)
|
||||
confidence = min(0.95, 0.8 + (1 - min(density_std, 0.1)) * 0.15)
|
||||
return DocumentTypeResult(
|
||||
doc_type='full_text',
|
||||
confidence=round(confidence, 2),
|
||||
pipeline='full_page',
|
||||
skip_steps=['columns', 'rows'],
|
||||
features=features,
|
||||
)
|
||||
else:
|
||||
# Ambiguous — default to vocab_table (most common use case)
|
||||
return DocumentTypeResult(
|
||||
doc_type='vocab_table',
|
||||
confidence=0.5,
|
||||
pipeline='cell_first',
|
||||
skip_steps=[],
|
||||
features=features,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stage 4: Dual Image Preparation
|
||||
# =============================================================================
|
||||
@@ -4481,8 +4650,395 @@ def _clean_cell_text(text: str) -> str:
|
||||
return ' '.join(tokens)
|
||||
|
||||
|
||||
def _clean_cell_text_lite(text: str) -> str:
|
||||
"""Simplified noise filter for cell-first OCR (isolated cell crops).
|
||||
|
||||
Since each cell is OCR'd in isolation (no neighbour content visible),
|
||||
trailing-noise stripping is unnecessary. Only 2 filters remain:
|
||||
|
||||
1. No real alphabetic word (>= 2 letters) and not a known abbreviation → empty.
|
||||
2. Entire text is garbage (no dictionary word) → empty.
|
||||
"""
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return ''
|
||||
|
||||
# --- Filter 1: No real word at all ---
|
||||
if not _RE_REAL_WORD.search(stripped):
|
||||
alpha_only = ''.join(_RE_ALPHA.findall(stripped)).lower()
|
||||
if alpha_only not in _KNOWN_ABBREVIATIONS:
|
||||
return ''
|
||||
|
||||
# --- Filter 2: Entire text is garbage ---
|
||||
if _is_garbage_text(stripped):
|
||||
return ''
|
||||
|
||||
return stripped
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Narrow-column OCR helpers (Proposal B)
|
||||
# Cell-First OCR (v2) — each cell cropped and OCR'd in isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ocr_cell_crop(
|
||||
row_idx: int,
|
||||
col_idx: int,
|
||||
row: RowGeometry,
|
||||
col: PageRegion,
|
||||
ocr_img: np.ndarray,
|
||||
img_bgr: Optional[np.ndarray],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
engine_name: str,
|
||||
lang: str,
|
||||
lang_map: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
"""OCR a single cell by cropping the exact column×row intersection.
|
||||
|
||||
No padding beyond cell boundaries → no neighbour bleeding.
|
||||
"""
|
||||
# Display bbox: exact column × row intersection
|
||||
disp_x = col.x
|
||||
disp_y = row.y
|
||||
disp_w = col.width
|
||||
disp_h = row.height
|
||||
|
||||
# Crop boundaries (clamped to image)
|
||||
cx = max(0, disp_x)
|
||||
cy = max(0, disp_y)
|
||||
cw = min(disp_w, img_w - cx)
|
||||
ch = min(disp_h, img_h - cy)
|
||||
|
||||
empty_cell = {
|
||||
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||
'row_index': row_idx,
|
||||
'col_index': col_idx,
|
||||
'col_type': col.type,
|
||||
'text': '',
|
||||
'confidence': 0.0,
|
||||
'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h},
|
||||
'bbox_pct': {
|
||||
'x': round(disp_x / img_w * 100, 2) if img_w else 0,
|
||||
'y': round(disp_y / img_h * 100, 2) if img_h else 0,
|
||||
'w': round(disp_w / img_w * 100, 2) if img_w else 0,
|
||||
'h': round(disp_h / img_h * 100, 2) if img_h else 0,
|
||||
},
|
||||
'ocr_engine': 'cell_crop_v2',
|
||||
}
|
||||
|
||||
if cw <= 0 or ch <= 0:
|
||||
return empty_cell
|
||||
|
||||
# --- Pixel-density check: skip truly empty cells ---
|
||||
if ocr_img is not None:
|
||||
crop = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||
if crop.size > 0:
|
||||
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
|
||||
if dark_ratio < 0.005:
|
||||
return empty_cell
|
||||
|
||||
# --- Prepare crop for OCR ---
|
||||
cell_lang = lang_map.get(col.type, lang)
|
||||
psm = _select_psm_for_column(col.type, col.width, row.height)
|
||||
text = ''
|
||||
avg_conf = 0.0
|
||||
used_engine = 'cell_crop_v2'
|
||||
|
||||
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
|
||||
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||
words = ocr_region_trocr(img_bgr, cell_region,
|
||||
handwritten=(engine_name == "trocr-handwritten"))
|
||||
elif engine_name == "lighton" and img_bgr is not None:
|
||||
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||
words = ocr_region_lighton(img_bgr, cell_region)
|
||||
elif engine_name == "rapid" and img_bgr is not None:
|
||||
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||
words = ocr_region_rapid(img_bgr, cell_region)
|
||||
else:
|
||||
# Tesseract: upscale tiny crops for better recognition
|
||||
if ocr_img is not None:
|
||||
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||
upscaled = _ensure_minimum_crop_size(crop_slice)
|
||||
up_h, up_w = upscaled.shape[:2]
|
||||
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
|
||||
words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm)
|
||||
# Remap word positions back to original image coordinates
|
||||
if words and (up_w != cw or up_h != ch):
|
||||
sx = cw / max(up_w, 1)
|
||||
sy = ch / max(up_h, 1)
|
||||
for w in words:
|
||||
w['left'] = int(w['left'] * sx) + cx
|
||||
w['top'] = int(w['top'] * sy) + cy
|
||||
w['width'] = int(w['width'] * sx)
|
||||
w['height'] = int(w['height'] * sy)
|
||||
elif words:
|
||||
for w in words:
|
||||
w['left'] += cx
|
||||
w['top'] += cy
|
||||
else:
|
||||
words = []
|
||||
|
||||
# Filter low-confidence words
|
||||
_MIN_WORD_CONF = 30
|
||||
if words:
|
||||
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
|
||||
|
||||
if words:
|
||||
y_tol = max(15, ch)
|
||||
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
|
||||
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
|
||||
|
||||
# --- PSM 7 fallback for still-empty Tesseract cells ---
|
||||
if not text.strip() and engine_name == "tesseract" and ocr_img is not None:
|
||||
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||
upscaled = _ensure_minimum_crop_size(crop_slice)
|
||||
up_h, up_w = upscaled.shape[:2]
|
||||
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
|
||||
psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7)
|
||||
if psm7_words:
|
||||
psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF]
|
||||
if psm7_words:
|
||||
p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10)
|
||||
if p7_text.strip():
|
||||
text = p7_text
|
||||
avg_conf = round(
|
||||
sum(w['conf'] for w in psm7_words) / len(psm7_words), 1
|
||||
)
|
||||
used_engine = 'cell_crop_v2_psm7'
|
||||
|
||||
# --- Noise filter ---
|
||||
if text.strip():
|
||||
text = _clean_cell_text_lite(text)
|
||||
if not text:
|
||||
avg_conf = 0.0
|
||||
|
||||
result = dict(empty_cell)
|
||||
result['text'] = text
|
||||
result['confidence'] = avg_conf
|
||||
result['ocr_engine'] = used_engine
|
||||
return result
|
||||
|
||||
|
||||
def build_cell_grid_v2(
|
||||
ocr_img: np.ndarray,
|
||||
column_regions: List[PageRegion],
|
||||
row_geometries: List[RowGeometry],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
lang: str = "eng+deu",
|
||||
ocr_engine: str = "auto",
|
||||
img_bgr: Optional[np.ndarray] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""Cell-First Grid: crop each cell in isolation, then OCR.
|
||||
|
||||
Drop-in replacement for build_cell_grid() — same signature & return type.
|
||||
No full-page word assignment; each cell is OCR'd from its own crop.
|
||||
"""
|
||||
# Resolve engine
|
||||
use_rapid = False
|
||||
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
|
||||
engine_name = ocr_engine
|
||||
elif ocr_engine == "auto":
|
||||
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||
engine_name = "rapid" if use_rapid else "tesseract"
|
||||
elif ocr_engine == "rapid":
|
||||
if not RAPIDOCR_AVAILABLE:
|
||||
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||
else:
|
||||
use_rapid = True
|
||||
engine_name = "rapid" if use_rapid else "tesseract"
|
||||
else:
|
||||
engine_name = "tesseract"
|
||||
|
||||
logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}'")
|
||||
|
||||
# Filter to content rows only
|
||||
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||
if not content_rows:
|
||||
logger.warning("build_cell_grid_v2: no content rows found")
|
||||
return [], []
|
||||
|
||||
# Filter phantom rows (word_count=0) and artifact rows
|
||||
before = len(content_rows)
|
||||
content_rows = [r for r in content_rows if r.word_count > 0]
|
||||
skipped = before - len(content_rows)
|
||||
if skipped > 0:
|
||||
logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)")
|
||||
if not content_rows:
|
||||
logger.warning("build_cell_grid_v2: no content rows with words found")
|
||||
return [], []
|
||||
|
||||
before_art = len(content_rows)
|
||||
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
|
||||
artifact_skipped = before_art - len(content_rows)
|
||||
if artifact_skipped > 0:
|
||||
logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows")
|
||||
if not content_rows:
|
||||
logger.warning("build_cell_grid_v2: no content rows after artifact filtering")
|
||||
return [], []
|
||||
|
||||
# Filter columns
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
|
||||
'margin_bottom', 'margin_left', 'margin_right'}
|
||||
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
|
||||
if not relevant_cols:
|
||||
logger.warning("build_cell_grid_v2: no usable columns found")
|
||||
return [], []
|
||||
|
||||
# Heal row gaps
|
||||
_heal_row_gaps(
|
||||
content_rows,
|
||||
top_bound=min(c.y for c in relevant_cols),
|
||||
bottom_bound=max(c.y + c.height for c in relevant_cols),
|
||||
)
|
||||
|
||||
relevant_cols.sort(key=lambda c: c.x)
|
||||
|
||||
columns_meta = [
|
||||
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
|
||||
for ci, c in enumerate(relevant_cols)
|
||||
]
|
||||
|
||||
lang_map = {
|
||||
'column_en': 'eng',
|
||||
'column_de': 'deu',
|
||||
'column_example': 'eng+deu',
|
||||
}
|
||||
|
||||
# --- Parallel OCR with ThreadPoolExecutor ---
|
||||
# Tesseract is single-threaded per call, so we benefit from parallelism.
|
||||
# ~40 rows × 4 cols = 160 cells, ~50% empty (density skip) → ~80 OCR calls.
|
||||
cells: List[Dict[str, Any]] = []
|
||||
cell_tasks = []
|
||||
|
||||
for row_idx, row in enumerate(content_rows):
|
||||
for col_idx, col in enumerate(relevant_cols):
|
||||
cell_tasks.append((row_idx, col_idx, row, col))
|
||||
|
||||
max_workers = 4 if engine_name == "tesseract" else 2
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
_ocr_cell_crop,
|
||||
ri, ci, row, col,
|
||||
ocr_img, img_bgr, img_w, img_h,
|
||||
engine_name, lang, lang_map,
|
||||
): (ri, ci)
|
||||
for ri, ci, row, col in cell_tasks
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
cell = future.result()
|
||||
cells.append(cell)
|
||||
except Exception as e:
|
||||
ri, ci = futures[future]
|
||||
logger.error(f"build_cell_grid_v2: cell R{ri:02d}_C{ci} failed: {e}")
|
||||
|
||||
# Sort cells by (row_index, col_index) since futures complete out of order
|
||||
cells.sort(key=lambda c: (c['row_index'], c['col_index']))
|
||||
|
||||
# Remove all-empty rows
|
||||
rows_with_text: set = set()
|
||||
for cell in cells:
|
||||
if cell['text'].strip():
|
||||
rows_with_text.add(cell['row_index'])
|
||||
before_filter = len(cells)
|
||||
cells = [c for c in cells if c['row_index'] in rows_with_text]
|
||||
empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1)
|
||||
if empty_rows_removed > 0:
|
||||
logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows")
|
||||
|
||||
logger.info(f"build_cell_grid_v2: {len(cells)} cells from "
|
||||
f"{len(content_rows)} rows × {len(relevant_cols)} columns, "
|
||||
f"engine={engine_name}")
|
||||
|
||||
return cells, columns_meta
|
||||
|
||||
|
||||
def build_cell_grid_v2_streaming(
|
||||
ocr_img: np.ndarray,
|
||||
column_regions: List[PageRegion],
|
||||
row_geometries: List[RowGeometry],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
lang: str = "eng+deu",
|
||||
ocr_engine: str = "auto",
|
||||
img_bgr: Optional[np.ndarray] = None,
|
||||
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
|
||||
"""Streaming variant of build_cell_grid_v2 — yields each cell as OCR'd.
|
||||
|
||||
Yields:
|
||||
(cell_dict, columns_meta, total_cells)
|
||||
"""
|
||||
# Resolve engine
|
||||
use_rapid = False
|
||||
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
|
||||
engine_name = ocr_engine
|
||||
elif ocr_engine == "auto":
|
||||
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||
engine_name = "rapid" if use_rapid else "tesseract"
|
||||
elif ocr_engine == "rapid":
|
||||
if not RAPIDOCR_AVAILABLE:
|
||||
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||
else:
|
||||
use_rapid = True
|
||||
engine_name = "rapid" if use_rapid else "tesseract"
|
||||
else:
|
||||
engine_name = "tesseract"
|
||||
|
||||
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||
if not content_rows:
|
||||
return
|
||||
|
||||
content_rows = [r for r in content_rows if r.word_count > 0]
|
||||
if not content_rows:
|
||||
return
|
||||
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
|
||||
'margin_bottom', 'margin_left', 'margin_right'}
|
||||
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
|
||||
if not relevant_cols:
|
||||
return
|
||||
|
||||
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
|
||||
if not content_rows:
|
||||
return
|
||||
|
||||
_heal_row_gaps(
|
||||
content_rows,
|
||||
top_bound=min(c.y for c in relevant_cols),
|
||||
bottom_bound=max(c.y + c.height for c in relevant_cols),
|
||||
)
|
||||
|
||||
relevant_cols.sort(key=lambda c: c.x)
|
||||
|
||||
columns_meta = [
|
||||
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
|
||||
for ci, c in enumerate(relevant_cols)
|
||||
]
|
||||
|
||||
lang_map = {
|
||||
'column_en': 'eng',
|
||||
'column_de': 'deu',
|
||||
'column_example': 'eng+deu',
|
||||
}
|
||||
|
||||
total_cells = len(content_rows) * len(relevant_cols)
|
||||
|
||||
for row_idx, row in enumerate(content_rows):
|
||||
for col_idx, col in enumerate(relevant_cols):
|
||||
cell = _ocr_cell_crop(
|
||||
row_idx, col_idx, row, col,
|
||||
ocr_img, img_bgr, img_w, img_h,
|
||||
engine_name, lang, lang_map,
|
||||
)
|
||||
yield cell, columns_meta, total_cells
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Narrow-column OCR helpers (Proposal B) — DEPRECATED (kept for legacy build_cell_grid)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _compute_cell_padding(col_width: int, img_w: int) -> int:
|
||||
|
||||
@@ -32,6 +32,7 @@ from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
DocumentTypeResult,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
@@ -43,6 +44,8 @@ from cv_vocab_pipeline import (
|
||||
analyze_layout_by_words,
|
||||
build_cell_grid,
|
||||
build_cell_grid_streaming,
|
||||
build_cell_grid_v2,
|
||||
build_cell_grid_v2_streaming,
|
||||
build_word_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
@@ -50,6 +53,7 @@ from cv_vocab_pipeline import (
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_document_type,
|
||||
detect_row_geometry,
|
||||
expand_narrow_columns,
|
||||
_apply_shear,
|
||||
@@ -759,6 +763,54 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document Type Detection (between Dewarp and Columns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/detect-type")
|
||||
async def detect_type(session_id: str):
|
||||
"""Detect document type (vocab_table, full_text, generic_table).
|
||||
|
||||
Should be called after dewarp (clean image available).
|
||||
Stores result in session for frontend to decide pipeline flow.
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Dewarp must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
result = detect_document_type(ocr_img, dewarped_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
result_dict = {
|
||||
"doc_type": result.doc_type,
|
||||
"confidence": result.confidence,
|
||||
"pipeline": result.pipeline,
|
||||
"skip_steps": result.skip_steps,
|
||||
"features": result.features,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
doc_type=result.doc_type,
|
||||
doc_type_result=result_dict,
|
||||
)
|
||||
|
||||
cached["doc_type_result"] = result_dict
|
||||
|
||||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||
|
||||
return {"session_id": session_id, **result_dict}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Column Detection Endpoints (Step 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1196,8 +1248,10 @@ async def detect_words(
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
# Re-populate row.words from cached full-page Tesseract words.
|
||||
# Word-lookup in _ocr_single_cell needs these to avoid re-running OCR.
|
||||
# Cell-First OCR (v2): no full-page word re-population needed.
|
||||
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
|
||||
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
|
||||
# so populate from cached words if available (just for counting).
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is None:
|
||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||
@@ -1209,8 +1263,6 @@ async def detect_words(
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
if word_dicts:
|
||||
# words['top'] is relative to content-ROI top_y.
|
||||
# row.y is absolute. Convert: row_y_rel = row.y - top_y.
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
if content_bounds:
|
||||
_lx, _rx, top_y, _by = content_bounds
|
||||
@@ -1240,15 +1292,15 @@ async def detect_words(
|
||||
},
|
||||
)
|
||||
|
||||
# --- Non-streaming path (unchanged) ---
|
||||
# --- Non-streaming path ---
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Build generic cell grid
|
||||
cells, columns_meta = build_cell_grid(
|
||||
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
|
||||
cells, columns_meta = build_cell_grid_v2(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
)
|
||||
@@ -1358,7 +1410,7 @@ async def _word_stream_generator(
|
||||
all_cells: List[Dict[str, Any]] = []
|
||||
cell_idx = 0
|
||||
|
||||
for cell, cols_meta, total in build_cell_grid_streaming(
|
||||
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
):
|
||||
|
||||
@@ -64,7 +64,9 @@ async def init_ocr_pipeline_tables():
|
||||
await conn.execute("""
|
||||
ALTER TABLE ocr_pipeline_sessions
|
||||
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
||||
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB
|
||||
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
|
||||
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
|
||||
ADD COLUMN IF NOT EXISTS doc_type_result JSONB
|
||||
""")
|
||||
|
||||
|
||||
@@ -88,6 +90,7 @@ async def create_session_db(
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
created_at, updated_at
|
||||
""", uuid.UUID(session_id), name, filename, original_png)
|
||||
|
||||
@@ -102,6 +105,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
SELECT id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
@@ -146,9 +150,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
||||
'clean_png', 'handwriting_removal_meta',
|
||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||
'doc_type', 'doc_type_result',
|
||||
}
|
||||
|
||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'}
|
||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result'}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
@@ -174,6 +179,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
created_at, updated_at
|
||||
""", *values)
|
||||
|
||||
@@ -229,7 +235,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||
result[key] = result[key].isoformat()
|
||||
|
||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
|
||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result']:
|
||||
if key in result and result[key] is not None:
|
||||
if isinstance(result[key], str):
|
||||
result[key] = json.loads(result[key])
|
||||
|
||||
@@ -25,7 +25,9 @@ from dataclasses import asdict
|
||||
# Import module under test
|
||||
from cv_vocab_pipeline import (
|
||||
ColumnGeometry,
|
||||
DocumentTypeResult,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
VocabRow,
|
||||
PipelineResult,
|
||||
deskew_image,
|
||||
@@ -48,9 +50,12 @@ from cv_vocab_pipeline import (
|
||||
CV_PIPELINE_AVAILABLE,
|
||||
_is_noise_tail_token,
|
||||
_clean_cell_text,
|
||||
_clean_cell_text_lite,
|
||||
_is_phonetic_only_text,
|
||||
_merge_phonetic_continuation_rows,
|
||||
_merge_continuation_rows,
|
||||
_ocr_cell_crop,
|
||||
detect_document_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -1566,6 +1571,167 @@ class TestCellsToVocabEntriesPageRef:
|
||||
assert entries[0]['source_page'] == 'p.59'
|
||||
|
||||
|
||||
# =============================================
|
||||
# CELL-FIRST OCR (v2) TESTS
|
||||
# =============================================
|
||||
|
||||
class TestCleanCellTextLite:
|
||||
"""Tests for _clean_cell_text_lite() — simplified noise filter."""
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _clean_cell_text_lite('') == ''
|
||||
|
||||
def test_whitespace_only(self):
|
||||
assert _clean_cell_text_lite(' ') == ''
|
||||
|
||||
def test_real_word_passes(self):
|
||||
assert _clean_cell_text_lite('hello') == 'hello'
|
||||
|
||||
def test_sentence_passes(self):
|
||||
assert _clean_cell_text_lite('to have dinner') == 'to have dinner'
|
||||
|
||||
def test_garbage_text_cleared(self):
|
||||
"""Garbage text (no dictionary words) should be cleared."""
|
||||
assert _clean_cell_text_lite('xqzjk') == ''
|
||||
|
||||
def test_no_real_word_cleared(self):
|
||||
"""Single chars with no real word (2+ letters) cleared."""
|
||||
assert _clean_cell_text_lite('3') == ''
|
||||
assert _clean_cell_text_lite('|') == ''
|
||||
|
||||
def test_known_abbreviation_kept(self):
|
||||
"""Known abbreviations should pass through."""
|
||||
assert _clean_cell_text_lite('sth') == 'sth'
|
||||
assert _clean_cell_text_lite('eg') == 'eg'
|
||||
|
||||
def test_no_trailing_noise_stripping(self):
|
||||
"""Unlike _clean_cell_text, lite does NOT strip trailing tokens.
|
||||
Since cells are isolated, all tokens are legitimate."""
|
||||
result = _clean_cell_text_lite('apple tree')
|
||||
assert result == 'apple tree'
|
||||
|
||||
def test_page_reference(self):
|
||||
"""Page references like p.60 should pass."""
|
||||
# 'p' is a known abbreviation
|
||||
assert _clean_cell_text_lite('p.60') != ''
|
||||
|
||||
|
||||
class TestOcrCellCrop:
|
||||
"""Tests for _ocr_cell_crop() — isolated cell OCR."""
|
||||
|
||||
def test_empty_cell_pixel_density(self):
|
||||
"""Cells with very few dark pixels should return empty text."""
|
||||
# All white image → no text
|
||||
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
|
||||
word_count=1, words=[{'text': 'a'}])
|
||||
col = PageRegion(type='column_en', x=50, y=0, width=200, height=400)
|
||||
|
||||
result = _ocr_cell_crop(
|
||||
0, 0, row, col, ocr_img, None, 600, 400,
|
||||
'tesseract', 'eng+deu', {'column_en': 'eng'},
|
||||
)
|
||||
assert result['text'] == ''
|
||||
assert result['cell_id'] == 'R00_C0'
|
||||
assert result['col_type'] == 'column_en'
|
||||
|
||||
def test_zero_width_cell(self):
|
||||
"""Zero-width cells should return empty."""
|
||||
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
|
||||
word_count=1, words=[])
|
||||
col = PageRegion(type='column_en', x=50, y=0, width=0, height=400)
|
||||
|
||||
result = _ocr_cell_crop(
|
||||
0, 0, row, col, ocr_img, None, 600, 400,
|
||||
'tesseract', 'eng+deu', {},
|
||||
)
|
||||
assert result['text'] == ''
|
||||
|
||||
def test_bbox_calculation(self):
|
||||
"""Check bbox_px and bbox_pct are correct."""
|
||||
ocr_img = np.ones((1000, 2000), dtype=np.uint8) * 255
|
||||
row = RowGeometry(index=0, x=0, y=100, width=2000, height=50,
|
||||
word_count=1, words=[{'text': 'test'}])
|
||||
col = PageRegion(type='column_de', x=400, y=0, width=600, height=1000)
|
||||
|
||||
result = _ocr_cell_crop(
|
||||
0, 0, row, col, ocr_img, None, 2000, 1000,
|
||||
'tesseract', 'eng+deu', {'column_de': 'deu'},
|
||||
)
|
||||
assert result['bbox_px'] == {'x': 400, 'y': 100, 'w': 600, 'h': 50}
|
||||
assert result['bbox_pct']['x'] == 20.0 # 400/2000*100
|
||||
assert result['bbox_pct']['y'] == 10.0 # 100/1000*100
|
||||
|
||||
|
||||
class TestDetectDocumentType:
|
||||
"""Tests for detect_document_type() — image-based classification."""
|
||||
|
||||
def test_empty_image(self):
|
||||
"""Empty image should default to full_text."""
|
||||
empty = np.array([], dtype=np.uint8).reshape(0, 0)
|
||||
result = detect_document_type(empty, empty)
|
||||
assert result.doc_type == 'full_text'
|
||||
assert result.pipeline == 'full_page'
|
||||
|
||||
def test_table_image_detected(self):
|
||||
"""Image with clear column gaps and row gaps → table."""
|
||||
# Create 600x400 binary image with 3 columns separated by white gaps
|
||||
img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||
# Column 1: x=20..170
|
||||
for y in range(30, 370, 20):
|
||||
img[y:y+10, 20:170] = 0
|
||||
# Gap: x=170..210 (white)
|
||||
# Column 2: x=210..370
|
||||
for y in range(30, 370, 20):
|
||||
img[y:y+10, 210:370] = 0
|
||||
# Gap: x=370..410 (white)
|
||||
# Column 3: x=410..580
|
||||
for y in range(30, 370, 20):
|
||||
img[y:y+10, 410:580] = 0
|
||||
|
||||
bgr = np.stack([img, img, img], axis=-1)
|
||||
result = detect_document_type(img, bgr)
|
||||
assert result.doc_type in ('vocab_table', 'generic_table')
|
||||
assert result.pipeline == 'cell_first'
|
||||
assert result.confidence >= 0.5
|
||||
|
||||
def test_fulltext_image_detected(self):
|
||||
"""Uniform text without column gaps → full_text."""
|
||||
img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||
# Uniform text lines across full width (no column gaps)
|
||||
for y in range(30, 370, 15):
|
||||
img[y:y+8, 30:570] = 0
|
||||
|
||||
bgr = np.stack([img, img, img], axis=-1)
|
||||
result = detect_document_type(img, bgr)
|
||||
assert result.doc_type == 'full_text'
|
||||
assert result.pipeline == 'full_page'
|
||||
assert 'columns' in result.skip_steps
|
||||
assert 'rows' in result.skip_steps
|
||||
|
||||
def test_result_has_features(self):
|
||||
"""Result should contain debug features."""
|
||||
img = np.ones((200, 300), dtype=np.uint8) * 255
|
||||
bgr = np.stack([img, img, img], axis=-1)
|
||||
result = detect_document_type(img, bgr)
|
||||
assert 'vertical_gaps' in result.features
|
||||
assert 'row_gaps' in result.features
|
||||
assert 'density_mean' in result.features
|
||||
assert 'density_std' in result.features
|
||||
|
||||
def test_document_type_result_dataclass(self):
|
||||
"""DocumentTypeResult dataclass should initialize correctly."""
|
||||
r = DocumentTypeResult(
|
||||
doc_type='vocab_table',
|
||||
confidence=0.9,
|
||||
pipeline='cell_first',
|
||||
)
|
||||
assert r.doc_type == 'vocab_table'
|
||||
assert r.skip_steps == []
|
||||
assert r.features == {}
|
||||
|
||||
|
||||
# =============================================
|
||||
# RUN TESTS
|
||||
# =============================================
|
||||
|
||||
Reference in New Issue
Block a user