feat: Persistente Sessions (PostgreSQL) + Spaltenerkennung (Step 3)

Sessions werden jetzt in PostgreSQL gespeichert statt in-memory.
Neue Session-Liste mit Name, Datum, Schritt. Sessions ueberleben
Browser-Refresh und Container-Neustart. Step 3 nutzt analyze_layout()
fuer automatische Spaltenerkennung mit farbigem Overlay.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-26 22:16:37 +01:00
parent 09b820efbe
commit aa06ae0f61
9 changed files with 1233 additions and 130 deletions

View File

@@ -1,6 +1,6 @@
'use client'
import { useState } from 'react'
import { useCallback, useEffect, useState } from 'react'
import { PagePurpose } from '@/components/common/PagePurpose'
import { PipelineStepper } from '@/components/ocr-pipeline/PipelineStepper'
import { StepDeskew } from '@/components/ocr-pipeline/StepDeskew'
@@ -10,11 +10,18 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti
import { StepCoordinates } from '@/components/ocr-pipeline/StepCoordinates'
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
import { PIPELINE_STEPS, type PipelineStep } from './types'
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types'
const KLAUSUR_API = '/klausur-api'
export default function OcrPipelinePage() {
const [currentStep, setCurrentStep] = useState(0)
const [sessionId, setSessionId] = useState<string | null>(null)
const [sessionName, setSessionName] = useState<string>('')
const [sessions, setSessions] = useState<SessionListItem[]>([])
const [loadingSessions, setLoadingSessions] = useState(true)
const [editingName, setEditingName] = useState<string | null>(null)
const [editNameValue, setEditNameValue] = useState('')
const [steps, setSteps] = useState<PipelineStep[]>(
PIPELINE_STEPS.map((s, i) => ({
...s,
@@ -22,6 +29,82 @@ export default function OcrPipelinePage() {
})),
)
// Load session list on mount
useEffect(() => {
loadSessions()
}, [])
const loadSessions = async () => {
setLoadingSessions(true)
try {
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`)
if (res.ok) {
const data = await res.json()
setSessions(data.sessions || [])
}
} catch (e) {
console.error('Failed to load sessions:', e)
} finally {
setLoadingSessions(false)
}
}
const openSession = useCallback(async (sid: string) => {
try {
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`)
if (!res.ok) return
const data = await res.json()
setSessionId(sid)
setSessionName(data.name || data.filename || '')
// Determine which step to jump to based on current_step
const dbStep = data.current_step || 1
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
const uiStep = Math.max(0, dbStep - 1)
setSteps(
PIPELINE_STEPS.map((s, i) => ({
...s,
status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
})),
)
setCurrentStep(uiStep)
} catch (e) {
console.error('Failed to open session:', e)
}
}, [])
const deleteSession = useCallback(async (sid: string) => {
try {
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, { method: 'DELETE' })
setSessions((prev) => prev.filter((s) => s.id !== sid))
if (sessionId === sid) {
setSessionId(null)
setCurrentStep(0)
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
}
} catch (e) {
console.error('Failed to delete session:', e)
}
}, [sessionId])
const renameSession = useCallback(async (sid: string, newName: string) => {
try {
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ name: newName }),
})
setSessions((prev) => prev.map((s) => (s.id === sid ? { ...s, name: newName } : s)))
if (sessionId === sid) setSessionName(newName)
} catch (e) {
console.error('Failed to rename session:', e)
}
setEditingName(null)
}, [sessionId])
const handleStepClick = (index: number) => {
if (index <= currentStep || steps[index].status === 'completed') {
setCurrentStep(index)
@@ -43,9 +126,28 @@ export default function OcrPipelinePage() {
const handleDeskewComplete = (sid: string) => {
setSessionId(sid)
// Reload session list to show the new session
loadSessions()
handleNext()
}
const handleNewSession = () => {
setSessionId(null)
setSessionName('')
setCurrentStep(0)
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
}
const stepNames: Record<number, string> = {
1: 'Begradigung',
2: 'Entzerrung',
3: 'Spalten',
4: 'Woerter',
5: 'Koordinaten',
6: 'Rekonstruktion',
7: 'Validierung',
}
const renderStep = () => {
switch (currentStep) {
case 0:
@@ -53,7 +155,7 @@ export default function OcrPipelinePage() {
case 1:
return <StepDewarp sessionId={sessionId} onNext={handleNext} />
case 2:
return <StepColumnDetection />
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
case 3:
return <StepWordRecognition />
case 4:
@@ -75,7 +177,7 @@ export default function OcrPipelinePage() {
audience={['Entwickler', 'Data Scientists']}
architecture={{
services: ['klausur-service (FastAPI)', 'OpenCV', 'Tesseract'],
databases: ['In-Memory Sessions'],
databases: ['PostgreSQL Sessions'],
}}
relatedPages={[
{ name: 'OCR Vergleich', href: '/ai/ocr-compare', description: 'Methoden-Vergleich' },
@@ -84,6 +186,97 @@ export default function OcrPipelinePage() {
defaultCollapsed
/>
{/* Session List */}
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4">
<div className="flex items-center justify-between mb-3">
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
Sessions
</h3>
<button
onClick={handleNewSession}
className="text-xs px-3 py-1.5 bg-teal-600 text-white rounded-lg hover:bg-teal-700 transition-colors"
>
+ Neue Session
</button>
</div>
{loadingSessions ? (
<div className="text-sm text-gray-400 py-2">Lade Sessions...</div>
) : sessions.length === 0 ? (
<div className="text-sm text-gray-400 py-2">Noch keine Sessions vorhanden.</div>
) : (
<div className="space-y-1 max-h-48 overflow-y-auto">
{sessions.map((s) => (
<div
key={s.id}
className={`flex items-center gap-2 px-3 py-2 rounded-lg text-sm transition-colors cursor-pointer ${
sessionId === s.id
? 'bg-teal-50 dark:bg-teal-900/30 border border-teal-200 dark:border-teal-700'
: 'hover:bg-gray-50 dark:hover:bg-gray-700/50'
}`}
>
<div className="flex-1 min-w-0" onClick={() => openSession(s.id)}>
{editingName === s.id ? (
<input
autoFocus
value={editNameValue}
onChange={(e) => setEditNameValue(e.target.value)}
onBlur={() => renameSession(s.id, editNameValue)}
onKeyDown={(e) => {
if (e.key === 'Enter') renameSession(s.id, editNameValue)
if (e.key === 'Escape') setEditingName(null)
}}
onClick={(e) => e.stopPropagation()}
className="w-full px-1 py-0.5 text-sm border rounded dark:bg-gray-700 dark:border-gray-600"
/>
) : (
<div className="truncate font-medium text-gray-700 dark:text-gray-300">
{s.name || s.filename}
</div>
)}
<div className="text-xs text-gray-400 flex gap-2">
<span>{new Date(s.created_at).toLocaleDateString('de-DE', { day: '2-digit', month: '2-digit', year: '2-digit', hour: '2-digit', minute: '2-digit' })}</span>
<span>Schritt {s.current_step}: {stepNames[s.current_step] || '?'}</span>
</div>
</div>
<button
onClick={(e) => {
e.stopPropagation()
setEditNameValue(s.name || s.filename)
setEditingName(s.id)
}}
className="p-1 text-gray-400 hover:text-gray-600 dark:hover:text-gray-300"
title="Umbenennen"
>
<svg className="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
<path strokeLinecap="round" strokeLinejoin="round" d="M15.232 5.232l3.536 3.536m-2.036-5.036a2.5 2.5 0 113.536 3.536L6.5 21.036H3v-3.572L16.732 3.732z" />
</svg>
</button>
<button
onClick={(e) => {
e.stopPropagation()
if (confirm('Session loeschen?')) deleteSession(s.id)
}}
className="p-1 text-gray-400 hover:text-red-500"
title="Loeschen"
>
<svg className="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
<path strokeLinecap="round" strokeLinejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
</svg>
</button>
</div>
))}
</div>
)}
</div>
{/* Active session name */}
{sessionId && sessionName && (
<div className="text-sm text-gray-500 dark:text-gray-400">
Aktive Session: <span className="font-medium text-gray-700 dark:text-gray-300">{sessionName}</span>
</div>
)}
<PipelineStepper steps={steps} currentStep={currentStep} onStepClick={handleStepClick} />
<div className="min-h-[400px]">{renderStep()}</div>

View File

@@ -7,14 +7,27 @@ export interface PipelineStep {
status: PipelineStepStatus
}
export interface SessionListItem {
id: string
name: string
filename: string
status: string
current_step: number
created_at: string
updated_at?: string
}
export interface SessionInfo {
session_id: string
filename: string
name?: string
image_width: number
image_height: number
original_image_url: string
current_step?: number
deskew_result?: DeskewResult
dewarp_result?: DewarpResult
column_result?: ColumnResult
}
export interface DeskewResult {
@@ -50,6 +63,24 @@ export interface DewarpGroundTruth {
notes?: string
}
export interface PageRegion {
type: 'column_en' | 'column_de' | 'column_example' | 'header' | 'footer'
x: number
y: number
width: number
height: number
}
export interface ColumnResult {
columns: PageRegion[]
duration_seconds: number
}
export interface ColumnGroundTruth {
is_correct: boolean
notes?: string
}
export const PIPELINE_STEPS: PipelineStep[] = [
{ id: 'deskew', name: 'Begradigung', icon: '📐', status: 'pending' },
{ id: 'dewarp', name: 'Entzerrung', icon: '🔧', status: 'pending' },

View File

@@ -0,0 +1,119 @@
'use client'
import { useState } from 'react'
import type { ColumnResult, ColumnGroundTruth, PageRegion } from '@/app/(admin)/ai/ocr-pipeline/types'
interface ColumnControlsProps {
columnResult: ColumnResult | null
onRerun: () => void
onGroundTruth: (gt: ColumnGroundTruth) => void
onNext: () => void
isDetecting: boolean
}
const TYPE_COLORS: Record<string, string> = {
column_en: 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400',
column_de: 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400',
column_example: 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400',
header: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400',
footer: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400',
}
const TYPE_LABELS: Record<string, string> = {
column_en: 'EN',
column_de: 'DE',
column_example: 'Beispiel',
header: 'Header',
footer: 'Footer',
}
export function ColumnControls({ columnResult, onRerun, onGroundTruth, onNext, isDetecting }: ColumnControlsProps) {
const [gtSaved, setGtSaved] = useState(false)
if (!columnResult) return null
const columns = columnResult.columns.filter((c: PageRegion) => c.type.startsWith('column'))
const headerFooter = columnResult.columns.filter((c: PageRegion) => !c.type.startsWith('column'))
const handleGt = (isCorrect: boolean) => {
onGroundTruth({ is_correct: isCorrect })
setGtSaved(true)
}
return (
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-4">
{/* Summary */}
<div className="flex items-center gap-3">
<div className="text-sm text-gray-600 dark:text-gray-400">
<span className="font-medium text-gray-800 dark:text-gray-200">{columns.length} Spalten</span> erkannt
{columnResult.duration_seconds > 0 && (
<span className="ml-2 text-xs">({columnResult.duration_seconds}s)</span>
)}
</div>
<button
onClick={onRerun}
disabled={isDetecting}
className="text-xs px-2 py-1 bg-gray-100 dark:bg-gray-700 rounded hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors disabled:opacity-50"
>
Erneut erkennen
</button>
</div>
{/* Column list */}
<div className="space-y-2">
{columns.map((col: PageRegion, i: number) => (
<div key={i} className="flex items-center gap-3 text-sm">
<span className={`px-2 py-0.5 rounded text-xs font-medium ${TYPE_COLORS[col.type] || ''}`}>
{TYPE_LABELS[col.type] || col.type}
</span>
<span className="text-gray-500 dark:text-gray-400 text-xs font-mono">
x={col.x} y={col.y} {col.width}x{col.height}px
</span>
</div>
))}
{headerFooter.map((r: PageRegion, i: number) => (
<div key={`hf-${i}`} className="flex items-center gap-3 text-sm">
<span className={`px-2 py-0.5 rounded text-xs font-medium ${TYPE_COLORS[r.type] || ''}`}>
{TYPE_LABELS[r.type] || r.type}
</span>
<span className="text-gray-500 dark:text-gray-400 text-xs font-mono">
x={r.x} y={r.y} {r.width}x{r.height}px
</span>
</div>
))}
</div>
{/* Ground Truth + Navigation */}
<div className="flex items-center justify-between pt-2 border-t border-gray-100 dark:border-gray-700">
<div className="flex items-center gap-2">
<span className="text-sm text-gray-500 dark:text-gray-400">Spalten korrekt?</span>
{gtSaved ? (
<span className="text-xs text-green-600 dark:text-green-400">Gespeichert</span>
) : (
<>
<button
onClick={() => handleGt(true)}
className="text-xs px-3 py-1 bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400 rounded hover:bg-green-200 dark:hover:bg-green-900/50 transition-colors"
>
Ja
</button>
<button
onClick={() => handleGt(false)}
className="text-xs px-3 py-1 bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400 rounded hover:bg-red-200 dark:hover:bg-red-900/50 transition-colors"
>
Nein
</button>
</>
)}
</div>
<button
onClick={onNext}
className="px-4 py-2 bg-teal-600 text-white rounded-lg hover:bg-teal-700 transition-colors text-sm font-medium"
>
Weiter
</button>
</div>
</div>
)
}

View File

@@ -1,19 +1,168 @@
'use client'
export function StepColumnDetection() {
return (
<div className="flex flex-col items-center justify-center py-16 text-center">
<div className="text-5xl mb-4">📊</div>
<h3 className="text-lg font-medium text-gray-700 dark:text-gray-300 mb-2">
Schritt 3: Spaltenerkennung
</h3>
<p className="text-gray-500 dark:text-gray-400 max-w-md">
Erkennung unsichtbarer Spaltentrennungen in der Vokabelseite.
Dieser Schritt wird in einer zukuenftigen Version implementiert.
</p>
<div className="mt-6 px-4 py-2 bg-amber-100 dark:bg-amber-900/30 text-amber-700 dark:text-amber-400 rounded-full text-sm font-medium">
Kommt bald
import { useCallback, useEffect, useState } from 'react'
import type { ColumnResult, ColumnGroundTruth } from '@/app/(admin)/ai/ocr-pipeline/types'
import { ColumnControls } from './ColumnControls'
const KLAUSUR_API = '/klausur-api'
interface StepColumnDetectionProps {
sessionId: string | null
onNext: () => void
}
export function StepColumnDetection({ sessionId, onNext }: StepColumnDetectionProps) {
const [columnResult, setColumnResult] = useState<ColumnResult | null>(null)
const [detecting, setDetecting] = useState(false)
const [error, setError] = useState<string | null>(null)
// Auto-trigger column detection on mount
useEffect(() => {
if (!sessionId || columnResult) return
const runDetection = async () => {
setDetecting(true)
setError(null)
try {
// First check if columns already detected (reload case)
const infoRes = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}`)
if (infoRes.ok) {
const info = await infoRes.json()
if (info.column_result) {
setColumnResult(info.column_result)
setDetecting(false)
return
}
}
// Run detection
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, {
method: 'POST',
})
if (!res.ok) {
const err = await res.json().catch(() => ({ detail: res.statusText }))
throw new Error(err.detail || 'Spaltenerkennung fehlgeschlagen')
}
const data: ColumnResult = await res.json()
setColumnResult(data)
} catch (e) {
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
} finally {
setDetecting(false)
}
}
runDetection()
}, [sessionId, columnResult])
const handleRerun = useCallback(async () => {
if (!sessionId) return
setDetecting(true)
setError(null)
try {
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, {
method: 'POST',
})
if (!res.ok) throw new Error('Spaltenerkennung fehlgeschlagen')
const data: ColumnResult = await res.json()
setColumnResult(data)
} catch (e) {
setError(e instanceof Error ? e.message : 'Fehler')
} finally {
setDetecting(false)
}
}, [sessionId])
const handleGroundTruth = useCallback(async (gt: ColumnGroundTruth) => {
if (!sessionId) return
try {
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/ground-truth/columns`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(gt),
})
} catch (e) {
console.error('Ground truth save failed:', e)
}
}, [sessionId])
if (!sessionId) {
return (
<div className="flex flex-col items-center justify-center py-16 text-center">
<div className="text-5xl mb-4">📊</div>
<h3 className="text-lg font-medium text-gray-700 dark:text-gray-300 mb-2">
Schritt 3: Spaltenerkennung
</h3>
<p className="text-gray-500 dark:text-gray-400 max-w-md">
Bitte zuerst Schritt 1 und 2 abschliessen.
</p>
</div>
)
}
const dewarpedUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/dewarped`
const overlayUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/columns-overlay`
return (
<div className="space-y-4">
{/* Loading indicator */}
{detecting && (
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
Spaltenerkennung laeuft...
</div>
)}
{/* Image comparison: overlay (left) vs clean (right) */}
<div className="grid grid-cols-2 gap-4">
<div>
<div className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-1">
Mit Spalten-Overlay
</div>
<div className="border rounded-lg overflow-hidden dark:border-gray-700 bg-gray-50 dark:bg-gray-900">
{columnResult ? (
// eslint-disable-next-line @next/next/no-img-element
<img
src={`${overlayUrl}?t=${Date.now()}`}
alt="Spalten-Overlay"
className="w-full h-auto"
/>
) : (
<div className="aspect-[3/4] flex items-center justify-center text-gray-400 text-sm">
{detecting ? 'Erkenne Spalten...' : 'Keine Daten'}
</div>
)}
</div>
</div>
<div>
<div className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-1">
Entzerrtes Bild
</div>
<div className="border rounded-lg overflow-hidden dark:border-gray-700 bg-gray-50 dark:bg-gray-900">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={dewarpedUrl}
alt="Entzerrt"
className="w-full h-auto"
/>
</div>
</div>
</div>
{/* Controls */}
<ColumnControls
columnResult={columnResult}
onRerun={handleRerun}
onGroundTruth={handleGroundTruth}
onNext={onNext}
isDetecting={detecting}
/>
{error && (
<div className="p-3 bg-red-50 dark:bg-red-900/20 text-red-600 dark:text-red-400 rounded-lg text-sm">
{error}
</div>
)}
</div>
)
}

View File

@@ -22,6 +22,7 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
const [showGrid, setShowGrid] = useState(true)
const [error, setError] = useState<string | null>(null)
const [dragOver, setDragOver] = useState(false)
const [sessionName, setSessionName] = useState('')
// Reload session data when navigating back from a later step
useEffect(() => {
@@ -67,6 +68,9 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
try {
const formData = new FormData()
formData.append('file', file)
if (sessionName.trim()) {
formData.append('name', sessionName.trim())
}
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`, {
method: 'POST',
@@ -167,6 +171,20 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
if (!session) {
return (
<div className="space-y-4">
{/* Session name input */}
<div>
<label className="block text-sm font-medium text-gray-600 dark:text-gray-400 mb-1">
Session-Name (optional)
</label>
<input
type="text"
value={sessionName}
onChange={(e) => setSessionName(e.target.value)}
placeholder="z.B. Unit 3 Seite 42"
className="w-full max-w-sm px-3 py-2 text-sm border rounded-lg dark:bg-gray-800 dark:border-gray-600 dark:text-gray-200 focus:outline-none focus:ring-2 focus:ring-teal-500"
/>
</div>
<div
onDragOver={(e) => { e.preventDefault(); setDragOver(true) }}
onDragLeave={() => setDragOver(false)}

View File

@@ -43,6 +43,7 @@ except ImportError:
trocr_router = None
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
from ocr_pipeline_api import router as ocr_pipeline_router
from ocr_pipeline_session_store import init_ocr_pipeline_tables
try:
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
@@ -76,6 +77,13 @@ async def lifespan(app: FastAPI):
except Exception as e:
print(f"Warning: Vocab sessions database initialization failed: {e}")
# Initialize OCR Pipeline session tables
try:
await init_ocr_pipeline_tables()
print("OCR Pipeline session tables initialized")
except Exception as e:
print(f"Warning: OCR Pipeline tables initialization failed: {e}")
# Initialize database pool for DSFA RAG
dsfa_db_pool = None
if DSFA_DATABASE_URL and set_dsfa_db_pool:

View File

@@ -0,0 +1,28 @@
-- OCR Pipeline Sessions - Persistent session storage
-- Applied automatically by ocr_pipeline_session_store.init_ocr_pipeline_tables()
CREATE TABLE IF NOT EXISTS ocr_pipeline_sessions (
id UUID PRIMARY KEY,
name VARCHAR(255) NOT NULL,
filename VARCHAR(255),
status VARCHAR(50) DEFAULT 'active',
current_step INT DEFAULT 1,
original_png BYTEA,
deskewed_png BYTEA,
binarized_png BYTEA,
dewarped_png BYTEA,
deskew_result JSONB,
dewarp_result JSONB,
column_result JSONB,
ground_truth JSONB DEFAULT '{}',
auto_shear_degrees FLOAT,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
-- Index for listing sessions
CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_created
ON ocr_pipeline_sessions (created_at DESC);
CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_status
ON ocr_pipeline_sessions (status);

View File

@@ -14,20 +14,21 @@ Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import logging
import time
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from cv_vocab_pipeline import (
analyze_layout,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
@@ -36,34 +37,67 @@ from cv_vocab_pipeline import (
render_image_high_res,
render_pdf_high_res,
)
from ocr_pipeline_session_store import (
create_session_db,
delete_session_db,
get_session_db,
get_session_image,
init_ocr_pipeline_tables,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# In-memory session store (24h TTL)
# In-memory cache for active sessions (BGR numpy arrays for processing)
# DB is source of truth, cache holds BGR arrays during active processing.
# ---------------------------------------------------------------------------
_sessions: Dict[str, Dict[str, Any]] = {}
SESSION_TTL_HOURS = 24
_cache: Dict[str, Dict[str, Any]] = {}
def _cleanup_expired():
"""Remove sessions older than TTL."""
cutoff = datetime.utcnow() - timedelta(hours=SESSION_TTL_HOURS)
expired = [sid for sid, s in _sessions.items() if s.get("created_at", datetime.utcnow()) < cutoff]
for sid in expired:
del _sessions[sid]
logger.info(f"OCR Pipeline: expired session {sid}")
def _get_session(session_id: str) -> Dict[str, Any]:
"""Get session or raise 404."""
session = _sessions.get(session_id)
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return session
if session_id in _cache:
return _cache[session_id]
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
# Decode images from DB into BGR numpy arrays
for img_type, bgr_key in [
("original", "original_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
def _get_cached(session_id: str) -> Dict[str, Any]:
"""Get from cache or raise 404."""
entry = _cache.get(session_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
return entry
# ---------------------------------------------------------------------------
@@ -90,15 +124,36 @@ class DewarpGroundTruthRequest(BaseModel):
notes: Optional[str] = None
class RenameSessionRequest(BaseModel):
name: str
class ManualColumnsRequest(BaseModel):
columns: List[Dict[str, Any]]
class ColumnGroundTruthRequest(BaseModel):
is_correct: bool
notes: Optional[str] = None
# ---------------------------------------------------------------------------
# Endpoints
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions():
"""List all OCR pipeline sessions."""
sessions = await list_sessions_db()
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(file: UploadFile = File(...)):
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session."""
_cleanup_expired()
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
@@ -114,25 +169,32 @@ async def create_session(file: UploadFile = File(...)):
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes for serving
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
_sessions[session_id] = {
original_png = png_buf.tobytes()
session_name = name or filename
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"created_at": datetime.utcnow(),
"name": session_name,
"original_bgr": img_bgr,
"original_png": png_buf.tobytes(),
"deskewed_bgr": None,
"deskewed_png": None,
"binarized_png": None,
"deskew_result": None,
"dewarped_bgr": None,
"dewarped_png": None,
"deskew_result": None,
"dewarp_result": None,
"auto_shear_degrees": None,
"ground_truth": {},
"current_step": 1,
}
@@ -143,6 +205,7 @@ async def create_session(file: UploadFile = File(...)):
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
@@ -151,35 +214,106 @@ async def create_session(file: UploadFile = File(...)):
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp results for step navigation."""
session = _get_session(session_id)
img_bgr = session["original_bgr"]
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session["filename"],
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
}
# Include deskew result if available
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
# Include dewarp result if available
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
return result
@router.put("/sessions/{session_id}")
async def rename_session(session_id: str, req: RenameSessionRequest):
"""Rename a session."""
updated = await update_session_db(session_id, name=req.name)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "name": req.name}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, or columns-overlay."""
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "columns-overlay":
return await _get_columns_overlay(session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Deskew Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/deskew")
async def auto_deskew(session_id: str):
"""Run both deskew methods and pick the best one."""
session = _get_session(session_id)
img_bgr = session["original_bgr"]
# Ensure session is in cache
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
@@ -202,12 +336,10 @@ async def auto_deskew(session_id: str):
duration = time.time() - t0
# Pick method with larger detected angle (more correction needed = more skew found)
# If both are ~0, prefer word alignment as it's more robust
# Pick best method
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
# Decode word alignment result to BGR
wa_array = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_array, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
@@ -219,20 +351,19 @@ async def auto_deskew(session_id: str):
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
# Encode deskewed as PNG
# Encode as PNG
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = deskewed_png_buf.tobytes() if success else session["original_png"]
deskewed_png = deskewed_png_buf.tobytes() if success else b""
# Create binarized version
binarized_png = None
try:
binarized = create_ocr_image(deskewed_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception as e:
logger.warning(f"Binarization failed: {e}")
binarized_png = None
# Confidence: higher angle = lower confidence that we got it right
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
deskew_result = {
@@ -244,13 +375,23 @@ async def auto_deskew(session_id: str):
"duration_seconds": round(duration, 2),
}
session["deskewed_bgr"] = deskewed_bgr
session["deskewed_png"] = deskewed_png
session["binarized_png"] = binarized_png
session["deskew_result"] = deskew_result
# Update cache
cached["deskewed_bgr"] = deskewed_bgr
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
"current_step": 2,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: deskew session {session_id}: "
f"hough={angle_hough:.2f}° wa={angle_wa:.2f}° → {method_used} {angle_applied:.2f}°")
f"hough={angle_hough:.2f} wa={angle_wa:.2f} -> {method_used} {angle_applied:.2f}")
return {
"session_id": session_id,
@@ -263,8 +404,14 @@ async def auto_deskew(session_id: str):
@router.post("/sessions/{session_id}/deskew/manual")
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
"""Apply a manual rotation angle to the original image."""
session = _get_session(session_id)
img_bgr = session["original_bgr"]
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
angle = max(-5.0, min(5.0, req.angle))
h, w = img_bgr.shape[:2]
@@ -275,26 +422,38 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest):
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
deskewed_png = png_buf.tobytes() if success else session["original_png"]
deskewed_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(rotated)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
binarized_png = None
pass
session["deskewed_bgr"] = rotated
session["deskewed_png"] = deskewed_png
session["binarized_png"] = binarized_png
session["deskew_result"] = {
**(session.get("deskew_result") or {}),
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(angle, 3),
"method_used": "manual",
}
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}°")
# Update cache
cached["deskewed_bgr"] = rotated
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
return {
"session_id": session_id,
@@ -304,33 +463,14 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest):
}
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, or binarized."""
session = _get_session(session_id)
if image_type == "original":
data = session.get("original_png")
elif image_type == "deskewed":
data = session.get("deskewed_png")
elif image_type == "dewarped":
data = session.get("dewarped_png")
elif image_type == "binarized":
data = session.get("binarized_png")
else:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
@router.post("/sessions/{session_id}/ground-truth/deskew")
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
"""Save ground truth feedback for the deskew step."""
session = _get_session(session_id)
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_angle": req.corrected_angle,
@@ -338,7 +478,13 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques
"saved_at": datetime.utcnow().isoformat(),
"deskew_result": session.get("deskew_result"),
}
session["ground_truth"]["deskew"] = gt
ground_truth["deskew"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
# Update cache
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
@@ -353,8 +499,11 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques
@router.post("/sessions/{session_id}/dewarp")
async def auto_dewarp(session_id: str):
"""Detect and correct vertical shear on the deskewed image."""
session = _get_session(session_id)
deskewed_bgr = session.get("deskewed_bgr")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
@@ -362,30 +511,37 @@ async def auto_dewarp(session_id: str):
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
duration = time.time() - t0
# Encode dewarped as PNG
# Encode as PNG
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else session["deskewed_png"]
dewarped_png = png_buf.tobytes() if success else b""
session["dewarped_bgr"] = dewarped_bgr
session["dewarped_png"] = dewarped_png
session["auto_shear_degrees"] = dewarp_info.get("shear_degrees", 0.0)
session["dewarp_result"] = {
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=3,
)
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f}° "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
return {
"session_id": session_id,
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
**dewarp_result,
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@@ -393,9 +549,11 @@ async def auto_dewarp(session_id: str):
@router.post("/sessions/{session_id}/dewarp/manual")
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
"""Apply shear correction with a manual angle."""
session = _get_session(session_id)
deskewed_bgr = session.get("deskewed_bgr")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
@@ -407,17 +565,26 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else session.get("deskewed_png")
dewarped_png = png_buf.tobytes() if success else b""
session["dewarped_bgr"] = dewarped_bgr
session["dewarped_png"] = dewarped_png
session["dewarp_result"] = {
**(session.get("dewarp_result") or {}),
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual",
"shear_degrees": round(shear_deg, 3),
}
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}°")
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
)
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
return {
"session_id": session_id,
@@ -430,8 +597,11 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
@router.post("/sessions/{session_id}/ground-truth/dewarp")
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
"""Save ground truth feedback for the dewarp step."""
session = _get_session(session_id)
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_shear": req.corrected_shear,
@@ -439,9 +609,168 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques
"saved_at": datetime.utcnow().isoformat(),
"dewarp_result": session.get("dewarp_result"),
}
session["ground_truth"]["dewarp"] = gt
ground_truth["dewarp"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Column Detection Endpoints (Step 3)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/columns")
async def detect_columns(session_id: str):
"""Run column detection on the dewarped image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection")
t0 = time.time()
# Prepare images for analyze_layout
gray = cv2.cvtColor(dewarped_bgr, cv2.COLOR_BGR2GRAY)
# CLAHE-enhanced for layout analysis
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
layout_img = clahe.apply(gray)
# Binarized for text density
ocr_img = create_ocr_image(dewarped_bgr)
regions = analyze_layout(layout_img, ocr_img)
duration = time.time() - t0
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
column_result=column_result,
current_step=3,
)
# Update cache
cached["column_result"] = column_result
col_count = len([c for c in columns if c["type"].startswith("column")])
logger.info(f"OCR Pipeline: columns session {session_id}: "
f"{col_count} columns detected ({duration:.2f}s)")
return {
"session_id": session_id,
**column_result,
}
@router.post("/sessions/{session_id}/columns/manual")
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
"""Override detected columns with manual definitions."""
column_result = {
"columns": req.columns,
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, column_result=column_result)
if session_id in _cache:
_cache[session_id]["column_result"] = column_result
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
f"{len(req.columns)} columns set")
return {"session_id": session_id, **column_result}
@router.post("/sessions/{session_id}/ground-truth/columns")
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
"""Save ground truth feedback for the column detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"column_result": session.get("column_result"),
}
ground_truth["columns"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate dewarped image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load dewarped image
dewarped_png = await get_session_image(session_id, "dewarped")
if not dewarped_png:
raise HTTPException(status_code=404, detail="Dewarped image not available")
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types
colors = {
"column_en": (255, 180, 0), # Blue (BGR)
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label
label = col.get("type", "unknown").replace("column_", "").upper()
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")

View File

@@ -0,0 +1,228 @@
"""
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
Replaces in-memory storage with database persistence.
See migrations/002_ocr_pipeline_sessions.sql for schema.
"""
import os
import uuid
import logging
import json
from typing import Optional, List, Dict, Any
import asyncpg
logger = logging.getLogger(__name__)
# Database configuration (same as vocab_session_store)
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
)
# Connection pool (initialized lazily)
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
"""Get or create the database connection pool."""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def init_ocr_pipeline_tables():
"""Initialize OCR pipeline tables if they don't exist."""
pool = await get_pool()
async with pool.acquire() as conn:
tables_exist = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'ocr_pipeline_sessions'
)
""")
if not tables_exist:
logger.info("Creating OCR pipeline tables...")
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/002_ocr_pipeline_sessions.sql"
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
logger.info("OCR pipeline tables created successfully")
else:
logger.warning(f"Migration file not found: {migration_path}")
else:
logger.debug("OCR pipeline tables already exist")
# =============================================================================
# SESSION CRUD
# =============================================================================
async def create_session_db(
session_id: str,
name: str,
filename: str,
original_png: bytes,
) -> Dict[str, Any]:
"""Create a new OCR pipeline session."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO ocr_pipeline_sessions (
id, name, filename, original_png, status, current_step
) VALUES ($1, $2, $3, $4, 'active', 1)
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result,
ground_truth, auto_shear_degrees,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png)
return _row_to_dict(row)
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
"""Get session metadata (without images)."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result,
ground_truth, auto_shear_degrees,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
if row:
return _row_to_dict(row)
return None
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
"""Load a single image (BYTEA) from the session."""
column_map = {
"original": "original_png",
"deskewed": "deskewed_png",
"binarized": "binarized_png",
"dewarped": "dewarped_png",
}
column = column_map.get(image_type)
if not column:
return None
pool = await get_pool()
async with pool.acquire() as conn:
return await conn.fetchval(
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
uuid.UUID(session_id)
)
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
"""Update session fields dynamically."""
pool = await get_pool()
fields = []
values = []
param_idx = 1
allowed_fields = {
'name', 'filename', 'status', 'current_step',
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
'deskew_result', 'dewarp_result', 'column_result',
'ground_truth', 'auto_shear_degrees',
}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'ground_truth'}
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
if key in jsonb_fields and value is not None and not isinstance(value, str):
value = json.dumps(value)
values.append(value)
param_idx += 1
if not fields:
return await get_session_db(session_id)
# Always update updated_at
fields.append(f"updated_at = NOW()")
values.append(uuid.UUID(session_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE ocr_pipeline_sessions
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result,
ground_truth, auto_shear_degrees,
created_at, updated_at
""", *values)
if row:
return _row_to_dict(row)
return None
async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]:
"""List all sessions (metadata only, no images)."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
created_at, updated_at
FROM ocr_pipeline_sessions
ORDER BY created_at DESC
LIMIT $1
""", limit)
return [_row_to_dict(row) for row in rows]
async def delete_session_db(session_id: str) -> bool:
"""Delete a session."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
return result == "DELETE 1"
# =============================================================================
# HELPER
# =============================================================================
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
"""Convert asyncpg Record to JSON-serializable dict."""
if row is None:
return {}
result = dict(row)
# UUID → string
for key in ['id', 'session_id']:
if key in result and result[key] is not None:
result[key] = str(result[key])
# datetime → ISO string
for key in ['created_at', 'updated_at']:
if key in result and result[key] is not None:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['deskew_result', 'dewarp_result', 'column_result', 'ground_truth']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])
return result