feat: cell-first OCR + document type detection + dynamic pipeline steps
Cell-First OCR (v2): Each cell is cropped and OCR'd in isolation, eliminating neighbour bleeding (e.g. "to", "ps" in marker columns). Uses ThreadPoolExecutor for parallel Tesseract calls. Document type detection: Classifies pages as vocab_table, full_text, or generic_table using projection profiles (<2s, no OCR needed). Frontend dynamically skips columns/rows steps for full-text pages. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti
|
|||||||
import { StepLlmReview } from '@/components/ocr-pipeline/StepLlmReview'
|
import { StepLlmReview } from '@/components/ocr-pipeline/StepLlmReview'
|
||||||
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
|
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
|
||||||
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
|
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
|
||||||
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types'
|
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem, type DocumentTypeResult } from './types'
|
||||||
|
|
||||||
const KLAUSUR_API = '/klausur-api'
|
const KLAUSUR_API = '/klausur-api'
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ export default function OcrPipelinePage() {
|
|||||||
const [loadingSessions, setLoadingSessions] = useState(true)
|
const [loadingSessions, setLoadingSessions] = useState(true)
|
||||||
const [editingName, setEditingName] = useState<string | null>(null)
|
const [editingName, setEditingName] = useState<string | null>(null)
|
||||||
const [editNameValue, setEditNameValue] = useState('')
|
const [editNameValue, setEditNameValue] = useState('')
|
||||||
|
const [docTypeResult, setDocTypeResult] = useState<DocumentTypeResult | null>(null)
|
||||||
const [steps, setSteps] = useState<PipelineStep[]>(
|
const [steps, setSteps] = useState<PipelineStep[]>(
|
||||||
PIPELINE_STEPS.map((s, i) => ({
|
PIPELINE_STEPS.map((s, i) => ({
|
||||||
...s,
|
...s,
|
||||||
@@ -59,16 +60,23 @@ export default function OcrPipelinePage() {
|
|||||||
setSessionId(sid)
|
setSessionId(sid)
|
||||||
setSessionName(data.name || data.filename || '')
|
setSessionName(data.name || data.filename || '')
|
||||||
|
|
||||||
|
// Restore doc type result if available
|
||||||
|
const savedDocType: DocumentTypeResult | null = data.doc_type_result || null
|
||||||
|
setDocTypeResult(savedDocType)
|
||||||
|
|
||||||
// Determine which step to jump to based on current_step
|
// Determine which step to jump to based on current_step
|
||||||
const dbStep = data.current_step || 1
|
const dbStep = data.current_step || 1
|
||||||
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
|
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
|
||||||
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
|
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
|
||||||
const uiStep = Math.max(0, dbStep - 1)
|
const uiStep = Math.max(0, dbStep - 1)
|
||||||
|
const skipSteps = savedDocType?.skip_steps || []
|
||||||
|
|
||||||
setSteps(
|
setSteps(
|
||||||
PIPELINE_STEPS.map((s, i) => ({
|
PIPELINE_STEPS.map((s, i) => ({
|
||||||
...s,
|
...s,
|
||||||
status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
|
status: skipSteps.includes(s.id)
|
||||||
|
? 'skipped'
|
||||||
|
: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
setCurrentStep(uiStep)
|
setCurrentStep(uiStep)
|
||||||
@@ -84,6 +92,7 @@ export default function OcrPipelinePage() {
|
|||||||
if (sessionId === sid) {
|
if (sessionId === sid) {
|
||||||
setSessionId(null)
|
setSessionId(null)
|
||||||
setCurrentStep(0)
|
setCurrentStep(0)
|
||||||
|
setDocTypeResult(null)
|
||||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@@ -123,16 +132,28 @@ export default function OcrPipelinePage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleNext = () => {
|
const handleNext = () => {
|
||||||
if (currentStep < steps.length - 1) {
|
if (currentStep >= steps.length - 1) return
|
||||||
|
|
||||||
|
// Find the next non-skipped step
|
||||||
|
const skipSteps = docTypeResult?.skip_steps || []
|
||||||
|
let nextStep = currentStep + 1
|
||||||
|
while (nextStep < steps.length && skipSteps.includes(PIPELINE_STEPS[nextStep]?.id)) {
|
||||||
|
nextStep++
|
||||||
|
}
|
||||||
|
if (nextStep >= steps.length) nextStep = steps.length - 1
|
||||||
|
|
||||||
setSteps((prev) =>
|
setSteps((prev) =>
|
||||||
prev.map((s, i) => {
|
prev.map((s, i) => {
|
||||||
if (i === currentStep) return { ...s, status: 'completed' }
|
if (i === currentStep) return { ...s, status: 'completed' }
|
||||||
if (i === currentStep + 1) return { ...s, status: 'active' }
|
if (i === nextStep) return { ...s, status: 'active' }
|
||||||
|
// Mark skipped steps between current and next
|
||||||
|
if (i > currentStep && i < nextStep && skipSteps.includes(PIPELINE_STEPS[i]?.id)) {
|
||||||
|
return { ...s, status: 'skipped' }
|
||||||
|
}
|
||||||
return s
|
return s
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
setCurrentStep((prev) => prev + 1)
|
setCurrentStep(nextStep)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleDeskewComplete = (sid: string) => {
|
const handleDeskewComplete = (sid: string) => {
|
||||||
@@ -142,10 +163,69 @@ export default function OcrPipelinePage() {
|
|||||||
handleNext()
|
handleNext()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleDewarpNext = async () => {
|
||||||
|
// Auto-detect document type after dewarp, then advance
|
||||||
|
if (sessionId) {
|
||||||
|
try {
|
||||||
|
const res = await fetch(
|
||||||
|
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-type`,
|
||||||
|
{ method: 'POST' },
|
||||||
|
)
|
||||||
|
if (res.ok) {
|
||||||
|
const data: DocumentTypeResult = await res.json()
|
||||||
|
setDocTypeResult(data)
|
||||||
|
|
||||||
|
// Mark skipped steps immediately
|
||||||
|
const skipSteps = data.skip_steps || []
|
||||||
|
if (skipSteps.length > 0) {
|
||||||
|
setSteps((prev) =>
|
||||||
|
prev.map((s) =>
|
||||||
|
skipSteps.includes(s.id) ? { ...s, status: 'skipped' } : s,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Doc type detection failed:', e)
|
||||||
|
// Not critical — continue without it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handleNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDocTypeChange = (newDocType: DocumentTypeResult['doc_type']) => {
|
||||||
|
if (!docTypeResult) return
|
||||||
|
|
||||||
|
// Build new skip_steps based on doc type
|
||||||
|
let skipSteps: string[] = []
|
||||||
|
if (newDocType === 'full_text') {
|
||||||
|
skipSteps = ['columns', 'rows']
|
||||||
|
}
|
||||||
|
// vocab_table and generic_table: no skips
|
||||||
|
|
||||||
|
const updated: DocumentTypeResult = {
|
||||||
|
...docTypeResult,
|
||||||
|
doc_type: newDocType,
|
||||||
|
skip_steps: skipSteps,
|
||||||
|
pipeline: newDocType === 'full_text' ? 'full_page' : 'cell_first',
|
||||||
|
}
|
||||||
|
setDocTypeResult(updated)
|
||||||
|
|
||||||
|
// Update step statuses
|
||||||
|
setSteps((prev) =>
|
||||||
|
prev.map((s) => {
|
||||||
|
if (skipSteps.includes(s.id)) return { ...s, status: 'skipped' as const }
|
||||||
|
if (s.status === 'skipped') return { ...s, status: 'pending' as const }
|
||||||
|
return s
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const handleNewSession = () => {
|
const handleNewSession = () => {
|
||||||
setSessionId(null)
|
setSessionId(null)
|
||||||
setSessionName('')
|
setSessionName('')
|
||||||
setCurrentStep(0)
|
setCurrentStep(0)
|
||||||
|
setDocTypeResult(null)
|
||||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +268,7 @@ export default function OcrPipelinePage() {
|
|||||||
case 0:
|
case 0:
|
||||||
return <StepDeskew sessionId={sessionId} onNext={handleDeskewComplete} />
|
return <StepDeskew sessionId={sessionId} onNext={handleDeskewComplete} />
|
||||||
case 1:
|
case 1:
|
||||||
return <StepDewarp sessionId={sessionId} onNext={handleNext} />
|
return <StepDewarp sessionId={sessionId} onNext={handleDewarpNext} />
|
||||||
case 2:
|
case 2:
|
||||||
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
|
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
|
||||||
case 3:
|
case 3:
|
||||||
@@ -314,7 +394,14 @@ export default function OcrPipelinePage() {
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<PipelineStepper steps={steps} currentStep={currentStep} onStepClick={handleStepClick} onReprocess={sessionId ? reprocessFromStep : undefined} />
|
<PipelineStepper
|
||||||
|
steps={steps}
|
||||||
|
currentStep={currentStep}
|
||||||
|
onStepClick={handleStepClick}
|
||||||
|
onReprocess={sessionId ? reprocessFromStep : undefined}
|
||||||
|
docTypeResult={docTypeResult}
|
||||||
|
onDocTypeChange={handleDocTypeChange}
|
||||||
|
/>
|
||||||
|
|
||||||
<div className="min-h-[400px]">{renderStep()}</div>
|
<div className="min-h-[400px]">{renderStep()}</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed'
|
export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed' | 'skipped'
|
||||||
|
|
||||||
export interface PipelineStep {
|
export interface PipelineStep {
|
||||||
id: string
|
id: string
|
||||||
@@ -17,6 +17,15 @@ export interface SessionListItem {
|
|||||||
updated_at?: string
|
updated_at?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface DocumentTypeResult {
|
||||||
|
doc_type: 'vocab_table' | 'full_text' | 'generic_table'
|
||||||
|
confidence: number
|
||||||
|
pipeline: 'cell_first' | 'full_page'
|
||||||
|
skip_steps: string[]
|
||||||
|
features?: Record<string, unknown>
|
||||||
|
duration_seconds?: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface SessionInfo {
|
export interface SessionInfo {
|
||||||
session_id: string
|
session_id: string
|
||||||
filename: string
|
filename: string
|
||||||
@@ -30,6 +39,7 @@ export interface SessionInfo {
|
|||||||
column_result?: ColumnResult
|
column_result?: ColumnResult
|
||||||
row_result?: RowResult
|
row_result?: RowResult
|
||||||
word_result?: GridResult
|
word_result?: GridResult
|
||||||
|
doc_type_result?: DocumentTypeResult
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface DeskewResult {
|
export interface DeskewResult {
|
||||||
|
|||||||
@@ -1,29 +1,48 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { PipelineStep } from '@/app/(admin)/ai/ocr-pipeline/types'
|
import { PipelineStep, DocumentTypeResult } from '@/app/(admin)/ai/ocr-pipeline/types'
|
||||||
|
|
||||||
|
const DOC_TYPE_LABELS: Record<string, string> = {
|
||||||
|
vocab_table: 'Vokabeltabelle',
|
||||||
|
full_text: 'Volltext',
|
||||||
|
generic_table: 'Tabelle',
|
||||||
|
}
|
||||||
|
|
||||||
interface PipelineStepperProps {
|
interface PipelineStepperProps {
|
||||||
steps: PipelineStep[]
|
steps: PipelineStep[]
|
||||||
currentStep: number
|
currentStep: number
|
||||||
onStepClick: (index: number) => void
|
onStepClick: (index: number) => void
|
||||||
onReprocess?: (index: number) => void
|
onReprocess?: (index: number) => void
|
||||||
|
docTypeResult?: DocumentTypeResult | null
|
||||||
|
onDocTypeChange?: (docType: DocumentTypeResult['doc_type']) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }: PipelineStepperProps) {
|
export function PipelineStepper({
|
||||||
|
steps,
|
||||||
|
currentStep,
|
||||||
|
onStepClick,
|
||||||
|
onReprocess,
|
||||||
|
docTypeResult,
|
||||||
|
onDocTypeChange,
|
||||||
|
}: PipelineStepperProps) {
|
||||||
return (
|
return (
|
||||||
|
<div className="space-y-2">
|
||||||
<div className="flex items-center justify-between px-4 py-3 bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700">
|
<div className="flex items-center justify-between px-4 py-3 bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700">
|
||||||
{steps.map((step, index) => {
|
{steps.map((step, index) => {
|
||||||
const isActive = index === currentStep
|
const isActive = index === currentStep
|
||||||
const isCompleted = step.status === 'completed'
|
const isCompleted = step.status === 'completed'
|
||||||
const isFailed = step.status === 'failed'
|
const isFailed = step.status === 'failed'
|
||||||
const isClickable = index <= currentStep || isCompleted
|
const isSkipped = step.status === 'skipped'
|
||||||
|
const isClickable = (index <= currentStep || isCompleted) && !isSkipped
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={step.id} className="flex items-center">
|
<div key={step.id} className="flex items-center">
|
||||||
{index > 0 && (
|
{index > 0 && (
|
||||||
<div
|
<div
|
||||||
className={`h-0.5 w-8 mx-1 ${
|
className={`h-0.5 w-8 mx-1 ${
|
||||||
index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
|
isSkipped
|
||||||
|
? 'bg-gray-200 dark:bg-gray-700 border-t border-dashed border-gray-400'
|
||||||
|
: index <= currentStep ? 'bg-teal-400' : 'bg-gray-300 dark:bg-gray-600'
|
||||||
}`}
|
}`}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
@@ -32,7 +51,9 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
|||||||
onClick={() => isClickable && onStepClick(index)}
|
onClick={() => isClickable && onStepClick(index)}
|
||||||
disabled={!isClickable}
|
disabled={!isClickable}
|
||||||
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-full text-sm font-medium transition-all ${
|
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-full text-sm font-medium transition-all ${
|
||||||
isActive
|
isSkipped
|
||||||
|
? 'bg-gray-100 text-gray-400 dark:bg-gray-800 dark:text-gray-600 line-through'
|
||||||
|
: isActive
|
||||||
? 'bg-teal-100 text-teal-700 dark:bg-teal-900/40 dark:text-teal-300 ring-2 ring-teal-400'
|
? 'bg-teal-100 text-teal-700 dark:bg-teal-900/40 dark:text-teal-300 ring-2 ring-teal-400'
|
||||||
: isCompleted
|
: isCompleted
|
||||||
? 'bg-green-100 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
? 'bg-green-100 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||||
@@ -42,7 +63,7 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
|||||||
} ${isClickable ? 'cursor-pointer hover:opacity-80' : 'cursor-default'}`}
|
} ${isClickable ? 'cursor-pointer hover:opacity-80' : 'cursor-default'}`}
|
||||||
>
|
>
|
||||||
<span className="text-base">
|
<span className="text-base">
|
||||||
{isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
|
{isSkipped ? '-' : isCompleted ? '\u2713' : isFailed ? '\u2717' : step.icon}
|
||||||
</span>
|
</span>
|
||||||
<span className="hidden sm:inline">{step.name}</span>
|
<span className="hidden sm:inline">{step.name}</span>
|
||||||
<span className="sm:hidden">{index + 1}</span>
|
<span className="sm:hidden">{index + 1}</span>
|
||||||
@@ -62,5 +83,33 @@ export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }
|
|||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Document type badge */}
|
||||||
|
{docTypeResult && (
|
||||||
|
<div className="flex items-center gap-2 px-4 py-2 bg-blue-50 dark:bg-blue-900/20 rounded-lg border border-blue-200 dark:border-blue-800 text-sm">
|
||||||
|
<span className="text-blue-600 dark:text-blue-400 font-medium">
|
||||||
|
Dokumenttyp:
|
||||||
|
</span>
|
||||||
|
{onDocTypeChange ? (
|
||||||
|
<select
|
||||||
|
value={docTypeResult.doc_type}
|
||||||
|
onChange={(e) => onDocTypeChange(e.target.value as DocumentTypeResult['doc_type'])}
|
||||||
|
className="bg-white dark:bg-gray-800 border border-blue-300 dark:border-blue-700 rounded px-2 py-0.5 text-sm text-blue-700 dark:text-blue-300"
|
||||||
|
>
|
||||||
|
<option value="vocab_table">Vokabeltabelle</option>
|
||||||
|
<option value="generic_table">Tabelle (generisch)</option>
|
||||||
|
<option value="full_text">Volltext</option>
|
||||||
|
</select>
|
||||||
|
) : (
|
||||||
|
<span className="text-blue-700 dark:text-blue-300">
|
||||||
|
{DOC_TYPE_LABELS[docTypeResult.doc_type] || docTypeResult.doc_type}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
<span className="text-blue-400 dark:text-blue-500 text-xs">
|
||||||
|
({Math.round(docTypeResult.confidence * 100)}% Konfidenz)
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -159,6 +160,16 @@ class PipelineResult:
|
|||||||
image_height: int = 0
|
image_height: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocumentTypeResult:
|
||||||
|
"""Result of automatic document type detection."""
|
||||||
|
doc_type: str # 'vocab_table' | 'full_text' | 'generic_table'
|
||||||
|
confidence: float # 0.0-1.0
|
||||||
|
pipeline: str # 'cell_first' | 'full_page'
|
||||||
|
skip_steps: List[str] = field(default_factory=list) # e.g. ['columns', 'rows']
|
||||||
|
features: Dict[str, Any] = field(default_factory=dict) # debug info
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Stage 1: High-Resolution PDF Rendering
|
# Stage 1: High-Resolution PDF Rendering
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -966,6 +977,164 @@ def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray:
|
|||||||
return _apply_shear(img, -shear_degrees)
|
return _apply_shear(img, -shear_degrees)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Document Type Detection
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def detect_document_type(ocr_img: np.ndarray, img_bgr: np.ndarray) -> DocumentTypeResult:
|
||||||
|
"""Detect whether the page is a vocab table, generic table, or full text.
|
||||||
|
|
||||||
|
Uses projection profiles and text density analysis — no OCR required.
|
||||||
|
Runs in < 2 seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ocr_img: Binarized grayscale image (for projection profiles).
|
||||||
|
img_bgr: BGR color image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocumentTypeResult with doc_type, confidence, pipeline, skip_steps.
|
||||||
|
"""
|
||||||
|
if ocr_img is None or ocr_img.size == 0:
|
||||||
|
return DocumentTypeResult(
|
||||||
|
doc_type='full_text', confidence=0.5, pipeline='full_page',
|
||||||
|
skip_steps=['columns', 'rows'],
|
||||||
|
features={'error': 'empty image'},
|
||||||
|
)
|
||||||
|
|
||||||
|
h, w = ocr_img.shape[:2]
|
||||||
|
|
||||||
|
# --- 1. Vertical projection profile → detect column gaps ---
|
||||||
|
# Sum dark pixels along each column (x-axis). Gaps = valleys in the profile.
|
||||||
|
# Invert: dark pixels on white background → high values = text.
|
||||||
|
vert_proj = np.sum(ocr_img < 128, axis=0).astype(float)
|
||||||
|
|
||||||
|
# Smooth the profile to avoid noise spikes
|
||||||
|
kernel_size = max(3, w // 100)
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
vert_smooth = np.convolve(vert_proj, np.ones(kernel_size) / kernel_size, mode='same')
|
||||||
|
|
||||||
|
# Find significant vertical gaps (columns of near-zero text density)
|
||||||
|
# A gap must be at least 1% of image width and have < 5% of max density
|
||||||
|
max_density = max(vert_smooth.max(), 1)
|
||||||
|
gap_threshold = max_density * 0.05
|
||||||
|
min_gap_width = max(5, w // 100)
|
||||||
|
|
||||||
|
in_gap = False
|
||||||
|
gap_count = 0
|
||||||
|
gap_start = 0
|
||||||
|
vert_gaps = []
|
||||||
|
|
||||||
|
for x in range(w):
|
||||||
|
if vert_smooth[x] < gap_threshold:
|
||||||
|
if not in_gap:
|
||||||
|
in_gap = True
|
||||||
|
gap_start = x
|
||||||
|
else:
|
||||||
|
if in_gap:
|
||||||
|
gap_width = x - gap_start
|
||||||
|
if gap_width >= min_gap_width:
|
||||||
|
gap_count += 1
|
||||||
|
vert_gaps.append((gap_start, x, gap_width))
|
||||||
|
in_gap = False
|
||||||
|
|
||||||
|
# Filter out margin gaps (within 10% of image edges)
|
||||||
|
margin_threshold = w * 0.10
|
||||||
|
internal_gaps = [g for g in vert_gaps if g[0] > margin_threshold and g[1] < w - margin_threshold]
|
||||||
|
internal_gap_count = len(internal_gaps)
|
||||||
|
|
||||||
|
# --- 2. Horizontal projection profile → detect row gaps ---
|
||||||
|
horiz_proj = np.sum(ocr_img < 128, axis=1).astype(float)
|
||||||
|
h_kernel = max(3, h // 200)
|
||||||
|
if h_kernel % 2 == 0:
|
||||||
|
h_kernel += 1
|
||||||
|
horiz_smooth = np.convolve(horiz_proj, np.ones(h_kernel) / h_kernel, mode='same')
|
||||||
|
|
||||||
|
h_max = max(horiz_smooth.max(), 1)
|
||||||
|
h_gap_threshold = h_max * 0.05
|
||||||
|
min_row_gap = max(3, h // 200)
|
||||||
|
|
||||||
|
row_gap_count = 0
|
||||||
|
in_gap = False
|
||||||
|
for y in range(h):
|
||||||
|
if horiz_smooth[y] < h_gap_threshold:
|
||||||
|
if not in_gap:
|
||||||
|
in_gap = True
|
||||||
|
gap_start = y
|
||||||
|
else:
|
||||||
|
if in_gap:
|
||||||
|
if y - gap_start >= min_row_gap:
|
||||||
|
row_gap_count += 1
|
||||||
|
in_gap = False
|
||||||
|
|
||||||
|
# --- 3. Text density distribution (4×4 grid) ---
|
||||||
|
grid_rows, grid_cols = 4, 4
|
||||||
|
cell_h, cell_w = h // grid_rows, w // grid_cols
|
||||||
|
densities = []
|
||||||
|
for gr in range(grid_rows):
|
||||||
|
for gc in range(grid_cols):
|
||||||
|
cell = ocr_img[gr * cell_h:(gr + 1) * cell_h,
|
||||||
|
gc * cell_w:(gc + 1) * cell_w]
|
||||||
|
if cell.size > 0:
|
||||||
|
d = float(np.count_nonzero(cell < 128)) / cell.size
|
||||||
|
densities.append(d)
|
||||||
|
|
||||||
|
density_std = float(np.std(densities)) if densities else 0
|
||||||
|
density_mean = float(np.mean(densities)) if densities else 0
|
||||||
|
|
||||||
|
features = {
|
||||||
|
'vertical_gaps': gap_count,
|
||||||
|
'internal_vertical_gaps': internal_gap_count,
|
||||||
|
'vertical_gap_details': [(g[0], g[1], g[2]) for g in vert_gaps[:10]],
|
||||||
|
'row_gaps': row_gap_count,
|
||||||
|
'density_mean': round(density_mean, 4),
|
||||||
|
'density_std': round(density_std, 4),
|
||||||
|
'image_size': (w, h),
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- 4. Decision tree ---
|
||||||
|
# Use internal_gap_count (excludes margin gaps) for column detection.
|
||||||
|
if internal_gap_count >= 2 and row_gap_count >= 5:
|
||||||
|
# Multiple internal vertical gaps + many row gaps → table
|
||||||
|
confidence = min(0.95, 0.7 + internal_gap_count * 0.05 + row_gap_count * 0.005)
|
||||||
|
return DocumentTypeResult(
|
||||||
|
doc_type='vocab_table',
|
||||||
|
confidence=round(confidence, 2),
|
||||||
|
pipeline='cell_first',
|
||||||
|
skip_steps=[],
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
elif internal_gap_count >= 1 and row_gap_count >= 3:
|
||||||
|
# Some internal structure, likely a table
|
||||||
|
confidence = min(0.85, 0.5 + internal_gap_count * 0.1 + row_gap_count * 0.01)
|
||||||
|
return DocumentTypeResult(
|
||||||
|
doc_type='generic_table',
|
||||||
|
confidence=round(confidence, 2),
|
||||||
|
pipeline='cell_first',
|
||||||
|
skip_steps=[],
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
elif internal_gap_count == 0:
|
||||||
|
# No internal column gaps → full text (regardless of density)
|
||||||
|
confidence = min(0.95, 0.8 + (1 - min(density_std, 0.1)) * 0.15)
|
||||||
|
return DocumentTypeResult(
|
||||||
|
doc_type='full_text',
|
||||||
|
confidence=round(confidence, 2),
|
||||||
|
pipeline='full_page',
|
||||||
|
skip_steps=['columns', 'rows'],
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Ambiguous — default to vocab_table (most common use case)
|
||||||
|
return DocumentTypeResult(
|
||||||
|
doc_type='vocab_table',
|
||||||
|
confidence=0.5,
|
||||||
|
pipeline='cell_first',
|
||||||
|
skip_steps=[],
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Stage 4: Dual Image Preparation
|
# Stage 4: Dual Image Preparation
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -4481,8 +4650,395 @@ def _clean_cell_text(text: str) -> str:
|
|||||||
return ' '.join(tokens)
|
return ' '.join(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_cell_text_lite(text: str) -> str:
|
||||||
|
"""Simplified noise filter for cell-first OCR (isolated cell crops).
|
||||||
|
|
||||||
|
Since each cell is OCR'd in isolation (no neighbour content visible),
|
||||||
|
trailing-noise stripping is unnecessary. Only 2 filters remain:
|
||||||
|
|
||||||
|
1. No real alphabetic word (>= 2 letters) and not a known abbreviation → empty.
|
||||||
|
2. Entire text is garbage (no dictionary word) → empty.
|
||||||
|
"""
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# --- Filter 1: No real word at all ---
|
||||||
|
if not _RE_REAL_WORD.search(stripped):
|
||||||
|
alpha_only = ''.join(_RE_ALPHA.findall(stripped)).lower()
|
||||||
|
if alpha_only not in _KNOWN_ABBREVIATIONS:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# --- Filter 2: Entire text is garbage ---
|
||||||
|
if _is_garbage_text(stripped):
|
||||||
|
return ''
|
||||||
|
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Narrow-column OCR helpers (Proposal B)
|
# Cell-First OCR (v2) — each cell cropped and OCR'd in isolation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _ocr_cell_crop(
|
||||||
|
row_idx: int,
|
||||||
|
col_idx: int,
|
||||||
|
row: RowGeometry,
|
||||||
|
col: PageRegion,
|
||||||
|
ocr_img: np.ndarray,
|
||||||
|
img_bgr: Optional[np.ndarray],
|
||||||
|
img_w: int,
|
||||||
|
img_h: int,
|
||||||
|
engine_name: str,
|
||||||
|
lang: str,
|
||||||
|
lang_map: Dict[str, str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""OCR a single cell by cropping the exact column×row intersection.
|
||||||
|
|
||||||
|
No padding beyond cell boundaries → no neighbour bleeding.
|
||||||
|
"""
|
||||||
|
# Display bbox: exact column × row intersection
|
||||||
|
disp_x = col.x
|
||||||
|
disp_y = row.y
|
||||||
|
disp_w = col.width
|
||||||
|
disp_h = row.height
|
||||||
|
|
||||||
|
# Crop boundaries (clamped to image)
|
||||||
|
cx = max(0, disp_x)
|
||||||
|
cy = max(0, disp_y)
|
||||||
|
cw = min(disp_w, img_w - cx)
|
||||||
|
ch = min(disp_h, img_h - cy)
|
||||||
|
|
||||||
|
empty_cell = {
|
||||||
|
'cell_id': f"R{row_idx:02d}_C{col_idx}",
|
||||||
|
'row_index': row_idx,
|
||||||
|
'col_index': col_idx,
|
||||||
|
'col_type': col.type,
|
||||||
|
'text': '',
|
||||||
|
'confidence': 0.0,
|
||||||
|
'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h},
|
||||||
|
'bbox_pct': {
|
||||||
|
'x': round(disp_x / img_w * 100, 2) if img_w else 0,
|
||||||
|
'y': round(disp_y / img_h * 100, 2) if img_h else 0,
|
||||||
|
'w': round(disp_w / img_w * 100, 2) if img_w else 0,
|
||||||
|
'h': round(disp_h / img_h * 100, 2) if img_h else 0,
|
||||||
|
},
|
||||||
|
'ocr_engine': 'cell_crop_v2',
|
||||||
|
}
|
||||||
|
|
||||||
|
if cw <= 0 or ch <= 0:
|
||||||
|
return empty_cell
|
||||||
|
|
||||||
|
# --- Pixel-density check: skip truly empty cells ---
|
||||||
|
if ocr_img is not None:
|
||||||
|
crop = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||||
|
if crop.size > 0:
|
||||||
|
dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size
|
||||||
|
if dark_ratio < 0.005:
|
||||||
|
return empty_cell
|
||||||
|
|
||||||
|
# --- Prepare crop for OCR ---
|
||||||
|
cell_lang = lang_map.get(col.type, lang)
|
||||||
|
psm = _select_psm_for_column(col.type, col.width, row.height)
|
||||||
|
text = ''
|
||||||
|
avg_conf = 0.0
|
||||||
|
used_engine = 'cell_crop_v2'
|
||||||
|
|
||||||
|
if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None:
|
||||||
|
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||||
|
words = ocr_region_trocr(img_bgr, cell_region,
|
||||||
|
handwritten=(engine_name == "trocr-handwritten"))
|
||||||
|
elif engine_name == "lighton" and img_bgr is not None:
|
||||||
|
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||||
|
words = ocr_region_lighton(img_bgr, cell_region)
|
||||||
|
elif engine_name == "rapid" and img_bgr is not None:
|
||||||
|
cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch)
|
||||||
|
words = ocr_region_rapid(img_bgr, cell_region)
|
||||||
|
else:
|
||||||
|
# Tesseract: upscale tiny crops for better recognition
|
||||||
|
if ocr_img is not None:
|
||||||
|
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||||
|
upscaled = _ensure_minimum_crop_size(crop_slice)
|
||||||
|
up_h, up_w = upscaled.shape[:2]
|
||||||
|
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
|
||||||
|
words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm)
|
||||||
|
# Remap word positions back to original image coordinates
|
||||||
|
if words and (up_w != cw or up_h != ch):
|
||||||
|
sx = cw / max(up_w, 1)
|
||||||
|
sy = ch / max(up_h, 1)
|
||||||
|
for w in words:
|
||||||
|
w['left'] = int(w['left'] * sx) + cx
|
||||||
|
w['top'] = int(w['top'] * sy) + cy
|
||||||
|
w['width'] = int(w['width'] * sx)
|
||||||
|
w['height'] = int(w['height'] * sy)
|
||||||
|
elif words:
|
||||||
|
for w in words:
|
||||||
|
w['left'] += cx
|
||||||
|
w['top'] += cy
|
||||||
|
else:
|
||||||
|
words = []
|
||||||
|
|
||||||
|
# Filter low-confidence words
|
||||||
|
_MIN_WORD_CONF = 30
|
||||||
|
if words:
|
||||||
|
words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF]
|
||||||
|
|
||||||
|
if words:
|
||||||
|
y_tol = max(15, ch)
|
||||||
|
text = _words_to_reading_order_text(words, y_tolerance_px=y_tol)
|
||||||
|
avg_conf = round(sum(w['conf'] for w in words) / len(words), 1)
|
||||||
|
|
||||||
|
# --- PSM 7 fallback for still-empty Tesseract cells ---
|
||||||
|
if not text.strip() and engine_name == "tesseract" and ocr_img is not None:
|
||||||
|
crop_slice = ocr_img[cy:cy + ch, cx:cx + cw]
|
||||||
|
upscaled = _ensure_minimum_crop_size(crop_slice)
|
||||||
|
up_h, up_w = upscaled.shape[:2]
|
||||||
|
tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h)
|
||||||
|
psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7)
|
||||||
|
if psm7_words:
|
||||||
|
psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF]
|
||||||
|
if psm7_words:
|
||||||
|
p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10)
|
||||||
|
if p7_text.strip():
|
||||||
|
text = p7_text
|
||||||
|
avg_conf = round(
|
||||||
|
sum(w['conf'] for w in psm7_words) / len(psm7_words), 1
|
||||||
|
)
|
||||||
|
used_engine = 'cell_crop_v2_psm7'
|
||||||
|
|
||||||
|
# --- Noise filter ---
|
||||||
|
if text.strip():
|
||||||
|
text = _clean_cell_text_lite(text)
|
||||||
|
if not text:
|
||||||
|
avg_conf = 0.0
|
||||||
|
|
||||||
|
result = dict(empty_cell)
|
||||||
|
result['text'] = text
|
||||||
|
result['confidence'] = avg_conf
|
||||||
|
result['ocr_engine'] = used_engine
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_cell_grid_v2(
|
||||||
|
ocr_img: np.ndarray,
|
||||||
|
column_regions: List[PageRegion],
|
||||||
|
row_geometries: List[RowGeometry],
|
||||||
|
img_w: int,
|
||||||
|
img_h: int,
|
||||||
|
lang: str = "eng+deu",
|
||||||
|
ocr_engine: str = "auto",
|
||||||
|
img_bgr: Optional[np.ndarray] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
"""Cell-First Grid: crop each cell in isolation, then OCR.
|
||||||
|
|
||||||
|
Drop-in replacement for build_cell_grid() — same signature & return type.
|
||||||
|
No full-page word assignment; each cell is OCR'd from its own crop.
|
||||||
|
"""
|
||||||
|
# Resolve engine
|
||||||
|
use_rapid = False
|
||||||
|
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
|
||||||
|
engine_name = ocr_engine
|
||||||
|
elif ocr_engine == "auto":
|
||||||
|
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||||
|
engine_name = "rapid" if use_rapid else "tesseract"
|
||||||
|
elif ocr_engine == "rapid":
|
||||||
|
if not RAPIDOCR_AVAILABLE:
|
||||||
|
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||||
|
else:
|
||||||
|
use_rapid = True
|
||||||
|
engine_name = "rapid" if use_rapid else "tesseract"
|
||||||
|
else:
|
||||||
|
engine_name = "tesseract"
|
||||||
|
|
||||||
|
logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}'")
|
||||||
|
|
||||||
|
# Filter to content rows only
|
||||||
|
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||||
|
if not content_rows:
|
||||||
|
logger.warning("build_cell_grid_v2: no content rows found")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Filter phantom rows (word_count=0) and artifact rows
|
||||||
|
before = len(content_rows)
|
||||||
|
content_rows = [r for r in content_rows if r.word_count > 0]
|
||||||
|
skipped = before - len(content_rows)
|
||||||
|
if skipped > 0:
|
||||||
|
logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)")
|
||||||
|
if not content_rows:
|
||||||
|
logger.warning("build_cell_grid_v2: no content rows with words found")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
before_art = len(content_rows)
|
||||||
|
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
|
||||||
|
artifact_skipped = before_art - len(content_rows)
|
||||||
|
if artifact_skipped > 0:
|
||||||
|
logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows")
|
||||||
|
if not content_rows:
|
||||||
|
logger.warning("build_cell_grid_v2: no content rows after artifact filtering")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Filter columns
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
|
||||||
|
'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
|
||||||
|
if not relevant_cols:
|
||||||
|
logger.warning("build_cell_grid_v2: no usable columns found")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Heal row gaps
|
||||||
|
_heal_row_gaps(
|
||||||
|
content_rows,
|
||||||
|
top_bound=min(c.y for c in relevant_cols),
|
||||||
|
bottom_bound=max(c.y + c.height for c in relevant_cols),
|
||||||
|
)
|
||||||
|
|
||||||
|
relevant_cols.sort(key=lambda c: c.x)
|
||||||
|
|
||||||
|
columns_meta = [
|
||||||
|
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
|
||||||
|
for ci, c in enumerate(relevant_cols)
|
||||||
|
]
|
||||||
|
|
||||||
|
lang_map = {
|
||||||
|
'column_en': 'eng',
|
||||||
|
'column_de': 'deu',
|
||||||
|
'column_example': 'eng+deu',
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Parallel OCR with ThreadPoolExecutor ---
|
||||||
|
# Tesseract is single-threaded per call, so we benefit from parallelism.
|
||||||
|
# ~40 rows × 4 cols = 160 cells, ~50% empty (density skip) → ~80 OCR calls.
|
||||||
|
cells: List[Dict[str, Any]] = []
|
||||||
|
cell_tasks = []
|
||||||
|
|
||||||
|
for row_idx, row in enumerate(content_rows):
|
||||||
|
for col_idx, col in enumerate(relevant_cols):
|
||||||
|
cell_tasks.append((row_idx, col_idx, row, col))
|
||||||
|
|
||||||
|
max_workers = 4 if engine_name == "tesseract" else 2
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(
|
||||||
|
_ocr_cell_crop,
|
||||||
|
ri, ci, row, col,
|
||||||
|
ocr_img, img_bgr, img_w, img_h,
|
||||||
|
engine_name, lang, lang_map,
|
||||||
|
): (ri, ci)
|
||||||
|
for ri, ci, row, col in cell_tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
try:
|
||||||
|
cell = future.result()
|
||||||
|
cells.append(cell)
|
||||||
|
except Exception as e:
|
||||||
|
ri, ci = futures[future]
|
||||||
|
logger.error(f"build_cell_grid_v2: cell R{ri:02d}_C{ci} failed: {e}")
|
||||||
|
|
||||||
|
# Sort cells by (row_index, col_index) since futures complete out of order
|
||||||
|
cells.sort(key=lambda c: (c['row_index'], c['col_index']))
|
||||||
|
|
||||||
|
# Remove all-empty rows
|
||||||
|
rows_with_text: set = set()
|
||||||
|
for cell in cells:
|
||||||
|
if cell['text'].strip():
|
||||||
|
rows_with_text.add(cell['row_index'])
|
||||||
|
before_filter = len(cells)
|
||||||
|
cells = [c for c in cells if c['row_index'] in rows_with_text]
|
||||||
|
empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1)
|
||||||
|
if empty_rows_removed > 0:
|
||||||
|
logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows")
|
||||||
|
|
||||||
|
logger.info(f"build_cell_grid_v2: {len(cells)} cells from "
|
||||||
|
f"{len(content_rows)} rows × {len(relevant_cols)} columns, "
|
||||||
|
f"engine={engine_name}")
|
||||||
|
|
||||||
|
return cells, columns_meta
|
||||||
|
|
||||||
|
|
||||||
|
def build_cell_grid_v2_streaming(
|
||||||
|
ocr_img: np.ndarray,
|
||||||
|
column_regions: List[PageRegion],
|
||||||
|
row_geometries: List[RowGeometry],
|
||||||
|
img_w: int,
|
||||||
|
img_h: int,
|
||||||
|
lang: str = "eng+deu",
|
||||||
|
ocr_engine: str = "auto",
|
||||||
|
img_bgr: Optional[np.ndarray] = None,
|
||||||
|
) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]:
|
||||||
|
"""Streaming variant of build_cell_grid_v2 — yields each cell as OCR'd.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
(cell_dict, columns_meta, total_cells)
|
||||||
|
"""
|
||||||
|
# Resolve engine
|
||||||
|
use_rapid = False
|
||||||
|
if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"):
|
||||||
|
engine_name = ocr_engine
|
||||||
|
elif ocr_engine == "auto":
|
||||||
|
use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None
|
||||||
|
engine_name = "rapid" if use_rapid else "tesseract"
|
||||||
|
elif ocr_engine == "rapid":
|
||||||
|
if not RAPIDOCR_AVAILABLE:
|
||||||
|
logger.warning("RapidOCR requested but not available, falling back to Tesseract")
|
||||||
|
else:
|
||||||
|
use_rapid = True
|
||||||
|
engine_name = "rapid" if use_rapid else "tesseract"
|
||||||
|
else:
|
||||||
|
engine_name = "tesseract"
|
||||||
|
|
||||||
|
content_rows = [r for r in row_geometries if r.row_type == 'content']
|
||||||
|
if not content_rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
content_rows = [r for r in content_rows if r.word_count > 0]
|
||||||
|
if not content_rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top',
|
||||||
|
'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
relevant_cols = [c for c in column_regions if c.type not in _skip_types]
|
||||||
|
if not relevant_cols:
|
||||||
|
return
|
||||||
|
|
||||||
|
content_rows = [r for r in content_rows if not _is_artifact_row(r)]
|
||||||
|
if not content_rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
_heal_row_gaps(
|
||||||
|
content_rows,
|
||||||
|
top_bound=min(c.y for c in relevant_cols),
|
||||||
|
bottom_bound=max(c.y + c.height for c in relevant_cols),
|
||||||
|
)
|
||||||
|
|
||||||
|
relevant_cols.sort(key=lambda c: c.x)
|
||||||
|
|
||||||
|
columns_meta = [
|
||||||
|
{'index': ci, 'type': c.type, 'x': c.x, 'width': c.width}
|
||||||
|
for ci, c in enumerate(relevant_cols)
|
||||||
|
]
|
||||||
|
|
||||||
|
lang_map = {
|
||||||
|
'column_en': 'eng',
|
||||||
|
'column_de': 'deu',
|
||||||
|
'column_example': 'eng+deu',
|
||||||
|
}
|
||||||
|
|
||||||
|
total_cells = len(content_rows) * len(relevant_cols)
|
||||||
|
|
||||||
|
for row_idx, row in enumerate(content_rows):
|
||||||
|
for col_idx, col in enumerate(relevant_cols):
|
||||||
|
cell = _ocr_cell_crop(
|
||||||
|
row_idx, col_idx, row, col,
|
||||||
|
ocr_img, img_bgr, img_w, img_h,
|
||||||
|
engine_name, lang, lang_map,
|
||||||
|
)
|
||||||
|
yield cell, columns_meta, total_cells
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Narrow-column OCR helpers (Proposal B) — DEPRECATED (kept for legacy build_cell_grid)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _compute_cell_padding(col_width: int, img_w: int) -> int:
|
def _compute_cell_padding(col_width: int, img_w: int) -> int:
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from cv_vocab_pipeline import (
|
from cv_vocab_pipeline import (
|
||||||
OLLAMA_REVIEW_MODEL,
|
OLLAMA_REVIEW_MODEL,
|
||||||
|
DocumentTypeResult,
|
||||||
PageRegion,
|
PageRegion,
|
||||||
RowGeometry,
|
RowGeometry,
|
||||||
_cells_to_vocab_entries,
|
_cells_to_vocab_entries,
|
||||||
@@ -43,6 +44,8 @@ from cv_vocab_pipeline import (
|
|||||||
analyze_layout_by_words,
|
analyze_layout_by_words,
|
||||||
build_cell_grid,
|
build_cell_grid,
|
||||||
build_cell_grid_streaming,
|
build_cell_grid_streaming,
|
||||||
|
build_cell_grid_v2,
|
||||||
|
build_cell_grid_v2_streaming,
|
||||||
build_word_grid,
|
build_word_grid,
|
||||||
classify_column_types,
|
classify_column_types,
|
||||||
create_layout_image,
|
create_layout_image,
|
||||||
@@ -50,6 +53,7 @@ from cv_vocab_pipeline import (
|
|||||||
deskew_image,
|
deskew_image,
|
||||||
deskew_image_by_word_alignment,
|
deskew_image_by_word_alignment,
|
||||||
detect_column_geometry,
|
detect_column_geometry,
|
||||||
|
detect_document_type,
|
||||||
detect_row_geometry,
|
detect_row_geometry,
|
||||||
expand_narrow_columns,
|
expand_narrow_columns,
|
||||||
_apply_shear,
|
_apply_shear,
|
||||||
@@ -759,6 +763,54 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques
|
|||||||
return {"session_id": session_id, "ground_truth": gt}
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document Type Detection (between Dewarp and Columns)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/detect-type")
|
||||||
|
async def detect_type(session_id: str):
|
||||||
|
"""Detect document type (vocab_table, full_text, generic_table).
|
||||||
|
|
||||||
|
Should be called after dewarp (clean image available).
|
||||||
|
Stores result in session for frontend to decide pipeline flow.
|
||||||
|
"""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
dewarped_bgr = cached.get("dewarped_bgr")
|
||||||
|
if dewarped_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Dewarp must be completed first")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
result = detect_document_type(ocr_img, dewarped_bgr)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"doc_type": result.doc_type,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
"pipeline": result.pipeline,
|
||||||
|
"skip_steps": result.skip_steps,
|
||||||
|
"features": result.features,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
doc_type=result.doc_type,
|
||||||
|
doc_type_result=result_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["doc_type_result"] = result_dict
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||||
|
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||||
|
|
||||||
|
return {"session_id": session_id, **result_dict}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Column Detection Endpoints (Step 3)
|
# Column Detection Endpoints (Step 3)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -1196,8 +1248,10 @@ async def detect_words(
|
|||||||
for r in row_result["rows"]
|
for r in row_result["rows"]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Re-populate row.words from cached full-page Tesseract words.
|
# Cell-First OCR (v2): no full-page word re-population needed.
|
||||||
# Word-lookup in _ocr_single_cell needs these to avoid re-running OCR.
|
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
|
||||||
|
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
|
||||||
|
# so populate from cached words if available (just for counting).
|
||||||
word_dicts = cached.get("_word_dicts")
|
word_dicts = cached.get("_word_dicts")
|
||||||
if word_dicts is None:
|
if word_dicts is None:
|
||||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||||
@@ -1209,8 +1263,6 @@ async def detect_words(
|
|||||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
if word_dicts:
|
if word_dicts:
|
||||||
# words['top'] is relative to content-ROI top_y.
|
|
||||||
# row.y is absolute. Convert: row_y_rel = row.y - top_y.
|
|
||||||
content_bounds = cached.get("_content_bounds")
|
content_bounds = cached.get("_content_bounds")
|
||||||
if content_bounds:
|
if content_bounds:
|
||||||
_lx, _rx, top_y, _by = content_bounds
|
_lx, _rx, top_y, _by = content_bounds
|
||||||
@@ -1240,15 +1292,15 @@ async def detect_words(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Non-streaming path (unchanged) ---
|
# --- Non-streaming path ---
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
# Create binarized OCR image (for Tesseract)
|
# Create binarized OCR image (for Tesseract)
|
||||||
ocr_img = create_ocr_image(dewarped_bgr)
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
img_h, img_w = dewarped_bgr.shape[:2]
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
# Build generic cell grid
|
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
|
||||||
cells, columns_meta = build_cell_grid(
|
cells, columns_meta = build_cell_grid_v2(
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
)
|
)
|
||||||
@@ -1358,7 +1410,7 @@ async def _word_stream_generator(
|
|||||||
all_cells: List[Dict[str, Any]] = []
|
all_cells: List[Dict[str, Any]] = []
|
||||||
cell_idx = 0
|
cell_idx = 0
|
||||||
|
|
||||||
for cell, cols_meta, total in build_cell_grid_streaming(
|
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -64,7 +64,9 @@ async def init_ocr_pipeline_tables():
|
|||||||
await conn.execute("""
|
await conn.execute("""
|
||||||
ALTER TABLE ocr_pipeline_sessions
|
ALTER TABLE ocr_pipeline_sessions
|
||||||
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
||||||
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB
|
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
|
||||||
|
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
|
||||||
|
ADD COLUMN IF NOT EXISTS doc_type_result JSONB
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
@@ -88,6 +90,7 @@ async def create_session_db(
|
|||||||
RETURNING id, name, filename, status, current_step,
|
RETURNING id, name, filename, status, current_step,
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
""", uuid.UUID(session_id), name, filename, original_png)
|
""", uuid.UUID(session_id), name, filename, original_png)
|
||||||
|
|
||||||
@@ -102,6 +105,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
|||||||
SELECT id, name, filename, status, current_step,
|
SELECT id, name, filename, status, current_step,
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
FROM ocr_pipeline_sessions WHERE id = $1
|
FROM ocr_pipeline_sessions WHERE id = $1
|
||||||
""", uuid.UUID(session_id))
|
""", uuid.UUID(session_id))
|
||||||
@@ -146,9 +150,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
|||||||
'clean_png', 'handwriting_removal_meta',
|
'clean_png', 'handwriting_removal_meta',
|
||||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||||
|
'doc_type', 'doc_type_result',
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'}
|
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result'}
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in allowed_fields:
|
if key in allowed_fields:
|
||||||
@@ -174,6 +179,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
|||||||
RETURNING id, name, filename, status, current_step,
|
RETURNING id, name, filename, status, current_step,
|
||||||
deskew_result, dewarp_result, column_result, row_result,
|
deskew_result, dewarp_result, column_result, row_result,
|
||||||
word_result, ground_truth, auto_shear_degrees,
|
word_result, ground_truth, auto_shear_degrees,
|
||||||
|
doc_type, doc_type_result,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
""", *values)
|
""", *values)
|
||||||
|
|
||||||
@@ -229,7 +235,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
|||||||
result[key] = result[key].isoformat()
|
result[key] = result[key].isoformat()
|
||||||
|
|
||||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
|
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result']:
|
||||||
if key in result and result[key] is not None:
|
if key in result and result[key] is not None:
|
||||||
if isinstance(result[key], str):
|
if isinstance(result[key], str):
|
||||||
result[key] = json.loads(result[key])
|
result[key] = json.loads(result[key])
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ from dataclasses import asdict
|
|||||||
# Import module under test
|
# Import module under test
|
||||||
from cv_vocab_pipeline import (
|
from cv_vocab_pipeline import (
|
||||||
ColumnGeometry,
|
ColumnGeometry,
|
||||||
|
DocumentTypeResult,
|
||||||
PageRegion,
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
VocabRow,
|
VocabRow,
|
||||||
PipelineResult,
|
PipelineResult,
|
||||||
deskew_image,
|
deskew_image,
|
||||||
@@ -48,9 +50,12 @@ from cv_vocab_pipeline import (
|
|||||||
CV_PIPELINE_AVAILABLE,
|
CV_PIPELINE_AVAILABLE,
|
||||||
_is_noise_tail_token,
|
_is_noise_tail_token,
|
||||||
_clean_cell_text,
|
_clean_cell_text,
|
||||||
|
_clean_cell_text_lite,
|
||||||
_is_phonetic_only_text,
|
_is_phonetic_only_text,
|
||||||
_merge_phonetic_continuation_rows,
|
_merge_phonetic_continuation_rows,
|
||||||
_merge_continuation_rows,
|
_merge_continuation_rows,
|
||||||
|
_ocr_cell_crop,
|
||||||
|
detect_document_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1566,6 +1571,167 @@ class TestCellsToVocabEntriesPageRef:
|
|||||||
assert entries[0]['source_page'] == 'p.59'
|
assert entries[0]['source_page'] == 'p.59'
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================
|
||||||
|
# CELL-FIRST OCR (v2) TESTS
|
||||||
|
# =============================================
|
||||||
|
|
||||||
|
class TestCleanCellTextLite:
|
||||||
|
"""Tests for _clean_cell_text_lite() — simplified noise filter."""
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert _clean_cell_text_lite('') == ''
|
||||||
|
|
||||||
|
def test_whitespace_only(self):
|
||||||
|
assert _clean_cell_text_lite(' ') == ''
|
||||||
|
|
||||||
|
def test_real_word_passes(self):
|
||||||
|
assert _clean_cell_text_lite('hello') == 'hello'
|
||||||
|
|
||||||
|
def test_sentence_passes(self):
|
||||||
|
assert _clean_cell_text_lite('to have dinner') == 'to have dinner'
|
||||||
|
|
||||||
|
def test_garbage_text_cleared(self):
|
||||||
|
"""Garbage text (no dictionary words) should be cleared."""
|
||||||
|
assert _clean_cell_text_lite('xqzjk') == ''
|
||||||
|
|
||||||
|
def test_no_real_word_cleared(self):
|
||||||
|
"""Single chars with no real word (2+ letters) cleared."""
|
||||||
|
assert _clean_cell_text_lite('3') == ''
|
||||||
|
assert _clean_cell_text_lite('|') == ''
|
||||||
|
|
||||||
|
def test_known_abbreviation_kept(self):
|
||||||
|
"""Known abbreviations should pass through."""
|
||||||
|
assert _clean_cell_text_lite('sth') == 'sth'
|
||||||
|
assert _clean_cell_text_lite('eg') == 'eg'
|
||||||
|
|
||||||
|
def test_no_trailing_noise_stripping(self):
|
||||||
|
"""Unlike _clean_cell_text, lite does NOT strip trailing tokens.
|
||||||
|
Since cells are isolated, all tokens are legitimate."""
|
||||||
|
result = _clean_cell_text_lite('apple tree')
|
||||||
|
assert result == 'apple tree'
|
||||||
|
|
||||||
|
def test_page_reference(self):
|
||||||
|
"""Page references like p.60 should pass."""
|
||||||
|
# 'p' is a known abbreviation
|
||||||
|
assert _clean_cell_text_lite('p.60') != ''
|
||||||
|
|
||||||
|
|
||||||
|
class TestOcrCellCrop:
|
||||||
|
"""Tests for _ocr_cell_crop() — isolated cell OCR."""
|
||||||
|
|
||||||
|
def test_empty_cell_pixel_density(self):
|
||||||
|
"""Cells with very few dark pixels should return empty text."""
|
||||||
|
# All white image → no text
|
||||||
|
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||||
|
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
|
||||||
|
word_count=1, words=[{'text': 'a'}])
|
||||||
|
col = PageRegion(type='column_en', x=50, y=0, width=200, height=400)
|
||||||
|
|
||||||
|
result = _ocr_cell_crop(
|
||||||
|
0, 0, row, col, ocr_img, None, 600, 400,
|
||||||
|
'tesseract', 'eng+deu', {'column_en': 'eng'},
|
||||||
|
)
|
||||||
|
assert result['text'] == ''
|
||||||
|
assert result['cell_id'] == 'R00_C0'
|
||||||
|
assert result['col_type'] == 'column_en'
|
||||||
|
|
||||||
|
def test_zero_width_cell(self):
|
||||||
|
"""Zero-width cells should return empty."""
|
||||||
|
ocr_img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||||
|
row = RowGeometry(index=0, x=0, y=50, width=600, height=30,
|
||||||
|
word_count=1, words=[])
|
||||||
|
col = PageRegion(type='column_en', x=50, y=0, width=0, height=400)
|
||||||
|
|
||||||
|
result = _ocr_cell_crop(
|
||||||
|
0, 0, row, col, ocr_img, None, 600, 400,
|
||||||
|
'tesseract', 'eng+deu', {},
|
||||||
|
)
|
||||||
|
assert result['text'] == ''
|
||||||
|
|
||||||
|
def test_bbox_calculation(self):
|
||||||
|
"""Check bbox_px and bbox_pct are correct."""
|
||||||
|
ocr_img = np.ones((1000, 2000), dtype=np.uint8) * 255
|
||||||
|
row = RowGeometry(index=0, x=0, y=100, width=2000, height=50,
|
||||||
|
word_count=1, words=[{'text': 'test'}])
|
||||||
|
col = PageRegion(type='column_de', x=400, y=0, width=600, height=1000)
|
||||||
|
|
||||||
|
result = _ocr_cell_crop(
|
||||||
|
0, 0, row, col, ocr_img, None, 2000, 1000,
|
||||||
|
'tesseract', 'eng+deu', {'column_de': 'deu'},
|
||||||
|
)
|
||||||
|
assert result['bbox_px'] == {'x': 400, 'y': 100, 'w': 600, 'h': 50}
|
||||||
|
assert result['bbox_pct']['x'] == 20.0 # 400/2000*100
|
||||||
|
assert result['bbox_pct']['y'] == 10.0 # 100/1000*100
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectDocumentType:
|
||||||
|
"""Tests for detect_document_type() — image-based classification."""
|
||||||
|
|
||||||
|
def test_empty_image(self):
|
||||||
|
"""Empty image should default to full_text."""
|
||||||
|
empty = np.array([], dtype=np.uint8).reshape(0, 0)
|
||||||
|
result = detect_document_type(empty, empty)
|
||||||
|
assert result.doc_type == 'full_text'
|
||||||
|
assert result.pipeline == 'full_page'
|
||||||
|
|
||||||
|
def test_table_image_detected(self):
|
||||||
|
"""Image with clear column gaps and row gaps → table."""
|
||||||
|
# Create 600x400 binary image with 3 columns separated by white gaps
|
||||||
|
img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||||
|
# Column 1: x=20..170
|
||||||
|
for y in range(30, 370, 20):
|
||||||
|
img[y:y+10, 20:170] = 0
|
||||||
|
# Gap: x=170..210 (white)
|
||||||
|
# Column 2: x=210..370
|
||||||
|
for y in range(30, 370, 20):
|
||||||
|
img[y:y+10, 210:370] = 0
|
||||||
|
# Gap: x=370..410 (white)
|
||||||
|
# Column 3: x=410..580
|
||||||
|
for y in range(30, 370, 20):
|
||||||
|
img[y:y+10, 410:580] = 0
|
||||||
|
|
||||||
|
bgr = np.stack([img, img, img], axis=-1)
|
||||||
|
result = detect_document_type(img, bgr)
|
||||||
|
assert result.doc_type in ('vocab_table', 'generic_table')
|
||||||
|
assert result.pipeline == 'cell_first'
|
||||||
|
assert result.confidence >= 0.5
|
||||||
|
|
||||||
|
def test_fulltext_image_detected(self):
|
||||||
|
"""Uniform text without column gaps → full_text."""
|
||||||
|
img = np.ones((400, 600), dtype=np.uint8) * 255
|
||||||
|
# Uniform text lines across full width (no column gaps)
|
||||||
|
for y in range(30, 370, 15):
|
||||||
|
img[y:y+8, 30:570] = 0
|
||||||
|
|
||||||
|
bgr = np.stack([img, img, img], axis=-1)
|
||||||
|
result = detect_document_type(img, bgr)
|
||||||
|
assert result.doc_type == 'full_text'
|
||||||
|
assert result.pipeline == 'full_page'
|
||||||
|
assert 'columns' in result.skip_steps
|
||||||
|
assert 'rows' in result.skip_steps
|
||||||
|
|
||||||
|
def test_result_has_features(self):
|
||||||
|
"""Result should contain debug features."""
|
||||||
|
img = np.ones((200, 300), dtype=np.uint8) * 255
|
||||||
|
bgr = np.stack([img, img, img], axis=-1)
|
||||||
|
result = detect_document_type(img, bgr)
|
||||||
|
assert 'vertical_gaps' in result.features
|
||||||
|
assert 'row_gaps' in result.features
|
||||||
|
assert 'density_mean' in result.features
|
||||||
|
assert 'density_std' in result.features
|
||||||
|
|
||||||
|
def test_document_type_result_dataclass(self):
|
||||||
|
"""DocumentTypeResult dataclass should initialize correctly."""
|
||||||
|
r = DocumentTypeResult(
|
||||||
|
doc_type='vocab_table',
|
||||||
|
confidence=0.9,
|
||||||
|
pipeline='cell_first',
|
||||||
|
)
|
||||||
|
assert r.doc_type == 'vocab_table'
|
||||||
|
assert r.skip_steps == []
|
||||||
|
assert r.features == {}
|
||||||
|
|
||||||
|
|
||||||
# =============================================
|
# =============================================
|
||||||
# RUN TESTS
|
# RUN TESTS
|
||||||
# =============================================
|
# =============================================
|
||||||
|
|||||||
Reference in New Issue
Block a user