From 29c74a99625129400f575db9d62d17994abcf2d3 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Wed, 4 Mar 2026 13:52:38 +0100 Subject: [PATCH] feat: cell-first OCR + document type detection + dynamic pipeline steps Cell-First OCR (v2): Each cell is cropped and OCR'd in isolation, eliminating neighbour bleeding (e.g. "to", "ps" in marker columns). Uses ThreadPoolExecutor for parallel Tesseract calls. Document type detection: Classifies pages as vocab_table, full_text, or generic_table using projection profiles (<2s, no OCR needed). Frontend dynamically skips columns/rows steps for full-text pages. Co-Authored-By: Claude Opus 4.6 --- .../app/(admin)/ai/ocr-pipeline/page.tsx | 113 +++- .../app/(admin)/ai/ocr-pipeline/types.ts | 12 +- .../ocr-pipeline/PipelineStepper.tsx | 147 +++-- klausur-service/backend/cv_vocab_pipeline.py | 558 +++++++++++++++++- klausur-service/backend/ocr_pipeline_api.py | 68 ++- .../backend/ocr_pipeline_session_store.py | 12 +- .../backend/tests/test_cv_vocab_pipeline.py | 166 ++++++ 7 files changed, 1001 insertions(+), 75 deletions(-) diff --git a/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx b/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx index 7383107..bbfba67 100644 --- a/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx +++ b/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx @@ -11,7 +11,7 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti import { StepLlmReview } from '@/components/ocr-pipeline/StepLlmReview' import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction' import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth' -import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types' +import { PIPELINE_STEPS, type PipelineStep, type SessionListItem, type DocumentTypeResult } from './types' const KLAUSUR_API = '/klausur-api' @@ -23,6 +23,7 @@ export default function OcrPipelinePage() { const [loadingSessions, setLoadingSessions] = useState(true) const [editingName, setEditingName] = useState(null) const [editNameValue, setEditNameValue] = useState('') + const [docTypeResult, setDocTypeResult] = useState(null) const [steps, setSteps] = useState( PIPELINE_STEPS.map((s, i) => ({ ...s, @@ -59,16 +60,23 @@ export default function OcrPipelinePage() { setSessionId(sid) setSessionName(data.name || data.filename || '') + // Restore doc type result if available + const savedDocType: DocumentTypeResult | null = data.doc_type_result || null + setDocTypeResult(savedDocType) + // Determine which step to jump to based on current_step const dbStep = data.current_step || 1 // Steps: 1=deskew, 2=dewarp, 3=columns, ... // UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ... const uiStep = Math.max(0, dbStep - 1) + const skipSteps = savedDocType?.skip_steps || [] setSteps( PIPELINE_STEPS.map((s, i) => ({ ...s, - status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending', + status: skipSteps.includes(s.id) + ? 'skipped' + : i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending', })), ) setCurrentStep(uiStep) @@ -84,6 +92,7 @@ export default function OcrPipelinePage() { if (sessionId === sid) { setSessionId(null) setCurrentStep(0) + setDocTypeResult(null) setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' }))) } } catch (e) { @@ -123,16 +132,28 @@ export default function OcrPipelinePage() { } const handleNext = () => { - if (currentStep < steps.length - 1) { - setSteps((prev) => - prev.map((s, i) => { - if (i === currentStep) return { ...s, status: 'completed' } - if (i === currentStep + 1) return { ...s, status: 'active' } - return s - }), - ) - setCurrentStep((prev) => prev + 1) + if (currentStep >= steps.length - 1) return + + // Find the next non-skipped step + const skipSteps = docTypeResult?.skip_steps || [] + let nextStep = currentStep + 1 + while (nextStep < steps.length && skipSteps.includes(PIPELINE_STEPS[nextStep]?.id)) { + nextStep++ } + if (nextStep >= steps.length) nextStep = steps.length - 1 + + setSteps((prev) => + prev.map((s, i) => { + if (i === currentStep) return { ...s, status: 'completed' } + if (i === nextStep) return { ...s, status: 'active' } + // Mark skipped steps between current and next + if (i > currentStep && i < nextStep && skipSteps.includes(PIPELINE_STEPS[i]?.id)) { + return { ...s, status: 'skipped' } + } + return s + }), + ) + setCurrentStep(nextStep) } const handleDeskewComplete = (sid: string) => { @@ -142,10 +163,69 @@ export default function OcrPipelinePage() { handleNext() } + const handleDewarpNext = async () => { + // Auto-detect document type after dewarp, then advance + if (sessionId) { + try { + const res = await fetch( + `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/detect-type`, + { method: 'POST' }, + ) + if (res.ok) { + const data: DocumentTypeResult = await res.json() + setDocTypeResult(data) + + // Mark skipped steps immediately + const skipSteps = data.skip_steps || [] + if (skipSteps.length > 0) { + setSteps((prev) => + prev.map((s) => + skipSteps.includes(s.id) ? { ...s, status: 'skipped' } : s, + ), + ) + } + } + } catch (e) { + console.error('Doc type detection failed:', e) + // Not critical — continue without it + } + } + handleNext() + } + + const handleDocTypeChange = (newDocType: DocumentTypeResult['doc_type']) => { + if (!docTypeResult) return + + // Build new skip_steps based on doc type + let skipSteps: string[] = [] + if (newDocType === 'full_text') { + skipSteps = ['columns', 'rows'] + } + // vocab_table and generic_table: no skips + + const updated: DocumentTypeResult = { + ...docTypeResult, + doc_type: newDocType, + skip_steps: skipSteps, + pipeline: newDocType === 'full_text' ? 'full_page' : 'cell_first', + } + setDocTypeResult(updated) + + // Update step statuses + setSteps((prev) => + prev.map((s) => { + if (skipSteps.includes(s.id)) return { ...s, status: 'skipped' as const } + if (s.status === 'skipped') return { ...s, status: 'pending' as const } + return s + }), + ) + } + const handleNewSession = () => { setSessionId(null) setSessionName('') setCurrentStep(0) + setDocTypeResult(null) setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' }))) } @@ -188,7 +268,7 @@ export default function OcrPipelinePage() { case 0: return case 1: - return + return case 2: return case 3: @@ -314,7 +394,14 @@ export default function OcrPipelinePage() { )} - +
{renderStep()}
diff --git a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts index 7213df0..4b1c86f 100644 --- a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts +++ b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts @@ -1,4 +1,4 @@ -export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed' +export type PipelineStepStatus = 'pending' | 'active' | 'completed' | 'failed' | 'skipped' export interface PipelineStep { id: string @@ -17,6 +17,15 @@ export interface SessionListItem { updated_at?: string } +export interface DocumentTypeResult { + doc_type: 'vocab_table' | 'full_text' | 'generic_table' + confidence: number + pipeline: 'cell_first' | 'full_page' + skip_steps: string[] + features?: Record + duration_seconds?: number +} + export interface SessionInfo { session_id: string filename: string @@ -30,6 +39,7 @@ export interface SessionInfo { column_result?: ColumnResult row_result?: RowResult word_result?: GridResult + doc_type_result?: DocumentTypeResult } export interface DeskewResult { diff --git a/admin-lehrer/components/ocr-pipeline/PipelineStepper.tsx b/admin-lehrer/components/ocr-pipeline/PipelineStepper.tsx index e72b8cf..da2785b 100644 --- a/admin-lehrer/components/ocr-pipeline/PipelineStepper.tsx +++ b/admin-lehrer/components/ocr-pipeline/PipelineStepper.tsx @@ -1,66 +1,115 @@ 'use client' -import { PipelineStep } from '@/app/(admin)/ai/ocr-pipeline/types' +import { PipelineStep, DocumentTypeResult } from '@/app/(admin)/ai/ocr-pipeline/types' + +const DOC_TYPE_LABELS: Record = { + vocab_table: 'Vokabeltabelle', + full_text: 'Volltext', + generic_table: 'Tabelle', +} interface PipelineStepperProps { steps: PipelineStep[] currentStep: number onStepClick: (index: number) => void onReprocess?: (index: number) => void + docTypeResult?: DocumentTypeResult | null + onDocTypeChange?: (docType: DocumentTypeResult['doc_type']) => void } -export function PipelineStepper({ steps, currentStep, onStepClick, onReprocess }: PipelineStepperProps) { +export function PipelineStepper({ + steps, + currentStep, + onStepClick, + onReprocess, + docTypeResult, + onDocTypeChange, +}: PipelineStepperProps) { return ( -
- {steps.map((step, index) => { - const isActive = index === currentStep - const isCompleted = step.status === 'completed' - const isFailed = step.status === 'failed' - const isClickable = index <= currentStep || isCompleted +
+
+ {steps.map((step, index) => { + const isActive = index === currentStep + const isCompleted = step.status === 'completed' + const isFailed = step.status === 'failed' + const isSkipped = step.status === 'skipped' + const isClickable = (index <= currentStep || isCompleted) && !isSkipped - return ( -
- {index > 0 && ( -
- )} -
- - {/* Reprocess button — shown on completed steps on hover */} - {isCompleted && onReprocess && ( - + return ( +
+ {index > 0 && ( +
)} +
+ + {/* Reprocess button — shown on completed steps on hover */} + {isCompleted && onReprocess && ( + + )} +
-
- ) - })} + ) + })} +
+ + {/* Document type badge */} + {docTypeResult && ( +
+ + Dokumenttyp: + + {onDocTypeChange ? ( + + ) : ( + + {DOC_TYPE_LABELS[docTypeResult.doc_type] || docTypeResult.doc_type} + + )} + + ({Math.round(docTypeResult.confidence * 100)}% Konfidenz) + +
+ )}
) } diff --git a/klausur-service/backend/cv_vocab_pipeline.py b/klausur-service/backend/cv_vocab_pipeline.py index a10a0c7..1c0012e 100644 --- a/klausur-service/backend/cv_vocab_pipeline.py +++ b/klausur-service/backend/cv_vocab_pipeline.py @@ -18,6 +18,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. import io import logging import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from typing import Any, Dict, Generator, List, Optional, Tuple @@ -159,6 +160,16 @@ class PipelineResult: image_height: int = 0 +@dataclass +class DocumentTypeResult: + """Result of automatic document type detection.""" + doc_type: str # 'vocab_table' | 'full_text' | 'generic_table' + confidence: float # 0.0-1.0 + pipeline: str # 'cell_first' | 'full_page' + skip_steps: List[str] = field(default_factory=list) # e.g. ['columns', 'rows'] + features: Dict[str, Any] = field(default_factory=dict) # debug info + + # ============================================================================= # Stage 1: High-Resolution PDF Rendering # ============================================================================= @@ -966,6 +977,164 @@ def dewarp_image_manual(img: np.ndarray, shear_degrees: float) -> np.ndarray: return _apply_shear(img, -shear_degrees) +# ============================================================================= +# Document Type Detection +# ============================================================================= + +def detect_document_type(ocr_img: np.ndarray, img_bgr: np.ndarray) -> DocumentTypeResult: + """Detect whether the page is a vocab table, generic table, or full text. + + Uses projection profiles and text density analysis — no OCR required. + Runs in < 2 seconds. + + Args: + ocr_img: Binarized grayscale image (for projection profiles). + img_bgr: BGR color image. + + Returns: + DocumentTypeResult with doc_type, confidence, pipeline, skip_steps. + """ + if ocr_img is None or ocr_img.size == 0: + return DocumentTypeResult( + doc_type='full_text', confidence=0.5, pipeline='full_page', + skip_steps=['columns', 'rows'], + features={'error': 'empty image'}, + ) + + h, w = ocr_img.shape[:2] + + # --- 1. Vertical projection profile → detect column gaps --- + # Sum dark pixels along each column (x-axis). Gaps = valleys in the profile. + # Invert: dark pixels on white background → high values = text. + vert_proj = np.sum(ocr_img < 128, axis=0).astype(float) + + # Smooth the profile to avoid noise spikes + kernel_size = max(3, w // 100) + if kernel_size % 2 == 0: + kernel_size += 1 + vert_smooth = np.convolve(vert_proj, np.ones(kernel_size) / kernel_size, mode='same') + + # Find significant vertical gaps (columns of near-zero text density) + # A gap must be at least 1% of image width and have < 5% of max density + max_density = max(vert_smooth.max(), 1) + gap_threshold = max_density * 0.05 + min_gap_width = max(5, w // 100) + + in_gap = False + gap_count = 0 + gap_start = 0 + vert_gaps = [] + + for x in range(w): + if vert_smooth[x] < gap_threshold: + if not in_gap: + in_gap = True + gap_start = x + else: + if in_gap: + gap_width = x - gap_start + if gap_width >= min_gap_width: + gap_count += 1 + vert_gaps.append((gap_start, x, gap_width)) + in_gap = False + + # Filter out margin gaps (within 10% of image edges) + margin_threshold = w * 0.10 + internal_gaps = [g for g in vert_gaps if g[0] > margin_threshold and g[1] < w - margin_threshold] + internal_gap_count = len(internal_gaps) + + # --- 2. Horizontal projection profile → detect row gaps --- + horiz_proj = np.sum(ocr_img < 128, axis=1).astype(float) + h_kernel = max(3, h // 200) + if h_kernel % 2 == 0: + h_kernel += 1 + horiz_smooth = np.convolve(horiz_proj, np.ones(h_kernel) / h_kernel, mode='same') + + h_max = max(horiz_smooth.max(), 1) + h_gap_threshold = h_max * 0.05 + min_row_gap = max(3, h // 200) + + row_gap_count = 0 + in_gap = False + for y in range(h): + if horiz_smooth[y] < h_gap_threshold: + if not in_gap: + in_gap = True + gap_start = y + else: + if in_gap: + if y - gap_start >= min_row_gap: + row_gap_count += 1 + in_gap = False + + # --- 3. Text density distribution (4×4 grid) --- + grid_rows, grid_cols = 4, 4 + cell_h, cell_w = h // grid_rows, w // grid_cols + densities = [] + for gr in range(grid_rows): + for gc in range(grid_cols): + cell = ocr_img[gr * cell_h:(gr + 1) * cell_h, + gc * cell_w:(gc + 1) * cell_w] + if cell.size > 0: + d = float(np.count_nonzero(cell < 128)) / cell.size + densities.append(d) + + density_std = float(np.std(densities)) if densities else 0 + density_mean = float(np.mean(densities)) if densities else 0 + + features = { + 'vertical_gaps': gap_count, + 'internal_vertical_gaps': internal_gap_count, + 'vertical_gap_details': [(g[0], g[1], g[2]) for g in vert_gaps[:10]], + 'row_gaps': row_gap_count, + 'density_mean': round(density_mean, 4), + 'density_std': round(density_std, 4), + 'image_size': (w, h), + } + + # --- 4. Decision tree --- + # Use internal_gap_count (excludes margin gaps) for column detection. + if internal_gap_count >= 2 and row_gap_count >= 5: + # Multiple internal vertical gaps + many row gaps → table + confidence = min(0.95, 0.7 + internal_gap_count * 0.05 + row_gap_count * 0.005) + return DocumentTypeResult( + doc_type='vocab_table', + confidence=round(confidence, 2), + pipeline='cell_first', + skip_steps=[], + features=features, + ) + elif internal_gap_count >= 1 and row_gap_count >= 3: + # Some internal structure, likely a table + confidence = min(0.85, 0.5 + internal_gap_count * 0.1 + row_gap_count * 0.01) + return DocumentTypeResult( + doc_type='generic_table', + confidence=round(confidence, 2), + pipeline='cell_first', + skip_steps=[], + features=features, + ) + elif internal_gap_count == 0: + # No internal column gaps → full text (regardless of density) + confidence = min(0.95, 0.8 + (1 - min(density_std, 0.1)) * 0.15) + return DocumentTypeResult( + doc_type='full_text', + confidence=round(confidence, 2), + pipeline='full_page', + skip_steps=['columns', 'rows'], + features=features, + ) + else: + # Ambiguous — default to vocab_table (most common use case) + return DocumentTypeResult( + doc_type='vocab_table', + confidence=0.5, + pipeline='cell_first', + skip_steps=[], + features=features, + ) + + # ============================================================================= # Stage 4: Dual Image Preparation # ============================================================================= @@ -4481,8 +4650,395 @@ def _clean_cell_text(text: str) -> str: return ' '.join(tokens) +def _clean_cell_text_lite(text: str) -> str: + """Simplified noise filter for cell-first OCR (isolated cell crops). + + Since each cell is OCR'd in isolation (no neighbour content visible), + trailing-noise stripping is unnecessary. Only 2 filters remain: + + 1. No real alphabetic word (>= 2 letters) and not a known abbreviation → empty. + 2. Entire text is garbage (no dictionary word) → empty. + """ + stripped = text.strip() + if not stripped: + return '' + + # --- Filter 1: No real word at all --- + if not _RE_REAL_WORD.search(stripped): + alpha_only = ''.join(_RE_ALPHA.findall(stripped)).lower() + if alpha_only not in _KNOWN_ABBREVIATIONS: + return '' + + # --- Filter 2: Entire text is garbage --- + if _is_garbage_text(stripped): + return '' + + return stripped + + # --------------------------------------------------------------------------- -# Narrow-column OCR helpers (Proposal B) +# Cell-First OCR (v2) — each cell cropped and OCR'd in isolation +# --------------------------------------------------------------------------- + +def _ocr_cell_crop( + row_idx: int, + col_idx: int, + row: RowGeometry, + col: PageRegion, + ocr_img: np.ndarray, + img_bgr: Optional[np.ndarray], + img_w: int, + img_h: int, + engine_name: str, + lang: str, + lang_map: Dict[str, str], +) -> Dict[str, Any]: + """OCR a single cell by cropping the exact column×row intersection. + + No padding beyond cell boundaries → no neighbour bleeding. + """ + # Display bbox: exact column × row intersection + disp_x = col.x + disp_y = row.y + disp_w = col.width + disp_h = row.height + + # Crop boundaries (clamped to image) + cx = max(0, disp_x) + cy = max(0, disp_y) + cw = min(disp_w, img_w - cx) + ch = min(disp_h, img_h - cy) + + empty_cell = { + 'cell_id': f"R{row_idx:02d}_C{col_idx}", + 'row_index': row_idx, + 'col_index': col_idx, + 'col_type': col.type, + 'text': '', + 'confidence': 0.0, + 'bbox_px': {'x': disp_x, 'y': disp_y, 'w': disp_w, 'h': disp_h}, + 'bbox_pct': { + 'x': round(disp_x / img_w * 100, 2) if img_w else 0, + 'y': round(disp_y / img_h * 100, 2) if img_h else 0, + 'w': round(disp_w / img_w * 100, 2) if img_w else 0, + 'h': round(disp_h / img_h * 100, 2) if img_h else 0, + }, + 'ocr_engine': 'cell_crop_v2', + } + + if cw <= 0 or ch <= 0: + return empty_cell + + # --- Pixel-density check: skip truly empty cells --- + if ocr_img is not None: + crop = ocr_img[cy:cy + ch, cx:cx + cw] + if crop.size > 0: + dark_ratio = float(np.count_nonzero(crop < 180)) / crop.size + if dark_ratio < 0.005: + return empty_cell + + # --- Prepare crop for OCR --- + cell_lang = lang_map.get(col.type, lang) + psm = _select_psm_for_column(col.type, col.width, row.height) + text = '' + avg_conf = 0.0 + used_engine = 'cell_crop_v2' + + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) + words = ocr_region_trocr(img_bgr, cell_region, + handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) + words = ocr_region_lighton(img_bgr, cell_region) + elif engine_name == "rapid" and img_bgr is not None: + cell_region = PageRegion(type=col.type, x=cx, y=cy, width=cw, height=ch) + words = ocr_region_rapid(img_bgr, cell_region) + else: + # Tesseract: upscale tiny crops for better recognition + if ocr_img is not None: + crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] + upscaled = _ensure_minimum_crop_size(crop_slice) + up_h, up_w = upscaled.shape[:2] + tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) + words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=psm) + # Remap word positions back to original image coordinates + if words and (up_w != cw or up_h != ch): + sx = cw / max(up_w, 1) + sy = ch / max(up_h, 1) + for w in words: + w['left'] = int(w['left'] * sx) + cx + w['top'] = int(w['top'] * sy) + cy + w['width'] = int(w['width'] * sx) + w['height'] = int(w['height'] * sy) + elif words: + for w in words: + w['left'] += cx + w['top'] += cy + else: + words = [] + + # Filter low-confidence words + _MIN_WORD_CONF = 30 + if words: + words = [w for w in words if w.get('conf', 0) >= _MIN_WORD_CONF] + + if words: + y_tol = max(15, ch) + text = _words_to_reading_order_text(words, y_tolerance_px=y_tol) + avg_conf = round(sum(w['conf'] for w in words) / len(words), 1) + + # --- PSM 7 fallback for still-empty Tesseract cells --- + if not text.strip() and engine_name == "tesseract" and ocr_img is not None: + crop_slice = ocr_img[cy:cy + ch, cx:cx + cw] + upscaled = _ensure_minimum_crop_size(crop_slice) + up_h, up_w = upscaled.shape[:2] + tmp_region = PageRegion(type=col.type, x=0, y=0, width=up_w, height=up_h) + psm7_words = ocr_region(upscaled, tmp_region, lang=cell_lang, psm=7) + if psm7_words: + psm7_words = [w for w in psm7_words if w.get('conf', 0) >= _MIN_WORD_CONF] + if psm7_words: + p7_text = _words_to_reading_order_text(psm7_words, y_tolerance_px=10) + if p7_text.strip(): + text = p7_text + avg_conf = round( + sum(w['conf'] for w in psm7_words) / len(psm7_words), 1 + ) + used_engine = 'cell_crop_v2_psm7' + + # --- Noise filter --- + if text.strip(): + text = _clean_cell_text_lite(text) + if not text: + avg_conf = 0.0 + + result = dict(empty_cell) + result['text'] = text + result['confidence'] = avg_conf + result['ocr_engine'] = used_engine + return result + + +def build_cell_grid_v2( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Cell-First Grid: crop each cell in isolation, then OCR. + + Drop-in replacement for build_cell_grid() — same signature & return type. + No full-page word assignment; each cell is OCR'd from its own crop. + """ + # Resolve engine + use_rapid = False + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": + use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" + elif ocr_engine == "rapid": + if not RAPIDOCR_AVAILABLE: + logger.warning("RapidOCR requested but not available, falling back to Tesseract") + else: + use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" + + logger.info(f"build_cell_grid_v2: using OCR engine '{engine_name}'") + + # Filter to content rows only + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows found") + return [], [] + + # Filter phantom rows (word_count=0) and artifact rows + before = len(content_rows) + content_rows = [r for r in content_rows if r.word_count > 0] + skipped = before - len(content_rows) + if skipped > 0: + logger.info(f"build_cell_grid_v2: skipped {skipped} phantom rows (word_count=0)") + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows with words found") + return [], [] + + before_art = len(content_rows) + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + artifact_skipped = before_art - len(content_rows) + if artifact_skipped > 0: + logger.info(f"build_cell_grid_v2: skipped {artifact_skipped} artifact rows") + if not content_rows: + logger.warning("build_cell_grid_v2: no content rows after artifact filtering") + return [], [] + + # Filter columns + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', + 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + logger.warning("build_cell_grid_v2: no usable columns found") + return [], [] + + # Heal row gaps + _heal_row_gaps( + content_rows, + top_bound=min(c.y for c in relevant_cols), + bottom_bound=max(c.y + c.height for c in relevant_cols), + ) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} + for ci, c in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + # --- Parallel OCR with ThreadPoolExecutor --- + # Tesseract is single-threaded per call, so we benefit from parallelism. + # ~40 rows × 4 cols = 160 cells, ~50% empty (density skip) → ~80 OCR calls. + cells: List[Dict[str, Any]] = [] + cell_tasks = [] + + for row_idx, row in enumerate(content_rows): + for col_idx, col in enumerate(relevant_cols): + cell_tasks.append((row_idx, col_idx, row, col)) + + max_workers = 4 if engine_name == "tesseract" else 2 + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = { + pool.submit( + _ocr_cell_crop, + ri, ci, row, col, + ocr_img, img_bgr, img_w, img_h, + engine_name, lang, lang_map, + ): (ri, ci) + for ri, ci, row, col in cell_tasks + } + + for future in as_completed(futures): + try: + cell = future.result() + cells.append(cell) + except Exception as e: + ri, ci = futures[future] + logger.error(f"build_cell_grid_v2: cell R{ri:02d}_C{ci} failed: {e}") + + # Sort cells by (row_index, col_index) since futures complete out of order + cells.sort(key=lambda c: (c['row_index'], c['col_index'])) + + # Remove all-empty rows + rows_with_text: set = set() + for cell in cells: + if cell['text'].strip(): + rows_with_text.add(cell['row_index']) + before_filter = len(cells) + cells = [c for c in cells if c['row_index'] in rows_with_text] + empty_rows_removed = (before_filter - len(cells)) // max(len(relevant_cols), 1) + if empty_rows_removed > 0: + logger.info(f"build_cell_grid_v2: removed {empty_rows_removed} all-empty rows") + + logger.info(f"build_cell_grid_v2: {len(cells)} cells from " + f"{len(content_rows)} rows × {len(relevant_cols)} columns, " + f"engine={engine_name}") + + return cells, columns_meta + + +def build_cell_grid_v2_streaming( + ocr_img: np.ndarray, + column_regions: List[PageRegion], + row_geometries: List[RowGeometry], + img_w: int, + img_h: int, + lang: str = "eng+deu", + ocr_engine: str = "auto", + img_bgr: Optional[np.ndarray] = None, +) -> Generator[Tuple[Dict[str, Any], List[Dict[str, Any]], int], None, None]: + """Streaming variant of build_cell_grid_v2 — yields each cell as OCR'd. + + Yields: + (cell_dict, columns_meta, total_cells) + """ + # Resolve engine + use_rapid = False + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": + use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" + elif ocr_engine == "rapid": + if not RAPIDOCR_AVAILABLE: + logger.warning("RapidOCR requested but not available, falling back to Tesseract") + else: + use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" + + content_rows = [r for r in row_geometries if r.row_type == 'content'] + if not content_rows: + return + + content_rows = [r for r in content_rows if r.word_count > 0] + if not content_rows: + return + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', + 'margin_bottom', 'margin_left', 'margin_right'} + relevant_cols = [c for c in column_regions if c.type not in _skip_types] + if not relevant_cols: + return + + content_rows = [r for r in content_rows if not _is_artifact_row(r)] + if not content_rows: + return + + _heal_row_gaps( + content_rows, + top_bound=min(c.y for c in relevant_cols), + bottom_bound=max(c.y + c.height for c in relevant_cols), + ) + + relevant_cols.sort(key=lambda c: c.x) + + columns_meta = [ + {'index': ci, 'type': c.type, 'x': c.x, 'width': c.width} + for ci, c in enumerate(relevant_cols) + ] + + lang_map = { + 'column_en': 'eng', + 'column_de': 'deu', + 'column_example': 'eng+deu', + } + + total_cells = len(content_rows) * len(relevant_cols) + + for row_idx, row in enumerate(content_rows): + for col_idx, col in enumerate(relevant_cols): + cell = _ocr_cell_crop( + row_idx, col_idx, row, col, + ocr_img, img_bgr, img_w, img_h, + engine_name, lang, lang_map, + ) + yield cell, columns_meta, total_cells + + +# --------------------------------------------------------------------------- +# Narrow-column OCR helpers (Proposal B) — DEPRECATED (kept for legacy build_cell_grid) # --------------------------------------------------------------------------- def _compute_cell_padding(col_width: int, img_w: int) -> int: diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index c02f36e..b1636e5 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -32,6 +32,7 @@ from pydantic import BaseModel from cv_vocab_pipeline import ( OLLAMA_REVIEW_MODEL, + DocumentTypeResult, PageRegion, RowGeometry, _cells_to_vocab_entries, @@ -43,6 +44,8 @@ from cv_vocab_pipeline import ( analyze_layout_by_words, build_cell_grid, build_cell_grid_streaming, + build_cell_grid_v2, + build_cell_grid_v2_streaming, build_word_grid, classify_column_types, create_layout_image, @@ -50,6 +53,7 @@ from cv_vocab_pipeline import ( deskew_image, deskew_image_by_word_alignment, detect_column_geometry, + detect_document_type, detect_row_geometry, expand_narrow_columns, _apply_shear, @@ -759,6 +763,54 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques return {"session_id": session_id, "ground_truth": gt} +# --------------------------------------------------------------------------- +# Document Type Detection (between Dewarp and Columns) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-type") +async def detect_type(session_id: str): + """Detect document type (vocab_table, full_text, generic_table). + + Should be called after dewarp (clean image available). + Stores result in session for frontend to decide pipeline flow. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("dewarped_bgr") + if dewarped_bgr is None: + raise HTTPException(status_code=400, detail="Dewarp must be completed first") + + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + result = detect_document_type(ocr_img, dewarped_bgr) + duration = time.time() - t0 + + result_dict = { + "doc_type": result.doc_type, + "confidence": result.confidence, + "pipeline": result.pipeline, + "skip_steps": result.skip_steps, + "features": result.features, + "duration_seconds": round(duration, 2), + } + + # Persist to DB + await update_session_db( + session_id, + doc_type=result.doc_type, + doc_type_result=result_dict, + ) + + cached["doc_type_result"] = result_dict + + logger.info(f"OCR Pipeline: detect-type session {session_id}: " + f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") + + return {"session_id": session_id, **result_dict} + + # --------------------------------------------------------------------------- # Column Detection Endpoints (Step 3) # --------------------------------------------------------------------------- @@ -1196,8 +1248,10 @@ async def detect_words( for r in row_result["rows"] ] - # Re-populate row.words from cached full-page Tesseract words. - # Word-lookup in _ocr_single_cell needs these to avoid re-running OCR. + # Cell-First OCR (v2): no full-page word re-population needed. + # Each cell is cropped and OCR'd in isolation → no neighbour bleeding. + # We still need word_count > 0 for row filtering in build_cell_grid_v2, + # so populate from cached words if available (just for counting). word_dicts = cached.get("_word_dicts") if word_dicts is None: ocr_img_tmp = create_ocr_image(dewarped_bgr) @@ -1209,8 +1263,6 @@ async def detect_words( cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) if word_dicts: - # words['top'] is relative to content-ROI top_y. - # row.y is absolute. Convert: row_y_rel = row.y - top_y. content_bounds = cached.get("_content_bounds") if content_bounds: _lx, _rx, top_y, _by = content_bounds @@ -1240,15 +1292,15 @@ async def detect_words( }, ) - # --- Non-streaming path (unchanged) --- + # --- Non-streaming path --- t0 = time.time() # Create binarized OCR image (for Tesseract) ocr_img = create_ocr_image(dewarped_bgr) img_h, img_w = dewarped_bgr.shape[:2] - # Build generic cell grid - cells, columns_meta = build_cell_grid( + # Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation + cells, columns_meta = build_cell_grid_v2( ocr_img, col_regions, row_geoms, img_w, img_h, ocr_engine=engine, img_bgr=dewarped_bgr, ) @@ -1358,7 +1410,7 @@ async def _word_stream_generator( all_cells: List[Dict[str, Any]] = [] cell_idx = 0 - for cell, cols_meta, total in build_cell_grid_streaming( + for cell, cols_meta, total in build_cell_grid_v2_streaming( ocr_img, col_regions, row_geoms, img_w, img_h, ocr_engine=engine, img_bgr=dewarped_bgr, ): diff --git a/klausur-service/backend/ocr_pipeline_session_store.py b/klausur-service/backend/ocr_pipeline_session_store.py index 8c58def..254d662 100644 --- a/klausur-service/backend/ocr_pipeline_session_store.py +++ b/klausur-service/backend/ocr_pipeline_session_store.py @@ -64,7 +64,9 @@ async def init_ocr_pipeline_tables(): await conn.execute(""" ALTER TABLE ocr_pipeline_sessions ADD COLUMN IF NOT EXISTS clean_png BYTEA, - ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB + ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB, + ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50), + ADD COLUMN IF NOT EXISTS doc_type_result JSONB """) @@ -88,6 +90,7 @@ async def create_session_db( RETURNING id, name, filename, status, current_step, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, created_at, updated_at """, uuid.UUID(session_id), name, filename, original_png) @@ -102,6 +105,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]: SELECT id, name, filename, status, current_step, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, created_at, updated_at FROM ocr_pipeline_sessions WHERE id = $1 """, uuid.UUID(session_id)) @@ -146,9 +150,10 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any 'clean_png', 'handwriting_removal_meta', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'auto_shear_degrees', + 'doc_type', 'doc_type_result', } - jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'} + jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result'} for key, value in kwargs.items(): if key in allowed_fields: @@ -174,6 +179,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any RETURNING id, name, filename, status, current_step, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, + doc_type, doc_type_result, created_at, updated_at """, *values) @@ -229,7 +235,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: result[key] = result[key].isoformat() # JSONB → parsed (asyncpg returns str for JSONB) - for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']: + for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result']: if key in result and result[key] is not None: if isinstance(result[key], str): result[key] = json.loads(result[key]) diff --git a/klausur-service/backend/tests/test_cv_vocab_pipeline.py b/klausur-service/backend/tests/test_cv_vocab_pipeline.py index 1a602c2..67ca888 100644 --- a/klausur-service/backend/tests/test_cv_vocab_pipeline.py +++ b/klausur-service/backend/tests/test_cv_vocab_pipeline.py @@ -25,7 +25,9 @@ from dataclasses import asdict # Import module under test from cv_vocab_pipeline import ( ColumnGeometry, + DocumentTypeResult, PageRegion, + RowGeometry, VocabRow, PipelineResult, deskew_image, @@ -48,9 +50,12 @@ from cv_vocab_pipeline import ( CV_PIPELINE_AVAILABLE, _is_noise_tail_token, _clean_cell_text, + _clean_cell_text_lite, _is_phonetic_only_text, _merge_phonetic_continuation_rows, _merge_continuation_rows, + _ocr_cell_crop, + detect_document_type, ) @@ -1566,6 +1571,167 @@ class TestCellsToVocabEntriesPageRef: assert entries[0]['source_page'] == 'p.59' +# ============================================= +# CELL-FIRST OCR (v2) TESTS +# ============================================= + +class TestCleanCellTextLite: + """Tests for _clean_cell_text_lite() — simplified noise filter.""" + + def test_empty_string(self): + assert _clean_cell_text_lite('') == '' + + def test_whitespace_only(self): + assert _clean_cell_text_lite(' ') == '' + + def test_real_word_passes(self): + assert _clean_cell_text_lite('hello') == 'hello' + + def test_sentence_passes(self): + assert _clean_cell_text_lite('to have dinner') == 'to have dinner' + + def test_garbage_text_cleared(self): + """Garbage text (no dictionary words) should be cleared.""" + assert _clean_cell_text_lite('xqzjk') == '' + + def test_no_real_word_cleared(self): + """Single chars with no real word (2+ letters) cleared.""" + assert _clean_cell_text_lite('3') == '' + assert _clean_cell_text_lite('|') == '' + + def test_known_abbreviation_kept(self): + """Known abbreviations should pass through.""" + assert _clean_cell_text_lite('sth') == 'sth' + assert _clean_cell_text_lite('eg') == 'eg' + + def test_no_trailing_noise_stripping(self): + """Unlike _clean_cell_text, lite does NOT strip trailing tokens. + Since cells are isolated, all tokens are legitimate.""" + result = _clean_cell_text_lite('apple tree') + assert result == 'apple tree' + + def test_page_reference(self): + """Page references like p.60 should pass.""" + # 'p' is a known abbreviation + assert _clean_cell_text_lite('p.60') != '' + + +class TestOcrCellCrop: + """Tests for _ocr_cell_crop() — isolated cell OCR.""" + + def test_empty_cell_pixel_density(self): + """Cells with very few dark pixels should return empty text.""" + # All white image → no text + ocr_img = np.ones((400, 600), dtype=np.uint8) * 255 + row = RowGeometry(index=0, x=0, y=50, width=600, height=30, + word_count=1, words=[{'text': 'a'}]) + col = PageRegion(type='column_en', x=50, y=0, width=200, height=400) + + result = _ocr_cell_crop( + 0, 0, row, col, ocr_img, None, 600, 400, + 'tesseract', 'eng+deu', {'column_en': 'eng'}, + ) + assert result['text'] == '' + assert result['cell_id'] == 'R00_C0' + assert result['col_type'] == 'column_en' + + def test_zero_width_cell(self): + """Zero-width cells should return empty.""" + ocr_img = np.ones((400, 600), dtype=np.uint8) * 255 + row = RowGeometry(index=0, x=0, y=50, width=600, height=30, + word_count=1, words=[]) + col = PageRegion(type='column_en', x=50, y=0, width=0, height=400) + + result = _ocr_cell_crop( + 0, 0, row, col, ocr_img, None, 600, 400, + 'tesseract', 'eng+deu', {}, + ) + assert result['text'] == '' + + def test_bbox_calculation(self): + """Check bbox_px and bbox_pct are correct.""" + ocr_img = np.ones((1000, 2000), dtype=np.uint8) * 255 + row = RowGeometry(index=0, x=0, y=100, width=2000, height=50, + word_count=1, words=[{'text': 'test'}]) + col = PageRegion(type='column_de', x=400, y=0, width=600, height=1000) + + result = _ocr_cell_crop( + 0, 0, row, col, ocr_img, None, 2000, 1000, + 'tesseract', 'eng+deu', {'column_de': 'deu'}, + ) + assert result['bbox_px'] == {'x': 400, 'y': 100, 'w': 600, 'h': 50} + assert result['bbox_pct']['x'] == 20.0 # 400/2000*100 + assert result['bbox_pct']['y'] == 10.0 # 100/1000*100 + + +class TestDetectDocumentType: + """Tests for detect_document_type() — image-based classification.""" + + def test_empty_image(self): + """Empty image should default to full_text.""" + empty = np.array([], dtype=np.uint8).reshape(0, 0) + result = detect_document_type(empty, empty) + assert result.doc_type == 'full_text' + assert result.pipeline == 'full_page' + + def test_table_image_detected(self): + """Image with clear column gaps and row gaps → table.""" + # Create 600x400 binary image with 3 columns separated by white gaps + img = np.ones((400, 600), dtype=np.uint8) * 255 + # Column 1: x=20..170 + for y in range(30, 370, 20): + img[y:y+10, 20:170] = 0 + # Gap: x=170..210 (white) + # Column 2: x=210..370 + for y in range(30, 370, 20): + img[y:y+10, 210:370] = 0 + # Gap: x=370..410 (white) + # Column 3: x=410..580 + for y in range(30, 370, 20): + img[y:y+10, 410:580] = 0 + + bgr = np.stack([img, img, img], axis=-1) + result = detect_document_type(img, bgr) + assert result.doc_type in ('vocab_table', 'generic_table') + assert result.pipeline == 'cell_first' + assert result.confidence >= 0.5 + + def test_fulltext_image_detected(self): + """Uniform text without column gaps → full_text.""" + img = np.ones((400, 600), dtype=np.uint8) * 255 + # Uniform text lines across full width (no column gaps) + for y in range(30, 370, 15): + img[y:y+8, 30:570] = 0 + + bgr = np.stack([img, img, img], axis=-1) + result = detect_document_type(img, bgr) + assert result.doc_type == 'full_text' + assert result.pipeline == 'full_page' + assert 'columns' in result.skip_steps + assert 'rows' in result.skip_steps + + def test_result_has_features(self): + """Result should contain debug features.""" + img = np.ones((200, 300), dtype=np.uint8) * 255 + bgr = np.stack([img, img, img], axis=-1) + result = detect_document_type(img, bgr) + assert 'vertical_gaps' in result.features + assert 'row_gaps' in result.features + assert 'density_mean' in result.features + assert 'density_std' in result.features + + def test_document_type_result_dataclass(self): + """DocumentTypeResult dataclass should initialize correctly.""" + r = DocumentTypeResult( + doc_type='vocab_table', + confidence=0.9, + pipeline='cell_first', + ) + assert r.doc_type == 'vocab_table' + assert r.skip_steps == [] + assert r.features == {} + + # ============================================= # RUN TESTS # =============================================