Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_session_store.py
Benjamin Admin 293e7914d8
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 39s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Failing after 1m48s
CI / test-python-agent-core (push) Successful in 17s
CI / test-nodejs-website (push) Successful in 20s
feat: improved OCR pipeline session manager with categories, thumbnails, pipeline logging
- Add document_category (10 types) and pipeline_log JSONB columns
- Session list: thumbnails, copyable IDs, category/doc_type badges
- Inline category dropdown, bulk delete, pipeline step logging
- New endpoints: thumbnail, delete-all, pipeline-log, categories
- Cleared all 22 old test sessions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 09:44:38 +01:00

263 lines
8.9 KiB
Python

"""
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
Replaces in-memory storage with database persistence.
See migrations/002_ocr_pipeline_sessions.sql for schema.
"""
import os
import uuid
import logging
import json
from typing import Optional, List, Dict, Any
import asyncpg
logger = logging.getLogger(__name__)
# Database configuration (same as vocab_session_store)
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
)
# Connection pool (initialized lazily)
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
"""Get or create the database connection pool."""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def init_ocr_pipeline_tables():
"""Initialize OCR pipeline tables if they don't exist."""
pool = await get_pool()
async with pool.acquire() as conn:
tables_exist = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'ocr_pipeline_sessions'
)
""")
if not tables_exist:
logger.info("Creating OCR pipeline tables...")
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/002_ocr_pipeline_sessions.sql"
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
logger.info("OCR pipeline tables created successfully")
else:
logger.warning(f"Migration file not found: {migration_path}")
else:
logger.debug("OCR pipeline tables already exist")
# Ensure new columns exist (idempotent ALTER TABLE)
await conn.execute("""
ALTER TABLE ocr_pipeline_sessions
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
ADD COLUMN IF NOT EXISTS pipeline_log JSONB
""")
# =============================================================================
# SESSION CRUD
# =============================================================================
async def create_session_db(
session_id: str,
name: str,
filename: str,
original_png: bytes,
) -> Dict[str, Any]:
"""Create a new OCR pipeline session."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO ocr_pipeline_sessions (
id, name, filename, original_png, status, current_step
) VALUES ($1, $2, $3, $4, 'active', 1)
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png)
return _row_to_dict(row)
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
"""Get session metadata (without images)."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
if row:
return _row_to_dict(row)
return None
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
"""Load a single image (BYTEA) from the session."""
column_map = {
"original": "original_png",
"deskewed": "deskewed_png",
"binarized": "binarized_png",
"dewarped": "dewarped_png",
"clean": "clean_png",
}
column = column_map.get(image_type)
if not column:
return None
pool = await get_pool()
async with pool.acquire() as conn:
return await conn.fetchval(
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
uuid.UUID(session_id)
)
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
"""Update session fields dynamically."""
pool = await get_pool()
fields = []
values = []
param_idx = 1
allowed_fields = {
'name', 'filename', 'status', 'current_step',
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
'clean_png', 'handwriting_removal_meta',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'word_result', 'ground_truth', 'auto_shear_degrees',
'doc_type', 'doc_type_result',
'document_category', 'pipeline_log',
}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log'}
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
if key in jsonb_fields and value is not None and not isinstance(value, str):
value = json.dumps(value)
values.append(value)
param_idx += 1
if not fields:
return await get_session_db(session_id)
# Always update updated_at
fields.append(f"updated_at = NOW()")
values.append(uuid.UUID(session_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE ocr_pipeline_sessions
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING id, name, filename, status, current_step,
deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result,
document_category, pipeline_log,
created_at, updated_at
""", *values)
if row:
return _row_to_dict(row)
return None
async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]:
"""List all sessions (metadata only, no images)."""
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
created_at, updated_at
FROM ocr_pipeline_sessions
ORDER BY created_at DESC
LIMIT $1
""", limit)
return [_row_to_dict(row) for row in rows]
async def delete_session_db(session_id: str) -> bool:
"""Delete a session."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id))
return result == "DELETE 1"
async def delete_all_sessions_db() -> int:
"""Delete all sessions. Returns number of deleted rows."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
# result is e.g. "DELETE 5"
try:
return int(result.split()[-1])
except (ValueError, IndexError):
return 0
# =============================================================================
# HELPER
# =============================================================================
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
"""Convert asyncpg Record to JSON-serializable dict."""
if row is None:
return {}
result = dict(row)
# UUID → string
for key in ['id', 'session_id']:
if key in result and result[key] is not None:
result[key] = str(result[key])
# datetime → ISO string
for key in ['created_at', 'updated_at']:
if key in result and result[key] is not None:
result[key] = result[key].isoformat()
# JSONB → parsed (asyncpg returns str for JSONB)
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])
return result