""" OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions. Replaces in-memory storage with database persistence. See migrations/002_ocr_pipeline_sessions.sql for schema. """ import os import uuid import logging import json from typing import Optional, List, Dict, Any import asyncpg logger = logging.getLogger(__name__) # Database configuration (same as vocab_session_store) DATABASE_URL = os.getenv( "DATABASE_URL", "postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db" ) # Connection pool (initialized lazily) _pool: Optional[asyncpg.Pool] = None async def get_pool() -> asyncpg.Pool: """Get or create the database connection pool.""" global _pool if _pool is None: _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) return _pool async def init_ocr_pipeline_tables(): """Initialize OCR pipeline tables if they don't exist.""" pool = await get_pool() async with pool.acquire() as conn: tables_exist = await conn.fetchval(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_name = 'ocr_pipeline_sessions' ) """) if not tables_exist: logger.info("Creating OCR pipeline tables...") migration_path = os.path.join( os.path.dirname(__file__), "migrations/002_ocr_pipeline_sessions.sql" ) if os.path.exists(migration_path): with open(migration_path, "r") as f: sql = f.read() await conn.execute(sql) logger.info("OCR pipeline tables created successfully") else: logger.warning(f"Migration file not found: {migration_path}") else: logger.debug("OCR pipeline tables already exist") # Ensure new columns exist (idempotent ALTER TABLE) await conn.execute(""" ALTER TABLE ocr_pipeline_sessions ADD COLUMN IF NOT EXISTS clean_png BYTEA, ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB, ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50), ADD COLUMN IF NOT EXISTS doc_type_result JSONB, ADD COLUMN IF NOT EXISTS document_category VARCHAR(50), ADD COLUMN IF NOT EXISTS pipeline_log JSONB, ADD COLUMN IF NOT EXISTS oriented_png BYTEA, ADD COLUMN IF NOT EXISTS cropped_png BYTEA, ADD COLUMN IF NOT EXISTS orientation_result JSONB, ADD COLUMN IF NOT EXISTS crop_result JSONB, ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE, ADD COLUMN IF NOT EXISTS box_index INT, ADD COLUMN IF NOT EXISTS grid_editor_result JSONB """) # ============================================================================= # SESSION CRUD # ============================================================================= async def create_session_db( session_id: str, name: str, filename: str, original_png: bytes, parent_session_id: Optional[str] = None, box_index: Optional[int] = None, ) -> Dict[str, Any]: """Create a new OCR pipeline session. Args: parent_session_id: If set, this is a sub-session for a box region. box_index: 0-based index of the box this sub-session represents. """ pool = await get_pool() parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None async with pool.acquire() as conn: row = await conn.fetchrow(""" INSERT INTO ocr_pipeline_sessions ( id, name, filename, original_png, status, current_step, parent_session_id, box_index ) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6) RETURNING id, name, filename, status, current_step, orientation_result, crop_result, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, doc_type, doc_type_result, document_category, pipeline_log, grid_editor_result, parent_session_id, box_index, created_at, updated_at """, uuid.UUID(session_id), name, filename, original_png, parent_uuid, box_index) return _row_to_dict(row) async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]: """Get session metadata (without images).""" pool = await get_pool() async with pool.acquire() as conn: row = await conn.fetchrow(""" SELECT id, name, filename, status, current_step, orientation_result, crop_result, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, doc_type, doc_type_result, document_category, pipeline_log, grid_editor_result, parent_session_id, box_index, created_at, updated_at FROM ocr_pipeline_sessions WHERE id = $1 """, uuid.UUID(session_id)) if row: return _row_to_dict(row) return None async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]: """Load a single image (BYTEA) from the session.""" column_map = { "original": "original_png", "oriented": "oriented_png", "cropped": "cropped_png", "deskewed": "deskewed_png", "binarized": "binarized_png", "dewarped": "dewarped_png", "clean": "clean_png", } column = column_map.get(image_type) if not column: return None pool = await get_pool() async with pool.acquire() as conn: return await conn.fetchval( f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1", uuid.UUID(session_id) ) async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]: """Update session fields dynamically.""" pool = await get_pool() fields = [] values = [] param_idx = 1 allowed_fields = { 'name', 'filename', 'status', 'current_step', 'original_png', 'oriented_png', 'cropped_png', 'deskewed_png', 'binarized_png', 'dewarped_png', 'clean_png', 'handwriting_removal_meta', 'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'auto_shear_degrees', 'doc_type', 'doc_type_result', 'document_category', 'pipeline_log', 'grid_editor_result', 'parent_session_id', 'box_index', } jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result'} for key, value in kwargs.items(): if key in allowed_fields: fields.append(f"{key} = ${param_idx}") if key in jsonb_fields and value is not None and not isinstance(value, str): value = json.dumps(value) values.append(value) param_idx += 1 if not fields: return await get_session_db(session_id) # Always update updated_at fields.append(f"updated_at = NOW()") values.append(uuid.UUID(session_id)) async with pool.acquire() as conn: row = await conn.fetchrow(f""" UPDATE ocr_pipeline_sessions SET {', '.join(fields)} WHERE id = ${param_idx} RETURNING id, name, filename, status, current_step, orientation_result, crop_result, deskew_result, dewarp_result, column_result, row_result, word_result, ground_truth, auto_shear_degrees, doc_type, doc_type_result, document_category, pipeline_log, grid_editor_result, parent_session_id, box_index, created_at, updated_at """, *values) if row: return _row_to_dict(row) return None async def list_sessions_db( limit: int = 50, include_sub_sessions: bool = False, ) -> List[Dict[str, Any]]: """List sessions (metadata only, no images). By default, sub-sessions (those with parent_session_id) are excluded. Pass include_sub_sessions=True to include them. """ pool = await get_pool() async with pool.acquire() as conn: where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL" rows = await conn.fetch(f""" SELECT id, name, filename, status, current_step, document_category, doc_type, parent_session_id, box_index, created_at, updated_at FROM ocr_pipeline_sessions {where} ORDER BY created_at DESC LIMIT $1 """, limit) return [_row_to_dict(row) for row in rows] async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]: """Get all sub-sessions for a parent session, ordered by box_index.""" pool = await get_pool() async with pool.acquire() as conn: rows = await conn.fetch(""" SELECT id, name, filename, status, current_step, document_category, doc_type, parent_session_id, box_index, created_at, updated_at FROM ocr_pipeline_sessions WHERE parent_session_id = $1 ORDER BY box_index ASC """, uuid.UUID(parent_session_id)) return [_row_to_dict(row) for row in rows] async def delete_session_db(session_id: str) -> bool: """Delete a session.""" pool = await get_pool() async with pool.acquire() as conn: result = await conn.execute(""" DELETE FROM ocr_pipeline_sessions WHERE id = $1 """, uuid.UUID(session_id)) return result == "DELETE 1" async def delete_all_sessions_db() -> int: """Delete all sessions. Returns number of deleted rows.""" pool = await get_pool() async with pool.acquire() as conn: result = await conn.execute("DELETE FROM ocr_pipeline_sessions") # result is e.g. "DELETE 5" try: return int(result.split()[-1]) except (ValueError, IndexError): return 0 # ============================================================================= # HELPER # ============================================================================= def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]: """Convert asyncpg Record to JSON-serializable dict.""" if row is None: return {} result = dict(row) # UUID → string for key in ['id', 'session_id', 'parent_session_id']: if key in result and result[key] is not None: result[key] = str(result[key]) # datetime → ISO string for key in ['created_at', 'updated_at']: if key in result and result[key] is not None: result[key] = result[key].isoformat() # JSONB → parsed (asyncpg returns str for JSONB) for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result']: if key in result and result[key] is not None: if isinstance(result[key], str): result[key] = json.loads(result[key]) return result