feat: Persistente Sessions (PostgreSQL) + Spaltenerkennung (Step 3)
Sessions werden jetzt in PostgreSQL gespeichert statt in-memory. Neue Session-Liste mit Name, Datum, Schritt. Sessions ueberleben Browser-Refresh und Container-Neustart. Step 3 nutzt analyze_layout() fuer automatische Spaltenerkennung mit farbigem Overlay. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { PagePurpose } from '@/components/common/PagePurpose'
|
||||
import { PipelineStepper } from '@/components/ocr-pipeline/PipelineStepper'
|
||||
import { StepDeskew } from '@/components/ocr-pipeline/StepDeskew'
|
||||
@@ -10,11 +10,18 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti
|
||||
import { StepCoordinates } from '@/components/ocr-pipeline/StepCoordinates'
|
||||
import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction'
|
||||
import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth'
|
||||
import { PIPELINE_STEPS, type PipelineStep } from './types'
|
||||
import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
export default function OcrPipelinePage() {
|
||||
const [currentStep, setCurrentStep] = useState(0)
|
||||
const [sessionId, setSessionId] = useState<string | null>(null)
|
||||
const [sessionName, setSessionName] = useState<string>('')
|
||||
const [sessions, setSessions] = useState<SessionListItem[]>([])
|
||||
const [loadingSessions, setLoadingSessions] = useState(true)
|
||||
const [editingName, setEditingName] = useState<string | null>(null)
|
||||
const [editNameValue, setEditNameValue] = useState('')
|
||||
const [steps, setSteps] = useState<PipelineStep[]>(
|
||||
PIPELINE_STEPS.map((s, i) => ({
|
||||
...s,
|
||||
@@ -22,6 +29,82 @@ export default function OcrPipelinePage() {
|
||||
})),
|
||||
)
|
||||
|
||||
// Load session list on mount
|
||||
useEffect(() => {
|
||||
loadSessions()
|
||||
}, [])
|
||||
|
||||
const loadSessions = async () => {
|
||||
setLoadingSessions(true)
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setSessions(data.sessions || [])
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to load sessions:', e)
|
||||
} finally {
|
||||
setLoadingSessions(false)
|
||||
}
|
||||
}
|
||||
|
||||
const openSession = useCallback(async (sid: string) => {
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`)
|
||||
if (!res.ok) return
|
||||
const data = await res.json()
|
||||
|
||||
setSessionId(sid)
|
||||
setSessionName(data.name || data.filename || '')
|
||||
|
||||
// Determine which step to jump to based on current_step
|
||||
const dbStep = data.current_step || 1
|
||||
// Steps: 1=deskew, 2=dewarp, 3=columns, ...
|
||||
// UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ...
|
||||
const uiStep = Math.max(0, dbStep - 1)
|
||||
|
||||
setSteps(
|
||||
PIPELINE_STEPS.map((s, i) => ({
|
||||
...s,
|
||||
status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending',
|
||||
})),
|
||||
)
|
||||
setCurrentStep(uiStep)
|
||||
} catch (e) {
|
||||
console.error('Failed to open session:', e)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const deleteSession = useCallback(async (sid: string) => {
|
||||
try {
|
||||
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, { method: 'DELETE' })
|
||||
setSessions((prev) => prev.filter((s) => s.id !== sid))
|
||||
if (sessionId === sid) {
|
||||
setSessionId(null)
|
||||
setCurrentStep(0)
|
||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to delete session:', e)
|
||||
}
|
||||
}, [sessionId])
|
||||
|
||||
const renameSession = useCallback(async (sid: string, newName: string) => {
|
||||
try {
|
||||
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ name: newName }),
|
||||
})
|
||||
setSessions((prev) => prev.map((s) => (s.id === sid ? { ...s, name: newName } : s)))
|
||||
if (sessionId === sid) setSessionName(newName)
|
||||
} catch (e) {
|
||||
console.error('Failed to rename session:', e)
|
||||
}
|
||||
setEditingName(null)
|
||||
}, [sessionId])
|
||||
|
||||
const handleStepClick = (index: number) => {
|
||||
if (index <= currentStep || steps[index].status === 'completed') {
|
||||
setCurrentStep(index)
|
||||
@@ -43,9 +126,28 @@ export default function OcrPipelinePage() {
|
||||
|
||||
const handleDeskewComplete = (sid: string) => {
|
||||
setSessionId(sid)
|
||||
// Reload session list to show the new session
|
||||
loadSessions()
|
||||
handleNext()
|
||||
}
|
||||
|
||||
const handleNewSession = () => {
|
||||
setSessionId(null)
|
||||
setSessionName('')
|
||||
setCurrentStep(0)
|
||||
setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' })))
|
||||
}
|
||||
|
||||
const stepNames: Record<number, string> = {
|
||||
1: 'Begradigung',
|
||||
2: 'Entzerrung',
|
||||
3: 'Spalten',
|
||||
4: 'Woerter',
|
||||
5: 'Koordinaten',
|
||||
6: 'Rekonstruktion',
|
||||
7: 'Validierung',
|
||||
}
|
||||
|
||||
const renderStep = () => {
|
||||
switch (currentStep) {
|
||||
case 0:
|
||||
@@ -53,7 +155,7 @@ export default function OcrPipelinePage() {
|
||||
case 1:
|
||||
return <StepDewarp sessionId={sessionId} onNext={handleNext} />
|
||||
case 2:
|
||||
return <StepColumnDetection />
|
||||
return <StepColumnDetection sessionId={sessionId} onNext={handleNext} />
|
||||
case 3:
|
||||
return <StepWordRecognition />
|
||||
case 4:
|
||||
@@ -75,7 +177,7 @@ export default function OcrPipelinePage() {
|
||||
audience={['Entwickler', 'Data Scientists']}
|
||||
architecture={{
|
||||
services: ['klausur-service (FastAPI)', 'OpenCV', 'Tesseract'],
|
||||
databases: ['In-Memory Sessions'],
|
||||
databases: ['PostgreSQL Sessions'],
|
||||
}}
|
||||
relatedPages={[
|
||||
{ name: 'OCR Vergleich', href: '/ai/ocr-compare', description: 'Methoden-Vergleich' },
|
||||
@@ -84,6 +186,97 @@ export default function OcrPipelinePage() {
|
||||
defaultCollapsed
|
||||
/>
|
||||
|
||||
{/* Session List */}
|
||||
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4">
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Sessions
|
||||
</h3>
|
||||
<button
|
||||
onClick={handleNewSession}
|
||||
className="text-xs px-3 py-1.5 bg-teal-600 text-white rounded-lg hover:bg-teal-700 transition-colors"
|
||||
>
|
||||
+ Neue Session
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{loadingSessions ? (
|
||||
<div className="text-sm text-gray-400 py-2">Lade Sessions...</div>
|
||||
) : sessions.length === 0 ? (
|
||||
<div className="text-sm text-gray-400 py-2">Noch keine Sessions vorhanden.</div>
|
||||
) : (
|
||||
<div className="space-y-1 max-h-48 overflow-y-auto">
|
||||
{sessions.map((s) => (
|
||||
<div
|
||||
key={s.id}
|
||||
className={`flex items-center gap-2 px-3 py-2 rounded-lg text-sm transition-colors cursor-pointer ${
|
||||
sessionId === s.id
|
||||
? 'bg-teal-50 dark:bg-teal-900/30 border border-teal-200 dark:border-teal-700'
|
||||
: 'hover:bg-gray-50 dark:hover:bg-gray-700/50'
|
||||
}`}
|
||||
>
|
||||
<div className="flex-1 min-w-0" onClick={() => openSession(s.id)}>
|
||||
{editingName === s.id ? (
|
||||
<input
|
||||
autoFocus
|
||||
value={editNameValue}
|
||||
onChange={(e) => setEditNameValue(e.target.value)}
|
||||
onBlur={() => renameSession(s.id, editNameValue)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') renameSession(s.id, editNameValue)
|
||||
if (e.key === 'Escape') setEditingName(null)
|
||||
}}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="w-full px-1 py-0.5 text-sm border rounded dark:bg-gray-700 dark:border-gray-600"
|
||||
/>
|
||||
) : (
|
||||
<div className="truncate font-medium text-gray-700 dark:text-gray-300">
|
||||
{s.name || s.filename}
|
||||
</div>
|
||||
)}
|
||||
<div className="text-xs text-gray-400 flex gap-2">
|
||||
<span>{new Date(s.created_at).toLocaleDateString('de-DE', { day: '2-digit', month: '2-digit', year: '2-digit', hour: '2-digit', minute: '2-digit' })}</span>
|
||||
<span>Schritt {s.current_step}: {stepNames[s.current_step] || '?'}</span>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setEditNameValue(s.name || s.filename)
|
||||
setEditingName(s.id)
|
||||
}}
|
||||
className="p-1 text-gray-400 hover:text-gray-600 dark:hover:text-gray-300"
|
||||
title="Umbenennen"
|
||||
>
|
||||
<svg className="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
|
||||
<path strokeLinecap="round" strokeLinejoin="round" d="M15.232 5.232l3.536 3.536m-2.036-5.036a2.5 2.5 0 113.536 3.536L6.5 21.036H3v-3.572L16.732 3.732z" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
if (confirm('Session loeschen?')) deleteSession(s.id)
|
||||
}}
|
||||
className="p-1 text-gray-400 hover:text-red-500"
|
||||
title="Loeschen"
|
||||
>
|
||||
<svg className="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
|
||||
<path strokeLinecap="round" strokeLinejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Active session name */}
|
||||
{sessionId && sessionName && (
|
||||
<div className="text-sm text-gray-500 dark:text-gray-400">
|
||||
Aktive Session: <span className="font-medium text-gray-700 dark:text-gray-300">{sessionName}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<PipelineStepper steps={steps} currentStep={currentStep} onStepClick={handleStepClick} />
|
||||
|
||||
<div className="min-h-[400px]">{renderStep()}</div>
|
||||
|
||||
@@ -7,14 +7,27 @@ export interface PipelineStep {
|
||||
status: PipelineStepStatus
|
||||
}
|
||||
|
||||
export interface SessionListItem {
|
||||
id: string
|
||||
name: string
|
||||
filename: string
|
||||
status: string
|
||||
current_step: number
|
||||
created_at: string
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface SessionInfo {
|
||||
session_id: string
|
||||
filename: string
|
||||
name?: string
|
||||
image_width: number
|
||||
image_height: number
|
||||
original_image_url: string
|
||||
current_step?: number
|
||||
deskew_result?: DeskewResult
|
||||
dewarp_result?: DewarpResult
|
||||
column_result?: ColumnResult
|
||||
}
|
||||
|
||||
export interface DeskewResult {
|
||||
@@ -50,6 +63,24 @@ export interface DewarpGroundTruth {
|
||||
notes?: string
|
||||
}
|
||||
|
||||
export interface PageRegion {
|
||||
type: 'column_en' | 'column_de' | 'column_example' | 'header' | 'footer'
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
|
||||
export interface ColumnResult {
|
||||
columns: PageRegion[]
|
||||
duration_seconds: number
|
||||
}
|
||||
|
||||
export interface ColumnGroundTruth {
|
||||
is_correct: boolean
|
||||
notes?: string
|
||||
}
|
||||
|
||||
export const PIPELINE_STEPS: PipelineStep[] = [
|
||||
{ id: 'deskew', name: 'Begradigung', icon: '📐', status: 'pending' },
|
||||
{ id: 'dewarp', name: 'Entzerrung', icon: '🔧', status: 'pending' },
|
||||
|
||||
119
admin-lehrer/components/ocr-pipeline/ColumnControls.tsx
Normal file
119
admin-lehrer/components/ocr-pipeline/ColumnControls.tsx
Normal file
@@ -0,0 +1,119 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import type { ColumnResult, ColumnGroundTruth, PageRegion } from '@/app/(admin)/ai/ocr-pipeline/types'
|
||||
|
||||
interface ColumnControlsProps {
|
||||
columnResult: ColumnResult | null
|
||||
onRerun: () => void
|
||||
onGroundTruth: (gt: ColumnGroundTruth) => void
|
||||
onNext: () => void
|
||||
isDetecting: boolean
|
||||
}
|
||||
|
||||
const TYPE_COLORS: Record<string, string> = {
|
||||
column_en: 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400',
|
||||
column_de: 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400',
|
||||
column_example: 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400',
|
||||
header: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400',
|
||||
footer: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400',
|
||||
}
|
||||
|
||||
const TYPE_LABELS: Record<string, string> = {
|
||||
column_en: 'EN',
|
||||
column_de: 'DE',
|
||||
column_example: 'Beispiel',
|
||||
header: 'Header',
|
||||
footer: 'Footer',
|
||||
}
|
||||
|
||||
export function ColumnControls({ columnResult, onRerun, onGroundTruth, onNext, isDetecting }: ColumnControlsProps) {
|
||||
const [gtSaved, setGtSaved] = useState(false)
|
||||
|
||||
if (!columnResult) return null
|
||||
|
||||
const columns = columnResult.columns.filter((c: PageRegion) => c.type.startsWith('column'))
|
||||
const headerFooter = columnResult.columns.filter((c: PageRegion) => !c.type.startsWith('column'))
|
||||
|
||||
const handleGt = (isCorrect: boolean) => {
|
||||
onGroundTruth({ is_correct: isCorrect })
|
||||
setGtSaved(true)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-white dark:bg-gray-800 rounded-xl border border-gray-200 dark:border-gray-700 p-4 space-y-4">
|
||||
{/* Summary */}
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="text-sm text-gray-600 dark:text-gray-400">
|
||||
<span className="font-medium text-gray-800 dark:text-gray-200">{columns.length} Spalten</span> erkannt
|
||||
{columnResult.duration_seconds > 0 && (
|
||||
<span className="ml-2 text-xs">({columnResult.duration_seconds}s)</span>
|
||||
)}
|
||||
</div>
|
||||
<button
|
||||
onClick={onRerun}
|
||||
disabled={isDetecting}
|
||||
className="text-xs px-2 py-1 bg-gray-100 dark:bg-gray-700 rounded hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors disabled:opacity-50"
|
||||
>
|
||||
Erneut erkennen
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Column list */}
|
||||
<div className="space-y-2">
|
||||
{columns.map((col: PageRegion, i: number) => (
|
||||
<div key={i} className="flex items-center gap-3 text-sm">
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${TYPE_COLORS[col.type] || ''}`}>
|
||||
{TYPE_LABELS[col.type] || col.type}
|
||||
</span>
|
||||
<span className="text-gray-500 dark:text-gray-400 text-xs font-mono">
|
||||
x={col.x} y={col.y} {col.width}x{col.height}px
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
{headerFooter.map((r: PageRegion, i: number) => (
|
||||
<div key={`hf-${i}`} className="flex items-center gap-3 text-sm">
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${TYPE_COLORS[r.type] || ''}`}>
|
||||
{TYPE_LABELS[r.type] || r.type}
|
||||
</span>
|
||||
<span className="text-gray-500 dark:text-gray-400 text-xs font-mono">
|
||||
x={r.x} y={r.y} {r.width}x{r.height}px
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Ground Truth + Navigation */}
|
||||
<div className="flex items-center justify-between pt-2 border-t border-gray-100 dark:border-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-500 dark:text-gray-400">Spalten korrekt?</span>
|
||||
{gtSaved ? (
|
||||
<span className="text-xs text-green-600 dark:text-green-400">Gespeichert</span>
|
||||
) : (
|
||||
<>
|
||||
<button
|
||||
onClick={() => handleGt(true)}
|
||||
className="text-xs px-3 py-1 bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400 rounded hover:bg-green-200 dark:hover:bg-green-900/50 transition-colors"
|
||||
>
|
||||
Ja
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleGt(false)}
|
||||
className="text-xs px-3 py-1 bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400 rounded hover:bg-red-200 dark:hover:bg-red-900/50 transition-colors"
|
||||
>
|
||||
Nein
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={onNext}
|
||||
className="px-4 py-2 bg-teal-600 text-white rounded-lg hover:bg-teal-700 transition-colors text-sm font-medium"
|
||||
>
|
||||
Weiter
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,19 +1,168 @@
|
||||
'use client'
|
||||
|
||||
export function StepColumnDetection() {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<div className="text-5xl mb-4">📊</div>
|
||||
<h3 className="text-lg font-medium text-gray-700 dark:text-gray-300 mb-2">
|
||||
Schritt 3: Spaltenerkennung
|
||||
</h3>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md">
|
||||
Erkennung unsichtbarer Spaltentrennungen in der Vokabelseite.
|
||||
Dieser Schritt wird in einer zukuenftigen Version implementiert.
|
||||
</p>
|
||||
<div className="mt-6 px-4 py-2 bg-amber-100 dark:bg-amber-900/30 text-amber-700 dark:text-amber-400 rounded-full text-sm font-medium">
|
||||
Kommt bald
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import type { ColumnResult, ColumnGroundTruth } from '@/app/(admin)/ai/ocr-pipeline/types'
|
||||
import { ColumnControls } from './ColumnControls'
|
||||
|
||||
const KLAUSUR_API = '/klausur-api'
|
||||
|
||||
interface StepColumnDetectionProps {
|
||||
sessionId: string | null
|
||||
onNext: () => void
|
||||
}
|
||||
|
||||
export function StepColumnDetection({ sessionId, onNext }: StepColumnDetectionProps) {
|
||||
const [columnResult, setColumnResult] = useState<ColumnResult | null>(null)
|
||||
const [detecting, setDetecting] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Auto-trigger column detection on mount
|
||||
useEffect(() => {
|
||||
if (!sessionId || columnResult) return
|
||||
|
||||
const runDetection = async () => {
|
||||
setDetecting(true)
|
||||
setError(null)
|
||||
try {
|
||||
// First check if columns already detected (reload case)
|
||||
const infoRes = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}`)
|
||||
if (infoRes.ok) {
|
||||
const info = await infoRes.json()
|
||||
if (info.column_result) {
|
||||
setColumnResult(info.column_result)
|
||||
setDetecting(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Run detection
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, {
|
||||
method: 'POST',
|
||||
})
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({ detail: res.statusText }))
|
||||
throw new Error(err.detail || 'Spaltenerkennung fehlgeschlagen')
|
||||
}
|
||||
const data: ColumnResult = await res.json()
|
||||
setColumnResult(data)
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : 'Unbekannter Fehler')
|
||||
} finally {
|
||||
setDetecting(false)
|
||||
}
|
||||
}
|
||||
|
||||
runDetection()
|
||||
}, [sessionId, columnResult])
|
||||
|
||||
const handleRerun = useCallback(async () => {
|
||||
if (!sessionId) return
|
||||
setDetecting(true)
|
||||
setError(null)
|
||||
try {
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, {
|
||||
method: 'POST',
|
||||
})
|
||||
if (!res.ok) throw new Error('Spaltenerkennung fehlgeschlagen')
|
||||
const data: ColumnResult = await res.json()
|
||||
setColumnResult(data)
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : 'Fehler')
|
||||
} finally {
|
||||
setDetecting(false)
|
||||
}
|
||||
}, [sessionId])
|
||||
|
||||
const handleGroundTruth = useCallback(async (gt: ColumnGroundTruth) => {
|
||||
if (!sessionId) return
|
||||
try {
|
||||
await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/ground-truth/columns`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(gt),
|
||||
})
|
||||
} catch (e) {
|
||||
console.error('Ground truth save failed:', e)
|
||||
}
|
||||
}, [sessionId])
|
||||
|
||||
if (!sessionId) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<div className="text-5xl mb-4">📊</div>
|
||||
<h3 className="text-lg font-medium text-gray-700 dark:text-gray-300 mb-2">
|
||||
Schritt 3: Spaltenerkennung
|
||||
</h3>
|
||||
<p className="text-gray-500 dark:text-gray-400 max-w-md">
|
||||
Bitte zuerst Schritt 1 und 2 abschliessen.
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const dewarpedUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/dewarped`
|
||||
const overlayUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/columns-overlay`
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Loading indicator */}
|
||||
{detecting && (
|
||||
<div className="flex items-center gap-2 text-teal-600 dark:text-teal-400 text-sm">
|
||||
<div className="animate-spin w-4 h-4 border-2 border-teal-500 border-t-transparent rounded-full" />
|
||||
Spaltenerkennung laeuft...
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Image comparison: overlay (left) vs clean (right) */}
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<div className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-1">
|
||||
Mit Spalten-Overlay
|
||||
</div>
|
||||
<div className="border rounded-lg overflow-hidden dark:border-gray-700 bg-gray-50 dark:bg-gray-900">
|
||||
{columnResult ? (
|
||||
// eslint-disable-next-line @next/next/no-img-element
|
||||
<img
|
||||
src={`${overlayUrl}?t=${Date.now()}`}
|
||||
alt="Spalten-Overlay"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
) : (
|
||||
<div className="aspect-[3/4] flex items-center justify-center text-gray-400 text-sm">
|
||||
{detecting ? 'Erkenne Spalten...' : 'Keine Daten'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className="text-xs font-medium text-gray-500 dark:text-gray-400 mb-1">
|
||||
Entzerrtes Bild
|
||||
</div>
|
||||
<div className="border rounded-lg overflow-hidden dark:border-gray-700 bg-gray-50 dark:bg-gray-900">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={dewarpedUrl}
|
||||
alt="Entzerrt"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Controls */}
|
||||
<ColumnControls
|
||||
columnResult={columnResult}
|
||||
onRerun={handleRerun}
|
||||
onGroundTruth={handleGroundTruth}
|
||||
onNext={onNext}
|
||||
isDetecting={detecting}
|
||||
/>
|
||||
|
||||
{error && (
|
||||
<div className="p-3 bg-red-50 dark:bg-red-900/20 text-red-600 dark:text-red-400 rounded-lg text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
|
||||
const [showGrid, setShowGrid] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [dragOver, setDragOver] = useState(false)
|
||||
const [sessionName, setSessionName] = useState('')
|
||||
|
||||
// Reload session data when navigating back from a later step
|
||||
useEffect(() => {
|
||||
@@ -67,6 +68,9 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
|
||||
try {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
if (sessionName.trim()) {
|
||||
formData.append('name', sessionName.trim())
|
||||
}
|
||||
|
||||
const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`, {
|
||||
method: 'POST',
|
||||
@@ -167,6 +171,20 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP
|
||||
if (!session) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Session name input */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-600 dark:text-gray-400 mb-1">
|
||||
Session-Name (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={sessionName}
|
||||
onChange={(e) => setSessionName(e.target.value)}
|
||||
placeholder="z.B. Unit 3 Seite 42"
|
||||
className="w-full max-w-sm px-3 py-2 text-sm border rounded-lg dark:bg-gray-800 dark:border-gray-600 dark:text-gray-200 focus:outline-none focus:ring-2 focus:ring-teal-500"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
onDragOver={(e) => { e.preventDefault(); setDragOver(true) }}
|
||||
onDragLeave={() => setDragOver(false)}
|
||||
|
||||
@@ -43,6 +43,7 @@ except ImportError:
|
||||
trocr_router = None
|
||||
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
|
||||
from ocr_pipeline_api import router as ocr_pipeline_router
|
||||
from ocr_pipeline_session_store import init_ocr_pipeline_tables
|
||||
try:
|
||||
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
|
||||
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
|
||||
@@ -76,6 +77,13 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
print(f"Warning: Vocab sessions database initialization failed: {e}")
|
||||
|
||||
# Initialize OCR Pipeline session tables
|
||||
try:
|
||||
await init_ocr_pipeline_tables()
|
||||
print("OCR Pipeline session tables initialized")
|
||||
except Exception as e:
|
||||
print(f"Warning: OCR Pipeline tables initialization failed: {e}")
|
||||
|
||||
# Initialize database pool for DSFA RAG
|
||||
dsfa_db_pool = None
|
||||
if DSFA_DATABASE_URL and set_dsfa_db_pool:
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
-- OCR Pipeline Sessions - Persistent session storage
|
||||
-- Applied automatically by ocr_pipeline_session_store.init_ocr_pipeline_tables()
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ocr_pipeline_sessions (
|
||||
id UUID PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
filename VARCHAR(255),
|
||||
status VARCHAR(50) DEFAULT 'active',
|
||||
current_step INT DEFAULT 1,
|
||||
original_png BYTEA,
|
||||
deskewed_png BYTEA,
|
||||
binarized_png BYTEA,
|
||||
dewarped_png BYTEA,
|
||||
deskew_result JSONB,
|
||||
dewarp_result JSONB,
|
||||
column_result JSONB,
|
||||
ground_truth JSONB DEFAULT '{}',
|
||||
auto_shear_degrees FLOAT,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for listing sessions
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_created
|
||||
ON ocr_pipeline_sessions (created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_status
|
||||
ON ocr_pipeline_sessions (status);
|
||||
@@ -14,20 +14,21 @@ Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
analyze_layout,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
@@ -36,34 +37,67 @@ from cv_vocab_pipeline import (
|
||||
render_image_high_res,
|
||||
render_pdf_high_res,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
delete_session_db,
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
init_ocr_pipeline_tables,
|
||||
list_sessions_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory session store (24h TTL)
|
||||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||||
# DB is source of truth, cache holds BGR arrays during active processing.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
SESSION_TTL_HOURS = 24
|
||||
_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _cleanup_expired():
|
||||
"""Remove sessions older than TTL."""
|
||||
cutoff = datetime.utcnow() - timedelta(hours=SESSION_TTL_HOURS)
|
||||
expired = [sid for sid, s in _sessions.items() if s.get("created_at", datetime.utcnow()) < cutoff]
|
||||
for sid in expired:
|
||||
del _sessions[sid]
|
||||
logger.info(f"OCR Pipeline: expired session {sid}")
|
||||
|
||||
|
||||
def _get_session(session_id: str) -> Dict[str, Any]:
|
||||
"""Get session or raise 404."""
|
||||
session = _sessions.get(session_id)
|
||||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return session
|
||||
|
||||
if session_id in _cache:
|
||||
return _cache[session_id]
|
||||
|
||||
cache_entry: Dict[str, Any] = {
|
||||
"id": session_id,
|
||||
**session,
|
||||
"original_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
}
|
||||
|
||||
# Decode images from DB into BGR numpy arrays
|
||||
for img_type, bgr_key in [
|
||||
("original", "original_bgr"),
|
||||
("deskewed", "deskewed_bgr"),
|
||||
("dewarped", "dewarped_bgr"),
|
||||
]:
|
||||
png_data = await get_session_image(session_id, img_type)
|
||||
if png_data:
|
||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
cache_entry[bgr_key] = bgr
|
||||
|
||||
_cache[session_id] = cache_entry
|
||||
return cache_entry
|
||||
|
||||
|
||||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||||
"""Get from cache or raise 404."""
|
||||
entry = _cache.get(session_id)
|
||||
if not entry:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||||
return entry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -90,15 +124,36 @@ class DewarpGroundTruthRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class RenameSessionRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ManualColumnsRequest(BaseModel):
|
||||
columns: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ColumnGroundTruthRequest(BaseModel):
|
||||
is_correct: bool
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# Session Management Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions():
|
||||
"""List all OCR pipeline sessions."""
|
||||
sessions = await list_sessions_db()
|
||||
return {"sessions": sessions}
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_session(file: UploadFile = File(...)):
|
||||
async def create_session(
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload a PDF or image file and create a pipeline session."""
|
||||
_cleanup_expired()
|
||||
|
||||
file_data = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
content_type = file.content_type or ""
|
||||
@@ -114,25 +169,32 @@ async def create_session(file: UploadFile = File(...)):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||
|
||||
# Encode original as PNG bytes for serving
|
||||
# Encode original as PNG bytes
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||
|
||||
_sessions[session_id] = {
|
||||
original_png = png_buf.tobytes()
|
||||
session_name = name or filename
|
||||
|
||||
# Persist to DB
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=session_name,
|
||||
filename=filename,
|
||||
original_png=original_png,
|
||||
)
|
||||
|
||||
# Cache BGR array for immediate processing
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"created_at": datetime.utcnow(),
|
||||
"name": session_name,
|
||||
"original_bgr": img_bgr,
|
||||
"original_png": png_buf.tobytes(),
|
||||
"deskewed_bgr": None,
|
||||
"deskewed_png": None,
|
||||
"binarized_png": None,
|
||||
"deskew_result": None,
|
||||
"dewarped_bgr": None,
|
||||
"dewarped_png": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"auto_shear_degrees": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
@@ -143,6 +205,7 @@ async def create_session(file: UploadFile = File(...)):
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"image_width": img_bgr.shape[1],
|
||||
"image_height": img_bgr.shape[0],
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
@@ -151,35 +214,106 @@ async def create_session(file: UploadFile = File(...)):
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session_info(session_id: str):
|
||||
"""Get session info including deskew/dewarp results for step navigation."""
|
||||
session = _get_session(session_id)
|
||||
img_bgr = session["original_bgr"]
|
||||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Get image dimensions from original PNG
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if original_png:
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||
else:
|
||||
img_w, img_h = 0, 0
|
||||
|
||||
result = {
|
||||
"session_id": session["id"],
|
||||
"filename": session["filename"],
|
||||
"image_width": img_bgr.shape[1],
|
||||
"image_height": img_bgr.shape[0],
|
||||
"filename": session.get("filename", ""),
|
||||
"name": session.get("name", ""),
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
"current_step": session.get("current_step", 1),
|
||||
}
|
||||
|
||||
# Include deskew result if available
|
||||
if session.get("deskew_result"):
|
||||
result["deskew_result"] = session["deskew_result"]
|
||||
|
||||
# Include dewarp result if available
|
||||
if session.get("dewarp_result"):
|
||||
result["dewarp_result"] = session["dewarp_result"]
|
||||
if session.get("column_result"):
|
||||
result["column_result"] = session["column_result"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}")
|
||||
async def rename_session(session_id: str, req: RenameSessionRequest):
|
||||
"""Rename a session."""
|
||||
updated = await update_session_db(session_id, name=req.name)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "name": req.name}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a session."""
|
||||
_cache.pop(session_id, None)
|
||||
deleted = await delete_session_db(session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "deleted": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, or columns-overlay."""
|
||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if image_type == "columns-overlay":
|
||||
return await _get_columns_overlay(session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||
|
||||
# For binarized, check if we have it cached as PNG
|
||||
if image_type == "binarized" and cached.get("binarized_png"):
|
||||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||
|
||||
# Load from DB
|
||||
data = await get_session_image(session_id, image_type)
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||
|
||||
return Response(content=data, media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deskew Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/deskew")
|
||||
async def auto_deskew(session_id: str):
|
||||
"""Run both deskew methods and pick the best one."""
|
||||
session = _get_session(session_id)
|
||||
img_bgr = session["original_bgr"]
|
||||
# Ensure session is in cache
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("original_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Original image not available")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
@@ -202,12 +336,10 @@ async def auto_deskew(session_id: str):
|
||||
|
||||
duration = time.time() - t0
|
||||
|
||||
# Pick method with larger detected angle (more correction needed = more skew found)
|
||||
# If both are ~0, prefer word alignment as it's more robust
|
||||
# Pick best method
|
||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||
method_used = "word_alignment"
|
||||
angle_applied = angle_wa
|
||||
# Decode word alignment result to BGR
|
||||
wa_array = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||
deskewed_bgr = cv2.imdecode(wa_array, cv2.IMREAD_COLOR)
|
||||
if deskewed_bgr is None:
|
||||
@@ -219,20 +351,19 @@ async def auto_deskew(session_id: str):
|
||||
angle_applied = angle_hough
|
||||
deskewed_bgr = deskewed_hough
|
||||
|
||||
# Encode deskewed as PNG
|
||||
# Encode as PNG
|
||||
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
deskewed_png = deskewed_png_buf.tobytes() if success else session["original_png"]
|
||||
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
||||
|
||||
# Create binarized version
|
||||
binarized_png = None
|
||||
try:
|
||||
binarized = create_ocr_image(deskewed_bgr)
|
||||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Binarization failed: {e}")
|
||||
binarized_png = None
|
||||
|
||||
# Confidence: higher angle = lower confidence that we got it right
|
||||
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
||||
|
||||
deskew_result = {
|
||||
@@ -244,13 +375,23 @@ async def auto_deskew(session_id: str):
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
session["deskewed_bgr"] = deskewed_bgr
|
||||
session["deskewed_png"] = deskewed_png
|
||||
session["binarized_png"] = binarized_png
|
||||
session["deskew_result"] = deskew_result
|
||||
# Update cache
|
||||
cached["deskewed_bgr"] = deskewed_bgr
|
||||
cached["binarized_png"] = binarized_png
|
||||
cached["deskew_result"] = deskew_result
|
||||
|
||||
# Persist to DB
|
||||
db_update = {
|
||||
"deskewed_png": deskewed_png,
|
||||
"deskew_result": deskew_result,
|
||||
"current_step": 2,
|
||||
}
|
||||
if binarized_png:
|
||||
db_update["binarized_png"] = binarized_png
|
||||
await update_session_db(session_id, **db_update)
|
||||
|
||||
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
||||
f"hough={angle_hough:.2f}° wa={angle_wa:.2f}° → {method_used} {angle_applied:.2f}°")
|
||||
f"hough={angle_hough:.2f} wa={angle_wa:.2f} -> {method_used} {angle_applied:.2f}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
@@ -263,8 +404,14 @@ async def auto_deskew(session_id: str):
|
||||
@router.post("/sessions/{session_id}/deskew/manual")
|
||||
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||||
"""Apply a manual rotation angle to the original image."""
|
||||
session = _get_session(session_id)
|
||||
img_bgr = session["original_bgr"]
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("original_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Original image not available")
|
||||
|
||||
angle = max(-5.0, min(5.0, req.angle))
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
@@ -275,26 +422,38 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
success, png_buf = cv2.imencode(".png", rotated)
|
||||
deskewed_png = png_buf.tobytes() if success else session["original_png"]
|
||||
deskewed_png = png_buf.tobytes() if success else b""
|
||||
|
||||
# Binarize
|
||||
binarized_png = None
|
||||
try:
|
||||
binarized = create_ocr_image(rotated)
|
||||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||||
except Exception:
|
||||
binarized_png = None
|
||||
pass
|
||||
|
||||
session["deskewed_bgr"] = rotated
|
||||
session["deskewed_png"] = deskewed_png
|
||||
session["binarized_png"] = binarized_png
|
||||
session["deskew_result"] = {
|
||||
**(session.get("deskew_result") or {}),
|
||||
deskew_result = {
|
||||
**(cached.get("deskew_result") or {}),
|
||||
"angle_applied": round(angle, 3),
|
||||
"method_used": "manual",
|
||||
}
|
||||
|
||||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}°")
|
||||
# Update cache
|
||||
cached["deskewed_bgr"] = rotated
|
||||
cached["binarized_png"] = binarized_png
|
||||
cached["deskew_result"] = deskew_result
|
||||
|
||||
# Persist to DB
|
||||
db_update = {
|
||||
"deskewed_png": deskewed_png,
|
||||
"deskew_result": deskew_result,
|
||||
}
|
||||
if binarized_png:
|
||||
db_update["binarized_png"] = binarized_png
|
||||
await update_session_db(session_id, **db_update)
|
||||
|
||||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
@@ -304,33 +463,14 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, or binarized."""
|
||||
session = _get_session(session_id)
|
||||
|
||||
if image_type == "original":
|
||||
data = session.get("original_png")
|
||||
elif image_type == "deskewed":
|
||||
data = session.get("deskewed_png")
|
||||
elif image_type == "dewarped":
|
||||
data = session.get("dewarped_png")
|
||||
elif image_type == "binarized":
|
||||
data = session.get("binarized_png")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||
|
||||
return Response(content=data, media_type="image/png")
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
||||
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
||||
"""Save ground truth feedback for the deskew step."""
|
||||
session = _get_session(session_id)
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
gt = {
|
||||
"is_correct": req.is_correct,
|
||||
"corrected_angle": req.corrected_angle,
|
||||
@@ -338,7 +478,13 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques
|
||||
"saved_at": datetime.utcnow().isoformat(),
|
||||
"deskew_result": session.get("deskew_result"),
|
||||
}
|
||||
session["ground_truth"]["deskew"] = gt
|
||||
ground_truth["deskew"] = gt
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
# Update cache
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
||||
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
||||
@@ -353,8 +499,11 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques
|
||||
@router.post("/sessions/{session_id}/dewarp")
|
||||
async def auto_dewarp(session_id: str):
|
||||
"""Detect and correct vertical shear on the deskewed image."""
|
||||
session = _get_session(session_id)
|
||||
deskewed_bgr = session.get("deskewed_bgr")
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||||
|
||||
@@ -362,30 +511,37 @@ async def auto_dewarp(session_id: str):
|
||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Encode dewarped as PNG
|
||||
# Encode as PNG
|
||||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success else session["deskewed_png"]
|
||||
dewarped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
session["dewarped_bgr"] = dewarped_bgr
|
||||
session["dewarped_png"] = dewarped_png
|
||||
session["auto_shear_degrees"] = dewarp_info.get("shear_degrees", 0.0)
|
||||
session["dewarp_result"] = {
|
||||
dewarp_result = {
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Update cache
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||
current_step=3,
|
||||
)
|
||||
|
||||
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
||||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f}° "
|
||||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
||||
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
**dewarp_result,
|
||||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||||
}
|
||||
|
||||
@@ -393,9 +549,11 @@ async def auto_dewarp(session_id: str):
|
||||
@router.post("/sessions/{session_id}/dewarp/manual")
|
||||
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||||
"""Apply shear correction with a manual angle."""
|
||||
session = _get_session(session_id)
|
||||
deskewed_bgr = session.get("deskewed_bgr")
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||||
|
||||
@@ -407,17 +565,26 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||||
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
||||
|
||||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success else session.get("deskewed_png")
|
||||
dewarped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
session["dewarped_bgr"] = dewarped_bgr
|
||||
session["dewarped_png"] = dewarped_png
|
||||
session["dewarp_result"] = {
|
||||
**(session.get("dewarp_result") or {}),
|
||||
dewarp_result = {
|
||||
**(cached.get("dewarp_result") or {}),
|
||||
"method_used": "manual",
|
||||
"shear_degrees": round(shear_deg, 3),
|
||||
}
|
||||
|
||||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}°")
|
||||
# Update cache
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
)
|
||||
|
||||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
@@ -430,8 +597,11 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||||
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
||||
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
||||
"""Save ground truth feedback for the dewarp step."""
|
||||
session = _get_session(session_id)
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
gt = {
|
||||
"is_correct": req.is_correct,
|
||||
"corrected_shear": req.corrected_shear,
|
||||
@@ -439,9 +609,168 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques
|
||||
"saved_at": datetime.utcnow().isoformat(),
|
||||
"dewarp_result": session.get("dewarp_result"),
|
||||
}
|
||||
session["ground_truth"]["dewarp"] = gt
|
||||
ground_truth["dewarp"] = gt
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
||||
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
||||
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Column Detection Endpoints (Step 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/columns")
|
||||
async def detect_columns(session_id: str):
|
||||
"""Run column detection on the dewarped image."""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Prepare images for analyze_layout
|
||||
gray = cv2.cvtColor(dewarped_bgr, cv2.COLOR_BGR2GRAY)
|
||||
# CLAHE-enhanced for layout analysis
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
layout_img = clahe.apply(gray)
|
||||
# Binarized for text density
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
duration = time.time() - t0
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
column_result = {
|
||||
"columns": columns,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
column_result=column_result,
|
||||
current_step=3,
|
||||
)
|
||||
|
||||
# Update cache
|
||||
cached["column_result"] = column_result
|
||||
|
||||
col_count = len([c for c in columns if c["type"].startswith("column")])
|
||||
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
||||
f"{col_count} columns detected ({duration:.2f}s)")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**column_result,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/columns/manual")
|
||||
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
||||
"""Override detected columns with manual definitions."""
|
||||
column_result = {
|
||||
"columns": req.columns,
|
||||
"duration_seconds": 0,
|
||||
"method": "manual",
|
||||
}
|
||||
|
||||
await update_session_db(session_id, column_result=column_result)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["column_result"] = column_result
|
||||
|
||||
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
||||
f"{len(req.columns)} columns set")
|
||||
|
||||
return {"session_id": session_id, **column_result}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/ground-truth/columns")
|
||||
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
||||
"""Save ground truth feedback for the column detection step."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
gt = {
|
||||
"is_correct": req.is_correct,
|
||||
"notes": req.notes,
|
||||
"saved_at": datetime.utcnow().isoformat(),
|
||||
"column_result": session.get("column_result"),
|
||||
}
|
||||
ground_truth["columns"] = gt
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
async def _get_columns_overlay(session_id: str) -> Response:
|
||||
"""Generate dewarped image with column borders drawn on it."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise HTTPException(status_code=404, detail="No column data available")
|
||||
|
||||
# Load dewarped image
|
||||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||||
if not dewarped_png:
|
||||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||||
|
||||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
# Color map for region types
|
||||
colors = {
|
||||
"column_en": (255, 180, 0), # Blue (BGR)
|
||||
"column_de": (0, 200, 0), # Green
|
||||
"column_example": (0, 140, 255), # Orange
|
||||
"header": (128, 128, 128), # Gray
|
||||
"footer": (128, 128, 128), # Gray
|
||||
}
|
||||
|
||||
overlay = img.copy()
|
||||
for col in column_result["columns"]:
|
||||
x, y = col["x"], col["y"]
|
||||
w, h = col["width"], col["height"]
|
||||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||||
|
||||
# Semi-transparent fill
|
||||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||||
|
||||
# Solid border
|
||||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||||
|
||||
# Label
|
||||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||||
cv2.putText(img, label, (x + 10, y + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||
|
||||
# Blend overlay at 20% opacity
|
||||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||||
|
||||
success, result_png = cv2.imencode(".png", img)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||
|
||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||
|
||||
228
klausur-service/backend/ocr_pipeline_session_store.py
Normal file
228
klausur-service/backend/ocr_pipeline_session_store.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
|
||||
|
||||
Replaces in-memory storage with database persistence.
|
||||
See migrations/002_ocr_pipeline_sessions.sql for schema.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration (same as vocab_session_store)
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
||||
)
|
||||
|
||||
# Connection pool (initialized lazily)
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
"""Get or create the database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_ocr_pipeline_tables():
|
||||
"""Initialize OCR pipeline tables if they don't exist."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
tables_exist = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'ocr_pipeline_sessions'
|
||||
)
|
||||
""")
|
||||
|
||||
if not tables_exist:
|
||||
logger.info("Creating OCR pipeline tables...")
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"migrations/002_ocr_pipeline_sessions.sql"
|
||||
)
|
||||
if os.path.exists(migration_path):
|
||||
with open(migration_path, "r") as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
logger.info("OCR pipeline tables created successfully")
|
||||
else:
|
||||
logger.warning(f"Migration file not found: {migration_path}")
|
||||
else:
|
||||
logger.debug("OCR pipeline tables already exist")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSION CRUD
|
||||
# =============================================================================
|
||||
|
||||
async def create_session_db(
|
||||
session_id: str,
|
||||
name: str,
|
||||
filename: str,
|
||||
original_png: bytes,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new OCR pipeline session."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO ocr_pipeline_sessions (
|
||||
id, name, filename, original_png, status, current_step
|
||||
) VALUES ($1, $2, $3, $4, 'active', 1)
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
""", uuid.UUID(session_id), name, filename, original_png)
|
||||
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session metadata (without images)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
|
||||
"""Load a single image (BYTEA) from the session."""
|
||||
column_map = {
|
||||
"original": "original_png",
|
||||
"deskewed": "deskewed_png",
|
||||
"binarized": "binarized_png",
|
||||
"dewarped": "dewarped_png",
|
||||
}
|
||||
column = column_map.get(image_type)
|
||||
if not column:
|
||||
return None
|
||||
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
|
||||
uuid.UUID(session_id)
|
||||
)
|
||||
|
||||
|
||||
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
|
||||
"""Update session fields dynamically."""
|
||||
pool = await get_pool()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
param_idx = 1
|
||||
|
||||
allowed_fields = {
|
||||
'name', 'filename', 'status', 'current_step',
|
||||
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
|
||||
'deskew_result', 'dewarp_result', 'column_result',
|
||||
'ground_truth', 'auto_shear_degrees',
|
||||
}
|
||||
|
||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'ground_truth'}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
fields.append(f"{key} = ${param_idx}")
|
||||
if key in jsonb_fields and value is not None and not isinstance(value, str):
|
||||
value = json.dumps(value)
|
||||
values.append(value)
|
||||
param_idx += 1
|
||||
|
||||
if not fields:
|
||||
return await get_session_db(session_id)
|
||||
|
||||
# Always update updated_at
|
||||
fields.append(f"updated_at = NOW()")
|
||||
|
||||
values.append(uuid.UUID(session_id))
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(f"""
|
||||
UPDATE ocr_pipeline_sessions
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
deskew_result, dewarp_result, column_result,
|
||||
ground_truth, auto_shear_degrees,
|
||||
created_at, updated_at
|
||||
""", *values)
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""List all sessions (metadata only, no images)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""", limit)
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def delete_session_db(session_id: str) -> bool:
|
||||
"""Delete a session."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
return result == "DELETE 1"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HELPER
|
||||
# =============================================================================
|
||||
|
||||
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||
"""Convert asyncpg Record to JSON-serializable dict."""
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
result = dict(row)
|
||||
|
||||
# UUID → string
|
||||
for key in ['id', 'session_id']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = str(result[key])
|
||||
|
||||
# datetime → ISO string
|
||||
for key in ['created_at', 'updated_at']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = result[key].isoformat()
|
||||
|
||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||
for key in ['deskew_result', 'dewarp_result', 'column_result', 'ground_truth']:
|
||||
if key in result and result[key] is not None:
|
||||
if isinstance(result[key], str):
|
||||
result[key] = json.loads(result[key])
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user