diff --git a/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx b/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx index 981cd76..baaad9c 100644 --- a/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx +++ b/admin-lehrer/app/(admin)/ai/ocr-pipeline/page.tsx @@ -1,6 +1,6 @@ 'use client' -import { useState } from 'react' +import { useCallback, useEffect, useState } from 'react' import { PagePurpose } from '@/components/common/PagePurpose' import { PipelineStepper } from '@/components/ocr-pipeline/PipelineStepper' import { StepDeskew } from '@/components/ocr-pipeline/StepDeskew' @@ -10,11 +10,18 @@ import { StepWordRecognition } from '@/components/ocr-pipeline/StepWordRecogniti import { StepCoordinates } from '@/components/ocr-pipeline/StepCoordinates' import { StepReconstruction } from '@/components/ocr-pipeline/StepReconstruction' import { StepGroundTruth } from '@/components/ocr-pipeline/StepGroundTruth' -import { PIPELINE_STEPS, type PipelineStep } from './types' +import { PIPELINE_STEPS, type PipelineStep, type SessionListItem } from './types' + +const KLAUSUR_API = '/klausur-api' export default function OcrPipelinePage() { const [currentStep, setCurrentStep] = useState(0) const [sessionId, setSessionId] = useState(null) + const [sessionName, setSessionName] = useState('') + const [sessions, setSessions] = useState([]) + const [loadingSessions, setLoadingSessions] = useState(true) + const [editingName, setEditingName] = useState(null) + const [editNameValue, setEditNameValue] = useState('') const [steps, setSteps] = useState( PIPELINE_STEPS.map((s, i) => ({ ...s, @@ -22,6 +29,82 @@ export default function OcrPipelinePage() { })), ) + // Load session list on mount + useEffect(() => { + loadSessions() + }, []) + + const loadSessions = async () => { + setLoadingSessions(true) + try { + const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`) + if (res.ok) { + const data = await res.json() + setSessions(data.sessions || []) + } + } catch (e) { + console.error('Failed to load sessions:', e) + } finally { + setLoadingSessions(false) + } + } + + const openSession = useCallback(async (sid: string) => { + try { + const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`) + if (!res.ok) return + const data = await res.json() + + setSessionId(sid) + setSessionName(data.name || data.filename || '') + + // Determine which step to jump to based on current_step + const dbStep = data.current_step || 1 + // Steps: 1=deskew, 2=dewarp, 3=columns, ... + // UI steps are 0-indexed: 0=deskew, 1=dewarp, 2=columns, ... + const uiStep = Math.max(0, dbStep - 1) + + setSteps( + PIPELINE_STEPS.map((s, i) => ({ + ...s, + status: i < uiStep ? 'completed' : i === uiStep ? 'active' : 'pending', + })), + ) + setCurrentStep(uiStep) + } catch (e) { + console.error('Failed to open session:', e) + } + }, []) + + const deleteSession = useCallback(async (sid: string) => { + try { + await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, { method: 'DELETE' }) + setSessions((prev) => prev.filter((s) => s.id !== sid)) + if (sessionId === sid) { + setSessionId(null) + setCurrentStep(0) + setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' }))) + } + } catch (e) { + console.error('Failed to delete session:', e) + } + }, [sessionId]) + + const renameSession = useCallback(async (sid: string, newName: string) => { + try { + await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sid}`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ name: newName }), + }) + setSessions((prev) => prev.map((s) => (s.id === sid ? { ...s, name: newName } : s))) + if (sessionId === sid) setSessionName(newName) + } catch (e) { + console.error('Failed to rename session:', e) + } + setEditingName(null) + }, [sessionId]) + const handleStepClick = (index: number) => { if (index <= currentStep || steps[index].status === 'completed') { setCurrentStep(index) @@ -43,9 +126,28 @@ export default function OcrPipelinePage() { const handleDeskewComplete = (sid: string) => { setSessionId(sid) + // Reload session list to show the new session + loadSessions() handleNext() } + const handleNewSession = () => { + setSessionId(null) + setSessionName('') + setCurrentStep(0) + setSteps(PIPELINE_STEPS.map((s, i) => ({ ...s, status: i === 0 ? 'active' : 'pending' }))) + } + + const stepNames: Record = { + 1: 'Begradigung', + 2: 'Entzerrung', + 3: 'Spalten', + 4: 'Woerter', + 5: 'Koordinaten', + 6: 'Rekonstruktion', + 7: 'Validierung', + } + const renderStep = () => { switch (currentStep) { case 0: @@ -53,7 +155,7 @@ export default function OcrPipelinePage() { case 1: return case 2: - return + return case 3: return case 4: @@ -75,7 +177,7 @@ export default function OcrPipelinePage() { audience={['Entwickler', 'Data Scientists']} architecture={{ services: ['klausur-service (FastAPI)', 'OpenCV', 'Tesseract'], - databases: ['In-Memory Sessions'], + databases: ['PostgreSQL Sessions'], }} relatedPages={[ { name: 'OCR Vergleich', href: '/ai/ocr-compare', description: 'Methoden-Vergleich' }, @@ -84,6 +186,97 @@ export default function OcrPipelinePage() { defaultCollapsed /> + {/* Session List */} +
+
+

+ Sessions +

+ +
+ + {loadingSessions ? ( +
Lade Sessions...
+ ) : sessions.length === 0 ? ( +
Noch keine Sessions vorhanden.
+ ) : ( +
+ {sessions.map((s) => ( +
+
openSession(s.id)}> + {editingName === s.id ? ( + setEditNameValue(e.target.value)} + onBlur={() => renameSession(s.id, editNameValue)} + onKeyDown={(e) => { + if (e.key === 'Enter') renameSession(s.id, editNameValue) + if (e.key === 'Escape') setEditingName(null) + }} + onClick={(e) => e.stopPropagation()} + className="w-full px-1 py-0.5 text-sm border rounded dark:bg-gray-700 dark:border-gray-600" + /> + ) : ( +
+ {s.name || s.filename} +
+ )} +
+ {new Date(s.created_at).toLocaleDateString('de-DE', { day: '2-digit', month: '2-digit', year: '2-digit', hour: '2-digit', minute: '2-digit' })} + Schritt {s.current_step}: {stepNames[s.current_step] || '?'} +
+
+ + +
+ ))} +
+ )} +
+ + {/* Active session name */} + {sessionId && sessionName && ( +
+ Aktive Session: {sessionName} +
+ )} +
{renderStep()}
diff --git a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts index 63ba8e9..9f2f66d 100644 --- a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts +++ b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts @@ -7,14 +7,27 @@ export interface PipelineStep { status: PipelineStepStatus } +export interface SessionListItem { + id: string + name: string + filename: string + status: string + current_step: number + created_at: string + updated_at?: string +} + export interface SessionInfo { session_id: string filename: string + name?: string image_width: number image_height: number original_image_url: string + current_step?: number deskew_result?: DeskewResult dewarp_result?: DewarpResult + column_result?: ColumnResult } export interface DeskewResult { @@ -50,6 +63,24 @@ export interface DewarpGroundTruth { notes?: string } +export interface PageRegion { + type: 'column_en' | 'column_de' | 'column_example' | 'header' | 'footer' + x: number + y: number + width: number + height: number +} + +export interface ColumnResult { + columns: PageRegion[] + duration_seconds: number +} + +export interface ColumnGroundTruth { + is_correct: boolean + notes?: string +} + export const PIPELINE_STEPS: PipelineStep[] = [ { id: 'deskew', name: 'Begradigung', icon: '📐', status: 'pending' }, { id: 'dewarp', name: 'Entzerrung', icon: '🔧', status: 'pending' }, diff --git a/admin-lehrer/components/ocr-pipeline/ColumnControls.tsx b/admin-lehrer/components/ocr-pipeline/ColumnControls.tsx new file mode 100644 index 0000000..52a808c --- /dev/null +++ b/admin-lehrer/components/ocr-pipeline/ColumnControls.tsx @@ -0,0 +1,119 @@ +'use client' + +import { useState } from 'react' +import type { ColumnResult, ColumnGroundTruth, PageRegion } from '@/app/(admin)/ai/ocr-pipeline/types' + +interface ColumnControlsProps { + columnResult: ColumnResult | null + onRerun: () => void + onGroundTruth: (gt: ColumnGroundTruth) => void + onNext: () => void + isDetecting: boolean +} + +const TYPE_COLORS: Record = { + column_en: 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400', + column_de: 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400', + column_example: 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400', + header: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400', + footer: 'bg-gray-100 text-gray-600 dark:bg-gray-700/50 dark:text-gray-400', +} + +const TYPE_LABELS: Record = { + column_en: 'EN', + column_de: 'DE', + column_example: 'Beispiel', + header: 'Header', + footer: 'Footer', +} + +export function ColumnControls({ columnResult, onRerun, onGroundTruth, onNext, isDetecting }: ColumnControlsProps) { + const [gtSaved, setGtSaved] = useState(false) + + if (!columnResult) return null + + const columns = columnResult.columns.filter((c: PageRegion) => c.type.startsWith('column')) + const headerFooter = columnResult.columns.filter((c: PageRegion) => !c.type.startsWith('column')) + + const handleGt = (isCorrect: boolean) => { + onGroundTruth({ is_correct: isCorrect }) + setGtSaved(true) + } + + return ( +
+ {/* Summary */} +
+
+ {columns.length} Spalten erkannt + {columnResult.duration_seconds > 0 && ( + ({columnResult.duration_seconds}s) + )} +
+ +
+ + {/* Column list */} +
+ {columns.map((col: PageRegion, i: number) => ( +
+ + {TYPE_LABELS[col.type] || col.type} + + + x={col.x} y={col.y} {col.width}x{col.height}px + +
+ ))} + {headerFooter.map((r: PageRegion, i: number) => ( +
+ + {TYPE_LABELS[r.type] || r.type} + + + x={r.x} y={r.y} {r.width}x{r.height}px + +
+ ))} +
+ + {/* Ground Truth + Navigation */} +
+
+ Spalten korrekt? + {gtSaved ? ( + Gespeichert + ) : ( + <> + + + + )} +
+ + +
+
+ ) +} diff --git a/admin-lehrer/components/ocr-pipeline/StepColumnDetection.tsx b/admin-lehrer/components/ocr-pipeline/StepColumnDetection.tsx index b2b9fe2..6f76e18 100644 --- a/admin-lehrer/components/ocr-pipeline/StepColumnDetection.tsx +++ b/admin-lehrer/components/ocr-pipeline/StepColumnDetection.tsx @@ -1,19 +1,168 @@ 'use client' -export function StepColumnDetection() { - return ( -
-
📊
-

- Schritt 3: Spaltenerkennung -

-

- Erkennung unsichtbarer Spaltentrennungen in der Vokabelseite. - Dieser Schritt wird in einer zukuenftigen Version implementiert. -

-
- Kommt bald +import { useCallback, useEffect, useState } from 'react' +import type { ColumnResult, ColumnGroundTruth } from '@/app/(admin)/ai/ocr-pipeline/types' +import { ColumnControls } from './ColumnControls' + +const KLAUSUR_API = '/klausur-api' + +interface StepColumnDetectionProps { + sessionId: string | null + onNext: () => void +} + +export function StepColumnDetection({ sessionId, onNext }: StepColumnDetectionProps) { + const [columnResult, setColumnResult] = useState(null) + const [detecting, setDetecting] = useState(false) + const [error, setError] = useState(null) + + // Auto-trigger column detection on mount + useEffect(() => { + if (!sessionId || columnResult) return + + const runDetection = async () => { + setDetecting(true) + setError(null) + try { + // First check if columns already detected (reload case) + const infoRes = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}`) + if (infoRes.ok) { + const info = await infoRes.json() + if (info.column_result) { + setColumnResult(info.column_result) + setDetecting(false) + return + } + } + + // Run detection + const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, { + method: 'POST', + }) + if (!res.ok) { + const err = await res.json().catch(() => ({ detail: res.statusText })) + throw new Error(err.detail || 'Spaltenerkennung fehlgeschlagen') + } + const data: ColumnResult = await res.json() + setColumnResult(data) + } catch (e) { + setError(e instanceof Error ? e.message : 'Unbekannter Fehler') + } finally { + setDetecting(false) + } + } + + runDetection() + }, [sessionId, columnResult]) + + const handleRerun = useCallback(async () => { + if (!sessionId) return + setDetecting(true) + setError(null) + try { + const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/columns`, { + method: 'POST', + }) + if (!res.ok) throw new Error('Spaltenerkennung fehlgeschlagen') + const data: ColumnResult = await res.json() + setColumnResult(data) + } catch (e) { + setError(e instanceof Error ? e.message : 'Fehler') + } finally { + setDetecting(false) + } + }, [sessionId]) + + const handleGroundTruth = useCallback(async (gt: ColumnGroundTruth) => { + if (!sessionId) return + try { + await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/ground-truth/columns`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(gt), + }) + } catch (e) { + console.error('Ground truth save failed:', e) + } + }, [sessionId]) + + if (!sessionId) { + return ( +
+
📊
+

+ Schritt 3: Spaltenerkennung +

+

+ Bitte zuerst Schritt 1 und 2 abschliessen. +

+ ) + } + + const dewarpedUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/dewarped` + const overlayUrl = `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/columns-overlay` + + return ( +
+ {/* Loading indicator */} + {detecting && ( +
+
+ Spaltenerkennung laeuft... +
+ )} + + {/* Image comparison: overlay (left) vs clean (right) */} +
+
+
+ Mit Spalten-Overlay +
+
+ {columnResult ? ( + // eslint-disable-next-line @next/next/no-img-element + Spalten-Overlay + ) : ( +
+ {detecting ? 'Erkenne Spalten...' : 'Keine Daten'} +
+ )} +
+
+
+
+ Entzerrtes Bild +
+
+ {/* eslint-disable-next-line @next/next/no-img-element */} + Entzerrt +
+
+
+ + {/* Controls */} + + + {error && ( +
+ {error} +
+ )}
) } diff --git a/admin-lehrer/components/ocr-pipeline/StepDeskew.tsx b/admin-lehrer/components/ocr-pipeline/StepDeskew.tsx index 14786ce..ed1845c 100644 --- a/admin-lehrer/components/ocr-pipeline/StepDeskew.tsx +++ b/admin-lehrer/components/ocr-pipeline/StepDeskew.tsx @@ -22,6 +22,7 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP const [showGrid, setShowGrid] = useState(true) const [error, setError] = useState(null) const [dragOver, setDragOver] = useState(false) + const [sessionName, setSessionName] = useState('') // Reload session data when navigating back from a later step useEffect(() => { @@ -67,6 +68,9 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP try { const formData = new FormData() formData.append('file', file) + if (sessionName.trim()) { + formData.append('name', sessionName.trim()) + } const res = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions`, { method: 'POST', @@ -167,6 +171,20 @@ export function StepDeskew({ sessionId: existingSessionId, onNext }: StepDeskewP if (!session) { return (
+ {/* Session name input */} +
+ + setSessionName(e.target.value)} + placeholder="z.B. Unit 3 Seite 42" + className="w-full max-w-sm px-3 py-2 text-sm border rounded-lg dark:bg-gray-800 dark:border-gray-600 dark:text-gray-200 focus:outline-none focus:ring-2 focus:ring-teal-500" + /> +
+
{ e.preventDefault(); setDragOver(true) }} onDragLeave={() => setDragOver(false)} diff --git a/klausur-service/backend/main.py b/klausur-service/backend/main.py index 1705520..51887c1 100644 --- a/klausur-service/backend/main.py +++ b/klausur-service/backend/main.py @@ -43,6 +43,7 @@ except ImportError: trocr_router = None from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL from ocr_pipeline_api import router as ocr_pipeline_router +from ocr_pipeline_session_store import init_ocr_pipeline_tables try: from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL @@ -76,6 +77,13 @@ async def lifespan(app: FastAPI): except Exception as e: print(f"Warning: Vocab sessions database initialization failed: {e}") + # Initialize OCR Pipeline session tables + try: + await init_ocr_pipeline_tables() + print("OCR Pipeline session tables initialized") + except Exception as e: + print(f"Warning: OCR Pipeline tables initialization failed: {e}") + # Initialize database pool for DSFA RAG dsfa_db_pool = None if DSFA_DATABASE_URL and set_dsfa_db_pool: diff --git a/klausur-service/backend/migrations/002_ocr_pipeline_sessions.sql b/klausur-service/backend/migrations/002_ocr_pipeline_sessions.sql new file mode 100644 index 0000000..c073ea4 --- /dev/null +++ b/klausur-service/backend/migrations/002_ocr_pipeline_sessions.sql @@ -0,0 +1,28 @@ +-- OCR Pipeline Sessions - Persistent session storage +-- Applied automatically by ocr_pipeline_session_store.init_ocr_pipeline_tables() + +CREATE TABLE IF NOT EXISTS ocr_pipeline_sessions ( + id UUID PRIMARY KEY, + name VARCHAR(255) NOT NULL, + filename VARCHAR(255), + status VARCHAR(50) DEFAULT 'active', + current_step INT DEFAULT 1, + original_png BYTEA, + deskewed_png BYTEA, + binarized_png BYTEA, + dewarped_png BYTEA, + deskew_result JSONB, + dewarp_result JSONB, + column_result JSONB, + ground_truth JSONB DEFAULT '{}', + auto_shear_degrees FLOAT, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() +); + +-- Index for listing sessions +CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_created + ON ocr_pipeline_sessions (created_at DESC); + +CREATE INDEX IF NOT EXISTS idx_ocr_pipeline_sessions_status + ON ocr_pipeline_sessions (status); diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index 4e37cf8..bcf155f 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -14,20 +14,21 @@ Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import io import logging import time import uuid -from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional import cv2 import numpy as np -from fastapi import APIRouter, File, HTTPException, UploadFile +from fastapi import APIRouter, File, Form, HTTPException, UploadFile from fastapi.responses import Response from pydantic import BaseModel from cv_vocab_pipeline import ( + analyze_layout, create_ocr_image, deskew_image, deskew_image_by_word_alignment, @@ -36,34 +37,67 @@ from cv_vocab_pipeline import ( render_image_high_res, render_pdf_high_res, ) +from ocr_pipeline_session_store import ( + create_session_db, + delete_session_db, + get_session_db, + get_session_image, + init_ocr_pipeline_tables, + list_sessions_db, + update_session_db, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) # --------------------------------------------------------------------------- -# In-memory session store (24h TTL) +# In-memory cache for active sessions (BGR numpy arrays for processing) +# DB is source of truth, cache holds BGR arrays during active processing. # --------------------------------------------------------------------------- -_sessions: Dict[str, Dict[str, Any]] = {} -SESSION_TTL_HOURS = 24 +_cache: Dict[str, Dict[str, Any]] = {} -def _cleanup_expired(): - """Remove sessions older than TTL.""" - cutoff = datetime.utcnow() - timedelta(hours=SESSION_TTL_HOURS) - expired = [sid for sid, s in _sessions.items() if s.get("created_at", datetime.utcnow()) < cutoff] - for sid in expired: - del _sessions[sid] - logger.info(f"OCR Pipeline: expired session {sid}") - - -def _get_session(session_id: str) -> Dict[str, Any]: - """Get session or raise 404.""" - session = _sessions.get(session_id) +async def _load_session_to_cache(session_id: str) -> Dict[str, Any]: + """Load session from DB into cache, decoding PNGs to BGR arrays.""" + session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return session + + if session_id in _cache: + return _cache[session_id] + + cache_entry: Dict[str, Any] = { + "id": session_id, + **session, + "original_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + } + + # Decode images from DB into BGR numpy arrays + for img_type, bgr_key in [ + ("original", "original_bgr"), + ("deskewed", "deskewed_bgr"), + ("dewarped", "dewarped_bgr"), + ]: + png_data = await get_session_image(session_id, img_type) + if png_data: + arr = np.frombuffer(png_data, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + cache_entry[bgr_key] = bgr + + _cache[session_id] = cache_entry + return cache_entry + + +def _get_cached(session_id: str) -> Dict[str, Any]: + """Get from cache or raise 404.""" + entry = _cache.get(session_id) + if not entry: + raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first") + return entry # --------------------------------------------------------------------------- @@ -90,15 +124,36 @@ class DewarpGroundTruthRequest(BaseModel): notes: Optional[str] = None +class RenameSessionRequest(BaseModel): + name: str + + +class ManualColumnsRequest(BaseModel): + columns: List[Dict[str, Any]] + + +class ColumnGroundTruthRequest(BaseModel): + is_correct: bool + notes: Optional[str] = None + + # --------------------------------------------------------------------------- -# Endpoints +# Session Management Endpoints # --------------------------------------------------------------------------- +@router.get("/sessions") +async def list_sessions(): + """List all OCR pipeline sessions.""" + sessions = await list_sessions_db() + return {"sessions": sessions} + + @router.post("/sessions") -async def create_session(file: UploadFile = File(...)): +async def create_session( + file: UploadFile = File(...), + name: Optional[str] = Form(None), +): """Upload a PDF or image file and create a pipeline session.""" - _cleanup_expired() - file_data = await file.read() filename = file.filename or "upload" content_type = file.content_type or "" @@ -114,25 +169,32 @@ async def create_session(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=400, detail=f"Could not process file: {e}") - # Encode original as PNG bytes for serving + # Encode original as PNG bytes success, png_buf = cv2.imencode(".png", img_bgr) if not success: raise HTTPException(status_code=500, detail="Failed to encode image") - _sessions[session_id] = { + original_png = png_buf.tobytes() + session_name = name or filename + + # Persist to DB + await create_session_db( + session_id=session_id, + name=session_name, + filename=filename, + original_png=original_png, + ) + + # Cache BGR array for immediate processing + _cache[session_id] = { "id": session_id, "filename": filename, - "created_at": datetime.utcnow(), + "name": session_name, "original_bgr": img_bgr, - "original_png": png_buf.tobytes(), "deskewed_bgr": None, - "deskewed_png": None, - "binarized_png": None, - "deskew_result": None, "dewarped_bgr": None, - "dewarped_png": None, + "deskew_result": None, "dewarp_result": None, - "auto_shear_degrees": None, "ground_truth": {}, "current_step": 1, } @@ -143,6 +205,7 @@ async def create_session(file: UploadFile = File(...)): return { "session_id": session_id, "filename": filename, + "name": session_name, "image_width": img_bgr.shape[1], "image_height": img_bgr.shape[0], "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", @@ -151,35 +214,106 @@ async def create_session(file: UploadFile = File(...)): @router.get("/sessions/{session_id}") async def get_session_info(session_id: str): - """Get session info including deskew/dewarp results for step navigation.""" - session = _get_session(session_id) - img_bgr = session["original_bgr"] + """Get session info including deskew/dewarp/column results for step navigation.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # Get image dimensions from original PNG + original_png = await get_session_image(session_id, "original") + if original_png: + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) + else: + img_w, img_h = 0, 0 result = { "session_id": session["id"], - "filename": session["filename"], - "image_width": img_bgr.shape[1], - "image_height": img_bgr.shape[0], + "filename": session.get("filename", ""), + "name": session.get("name", ""), + "image_width": img_w, + "image_height": img_h, "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", "current_step": session.get("current_step", 1), } - # Include deskew result if available if session.get("deskew_result"): result["deskew_result"] = session["deskew_result"] - - # Include dewarp result if available if session.get("dewarp_result"): result["dewarp_result"] = session["dewarp_result"] + if session.get("column_result"): + result["column_result"] = session["column_result"] return result +@router.put("/sessions/{session_id}") +async def rename_session(session_id: str, req: RenameSessionRequest): + """Rename a session.""" + updated = await update_session_db(session_id, name=req.name) + if not updated: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "name": req.name} + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a session.""" + _cache.pop(session_id, None) + deleted = await delete_session_db(session_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "deleted": True} + + +# --------------------------------------------------------------------------- +# Image Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/image/{image_type}") +async def get_image(session_id: str, image_type: str): + """Serve session images: original, deskewed, dewarped, binarized, or columns-overlay.""" + valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay"} + if image_type not in valid_types: + raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") + + if image_type == "columns-overlay": + return await _get_columns_overlay(session_id) + + # Try cache first for fast serving + cached = _cache.get(session_id) + if cached: + png_key = f"{image_type}_png" if image_type != "original" else None + bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None + + # For binarized, check if we have it cached as PNG + if image_type == "binarized" and cached.get("binarized_png"): + return Response(content=cached["binarized_png"], media_type="image/png") + + # Load from DB + data = await get_session_image(session_id, image_type) + if not data: + raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") + + return Response(content=data, media_type="image/png") + + +# --------------------------------------------------------------------------- +# Deskew Endpoints +# --------------------------------------------------------------------------- + @router.post("/sessions/{session_id}/deskew") async def auto_deskew(session_id: str): """Run both deskew methods and pick the best one.""" - session = _get_session(session_id) - img_bgr = session["original_bgr"] + # Ensure session is in cache + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("original_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Original image not available") t0 = time.time() @@ -202,12 +336,10 @@ async def auto_deskew(session_id: str): duration = time.time() - t0 - # Pick method with larger detected angle (more correction needed = more skew found) - # If both are ~0, prefer word alignment as it's more robust + # Pick best method if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: method_used = "word_alignment" angle_applied = angle_wa - # Decode word alignment result to BGR wa_array = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) deskewed_bgr = cv2.imdecode(wa_array, cv2.IMREAD_COLOR) if deskewed_bgr is None: @@ -219,20 +351,19 @@ async def auto_deskew(session_id: str): angle_applied = angle_hough deskewed_bgr = deskewed_hough - # Encode deskewed as PNG + # Encode as PNG success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = deskewed_png_buf.tobytes() if success else session["original_png"] + deskewed_png = deskewed_png_buf.tobytes() if success else b"" # Create binarized version + binarized_png = None try: binarized = create_ocr_image(deskewed_bgr) success_bin, bin_buf = cv2.imencode(".png", binarized) binarized_png = bin_buf.tobytes() if success_bin else None except Exception as e: logger.warning(f"Binarization failed: {e}") - binarized_png = None - # Confidence: higher angle = lower confidence that we got it right confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0) deskew_result = { @@ -244,13 +375,23 @@ async def auto_deskew(session_id: str): "duration_seconds": round(duration, 2), } - session["deskewed_bgr"] = deskewed_bgr - session["deskewed_png"] = deskewed_png - session["binarized_png"] = binarized_png - session["deskew_result"] = deskew_result + # Update cache + cached["deskewed_bgr"] = deskewed_bgr + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + "current_step": 2, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) logger.info(f"OCR Pipeline: deskew session {session_id}: " - f"hough={angle_hough:.2f}° wa={angle_wa:.2f}° → {method_used} {angle_applied:.2f}°") + f"hough={angle_hough:.2f} wa={angle_wa:.2f} -> {method_used} {angle_applied:.2f}") return { "session_id": session_id, @@ -263,8 +404,14 @@ async def auto_deskew(session_id: str): @router.post("/sessions/{session_id}/deskew/manual") async def manual_deskew(session_id: str, req: ManualDeskewRequest): """Apply a manual rotation angle to the original image.""" - session = _get_session(session_id) - img_bgr = session["original_bgr"] + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("original_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Original image not available") + angle = max(-5.0, min(5.0, req.angle)) h, w = img_bgr.shape[:2] @@ -275,26 +422,38 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest): borderMode=cv2.BORDER_REPLICATE) success, png_buf = cv2.imencode(".png", rotated) - deskewed_png = png_buf.tobytes() if success else session["original_png"] + deskewed_png = png_buf.tobytes() if success else b"" # Binarize + binarized_png = None try: binarized = create_ocr_image(rotated) success_bin, bin_buf = cv2.imencode(".png", binarized) binarized_png = bin_buf.tobytes() if success_bin else None except Exception: - binarized_png = None + pass - session["deskewed_bgr"] = rotated - session["deskewed_png"] = deskewed_png - session["binarized_png"] = binarized_png - session["deskew_result"] = { - **(session.get("deskew_result") or {}), + deskew_result = { + **(cached.get("deskew_result") or {}), "angle_applied": round(angle, 3), "method_used": "manual", } - logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}°") + # Update cache + cached["deskewed_bgr"] = rotated + cached["binarized_png"] = binarized_png + cached["deskew_result"] = deskew_result + + # Persist to DB + db_update = { + "deskewed_png": deskewed_png, + "deskew_result": deskew_result, + } + if binarized_png: + db_update["binarized_png"] = binarized_png + await update_session_db(session_id, **db_update) + + logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}") return { "session_id": session_id, @@ -304,33 +463,14 @@ async def manual_deskew(session_id: str, req: ManualDeskewRequest): } -@router.get("/sessions/{session_id}/image/{image_type}") -async def get_image(session_id: str, image_type: str): - """Serve session images: original, deskewed, dewarped, or binarized.""" - session = _get_session(session_id) - - if image_type == "original": - data = session.get("original_png") - elif image_type == "deskewed": - data = session.get("deskewed_png") - elif image_type == "dewarped": - data = session.get("dewarped_png") - elif image_type == "binarized": - data = session.get("binarized_png") - else: - raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") - - if not data: - raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") - - return Response(content=data, media_type="image/png") - - @router.post("/sessions/{session_id}/ground-truth/deskew") async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest): """Save ground truth feedback for the deskew step.""" - session = _get_session(session_id) + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + ground_truth = session.get("ground_truth") or {} gt = { "is_correct": req.is_correct, "corrected_angle": req.corrected_angle, @@ -338,7 +478,13 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques "saved_at": datetime.utcnow().isoformat(), "deskew_result": session.get("deskew_result"), } - session["ground_truth"]["deskew"] = gt + ground_truth["deskew"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + # Update cache + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: " f"correct={req.is_correct}, corrected_angle={req.corrected_angle}") @@ -353,8 +499,11 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques @router.post("/sessions/{session_id}/dewarp") async def auto_dewarp(session_id: str): """Detect and correct vertical shear on the deskewed image.""" - session = _get_session(session_id) - deskewed_bgr = session.get("deskewed_bgr") + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + deskewed_bgr = cached.get("deskewed_bgr") if deskewed_bgr is None: raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") @@ -362,30 +511,37 @@ async def auto_dewarp(session_id: str): dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) duration = time.time() - t0 - # Encode dewarped as PNG + # Encode as PNG success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else session["deskewed_png"] + dewarped_png = png_buf.tobytes() if success else b"" - session["dewarped_bgr"] = dewarped_bgr - session["dewarped_png"] = dewarped_png - session["auto_shear_degrees"] = dewarp_info.get("shear_degrees", 0.0) - session["dewarp_result"] = { + dewarp_result = { "method_used": dewarp_info["method"], "shear_degrees": dewarp_info["shear_degrees"], "confidence": dewarp_info["confidence"], "duration_seconds": round(duration, 2), } + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=3, + ) + logger.info(f"OCR Pipeline: dewarp session {session_id}: " - f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f}° " + f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") return { "session_id": session_id, - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(duration, 2), + **dewarp_result, "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", } @@ -393,9 +549,11 @@ async def auto_dewarp(session_id: str): @router.post("/sessions/{session_id}/dewarp/manual") async def manual_dewarp(session_id: str, req: ManualDewarpRequest): """Apply shear correction with a manual angle.""" - session = _get_session(session_id) - deskewed_bgr = session.get("deskewed_bgr") + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + deskewed_bgr = cached.get("deskewed_bgr") if deskewed_bgr is None: raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") @@ -407,17 +565,26 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest): dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) success, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success else session.get("deskewed_png") + dewarped_png = png_buf.tobytes() if success else b"" - session["dewarped_bgr"] = dewarped_bgr - session["dewarped_png"] = dewarped_png - session["dewarp_result"] = { - **(session.get("dewarp_result") or {}), + dewarp_result = { + **(cached.get("dewarp_result") or {}), "method_used": "manual", "shear_degrees": round(shear_deg, 3), } - logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}°") + # Update cache + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + + # Persist to DB + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + ) + + logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") return { "session_id": session_id, @@ -430,8 +597,11 @@ async def manual_dewarp(session_id: str, req: ManualDewarpRequest): @router.post("/sessions/{session_id}/ground-truth/dewarp") async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): """Save ground truth feedback for the dewarp step.""" - session = _get_session(session_id) + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + ground_truth = session.get("ground_truth") or {} gt = { "is_correct": req.is_correct, "corrected_shear": req.corrected_shear, @@ -439,9 +609,168 @@ async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthReques "saved_at": datetime.utcnow().isoformat(), "dewarp_result": session.get("dewarp_result"), } - session["ground_truth"]["dewarp"] = gt + ground_truth["dewarp"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") return {"session_id": session_id, "ground_truth": gt} + + +# --------------------------------------------------------------------------- +# Column Detection Endpoints (Step 3) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/columns") +async def detect_columns(session_id: str): + """Run column detection on the dewarped image.""" + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("dewarped_bgr") + if dewarped_bgr is None: + raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection") + + t0 = time.time() + + # Prepare images for analyze_layout + gray = cv2.cvtColor(dewarped_bgr, cv2.COLOR_BGR2GRAY) + # CLAHE-enhanced for layout analysis + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + layout_img = clahe.apply(gray) + # Binarized for text density + ocr_img = create_ocr_image(dewarped_bgr) + + regions = analyze_layout(layout_img, ocr_img) + duration = time.time() - t0 + + columns = [asdict(r) for r in regions] + column_result = { + "columns": columns, + "duration_seconds": round(duration, 2), + } + + # Persist to DB + await update_session_db( + session_id, + column_result=column_result, + current_step=3, + ) + + # Update cache + cached["column_result"] = column_result + + col_count = len([c for c in columns if c["type"].startswith("column")]) + logger.info(f"OCR Pipeline: columns session {session_id}: " + f"{col_count} columns detected ({duration:.2f}s)") + + return { + "session_id": session_id, + **column_result, + } + + +@router.post("/sessions/{session_id}/columns/manual") +async def set_manual_columns(session_id: str, req: ManualColumnsRequest): + """Override detected columns with manual definitions.""" + column_result = { + "columns": req.columns, + "duration_seconds": 0, + "method": "manual", + } + + await update_session_db(session_id, column_result=column_result) + + if session_id in _cache: + _cache[session_id]["column_result"] = column_result + + logger.info(f"OCR Pipeline: manual columns session {session_id}: " + f"{len(req.columns)} columns set") + + return {"session_id": session_id, **column_result} + + +@router.post("/sessions/{session_id}/ground-truth/columns") +async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest): + """Save ground truth feedback for the column detection step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + gt = { + "is_correct": req.is_correct, + "notes": req.notes, + "saved_at": datetime.utcnow().isoformat(), + "column_result": session.get("column_result"), + } + ground_truth["columns"] = gt + + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + return {"session_id": session_id, "ground_truth": gt} + + +async def _get_columns_overlay(session_id: str) -> Response: + """Generate dewarped image with column borders drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result or not column_result.get("columns"): + raise HTTPException(status_code=404, detail="No column data available") + + # Load dewarped image + dewarped_png = await get_session_image(session_id, "dewarped") + if not dewarped_png: + raise HTTPException(status_code=404, detail="Dewarped image not available") + + arr = np.frombuffer(dewarped_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + # Color map for region types + colors = { + "column_en": (255, 180, 0), # Blue (BGR) + "column_de": (0, 200, 0), # Green + "column_example": (0, 140, 255), # Orange + "header": (128, 128, 128), # Gray + "footer": (128, 128, 128), # Gray + } + + overlay = img.copy() + for col in column_result["columns"]: + x, y = col["x"], col["y"] + w, h = col["width"], col["height"] + color = colors.get(col.get("type", ""), (200, 200, 200)) + + # Semi-transparent fill + cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) + + # Solid border + cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) + + # Label + label = col.get("type", "unknown").replace("column_", "").upper() + cv2.putText(img, label, (x + 10, y + 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) + + # Blend overlay at 20% opacity + cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") diff --git a/klausur-service/backend/ocr_pipeline_session_store.py b/klausur-service/backend/ocr_pipeline_session_store.py new file mode 100644 index 0000000..f83583c --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_session_store.py @@ -0,0 +1,228 @@ +""" +OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions. + +Replaces in-memory storage with database persistence. +See migrations/002_ocr_pipeline_sessions.sql for schema. +""" + +import os +import uuid +import logging +import json +from typing import Optional, List, Dict, Any + +import asyncpg + +logger = logging.getLogger(__name__) + +# Database configuration (same as vocab_session_store) +DATABASE_URL = os.getenv( + "DATABASE_URL", + "postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db" +) + +# Connection pool (initialized lazily) +_pool: Optional[asyncpg.Pool] = None + + +async def get_pool() -> asyncpg.Pool: + """Get or create the database connection pool.""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + return _pool + + +async def init_ocr_pipeline_tables(): + """Initialize OCR pipeline tables if they don't exist.""" + pool = await get_pool() + async with pool.acquire() as conn: + tables_exist = await conn.fetchval(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'ocr_pipeline_sessions' + ) + """) + + if not tables_exist: + logger.info("Creating OCR pipeline tables...") + migration_path = os.path.join( + os.path.dirname(__file__), + "migrations/002_ocr_pipeline_sessions.sql" + ) + if os.path.exists(migration_path): + with open(migration_path, "r") as f: + sql = f.read() + await conn.execute(sql) + logger.info("OCR pipeline tables created successfully") + else: + logger.warning(f"Migration file not found: {migration_path}") + else: + logger.debug("OCR pipeline tables already exist") + + +# ============================================================================= +# SESSION CRUD +# ============================================================================= + +async def create_session_db( + session_id: str, + name: str, + filename: str, + original_png: bytes, +) -> Dict[str, Any]: + """Create a new OCR pipeline session.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + INSERT INTO ocr_pipeline_sessions ( + id, name, filename, original_png, status, current_step + ) VALUES ($1, $2, $3, $4, 'active', 1) + RETURNING id, name, filename, status, current_step, + deskew_result, dewarp_result, column_result, + ground_truth, auto_shear_degrees, + created_at, updated_at + """, uuid.UUID(session_id), name, filename, original_png) + + return _row_to_dict(row) + + +async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]: + """Get session metadata (without images).""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + SELECT id, name, filename, status, current_step, + deskew_result, dewarp_result, column_result, + ground_truth, auto_shear_degrees, + created_at, updated_at + FROM ocr_pipeline_sessions WHERE id = $1 + """, uuid.UUID(session_id)) + + if row: + return _row_to_dict(row) + return None + + +async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]: + """Load a single image (BYTEA) from the session.""" + column_map = { + "original": "original_png", + "deskewed": "deskewed_png", + "binarized": "binarized_png", + "dewarped": "dewarped_png", + } + column = column_map.get(image_type) + if not column: + return None + + pool = await get_pool() + async with pool.acquire() as conn: + return await conn.fetchval( + f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1", + uuid.UUID(session_id) + ) + + +async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]: + """Update session fields dynamically.""" + pool = await get_pool() + + fields = [] + values = [] + param_idx = 1 + + allowed_fields = { + 'name', 'filename', 'status', 'current_step', + 'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png', + 'deskew_result', 'dewarp_result', 'column_result', + 'ground_truth', 'auto_shear_degrees', + } + + jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'ground_truth'} + + for key, value in kwargs.items(): + if key in allowed_fields: + fields.append(f"{key} = ${param_idx}") + if key in jsonb_fields and value is not None and not isinstance(value, str): + value = json.dumps(value) + values.append(value) + param_idx += 1 + + if not fields: + return await get_session_db(session_id) + + # Always update updated_at + fields.append(f"updated_at = NOW()") + + values.append(uuid.UUID(session_id)) + + async with pool.acquire() as conn: + row = await conn.fetchrow(f""" + UPDATE ocr_pipeline_sessions + SET {', '.join(fields)} + WHERE id = ${param_idx} + RETURNING id, name, filename, status, current_step, + deskew_result, dewarp_result, column_result, + ground_truth, auto_shear_degrees, + created_at, updated_at + """, *values) + + if row: + return _row_to_dict(row) + return None + + +async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]: + """List all sessions (metadata only, no images).""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, filename, status, current_step, + created_at, updated_at + FROM ocr_pipeline_sessions + ORDER BY created_at DESC + LIMIT $1 + """, limit) + + return [_row_to_dict(row) for row in rows] + + +async def delete_session_db(session_id: str) -> bool: + """Delete a session.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute(""" + DELETE FROM ocr_pipeline_sessions WHERE id = $1 + """, uuid.UUID(session_id)) + return result == "DELETE 1" + + +# ============================================================================= +# HELPER +# ============================================================================= + +def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: + """Convert asyncpg Record to JSON-serializable dict.""" + if row is None: + return {} + + result = dict(row) + + # UUID → string + for key in ['id', 'session_id']: + if key in result and result[key] is not None: + result[key] = str(result[key]) + + # datetime → ISO string + for key in ['created_at', 'updated_at']: + if key in result and result[key] is not None: + result[key] = result[key].isoformat() + + # JSONB → parsed (asyncpg returns str for JSONB) + for key in ['deskew_result', 'dewarp_result', 'column_result', 'ground_truth']: + if key in result and result[key] is not None: + if isinstance(result[key], str): + result[key] = json.loads(result[key]) + + return result