Backend: build_word_grid() intersects column regions with content rows, OCRs each cell with language-specific Tesseract, and returns vocabulary entries with percent-based bounding boxes. New endpoints: POST /words, GET /image/words-overlay, ground-truth save/retrieve for words. Frontend: StepWordRecognition with overview + step-through labeling modes, goToStep callback for row correction feedback loop. MkDocs: OCR Pipeline documentation added. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
229 lines
7.4 KiB
Python
229 lines
7.4 KiB
Python
"""
|
|
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
|
|
|
|
Replaces in-memory storage with database persistence.
|
|
See migrations/002_ocr_pipeline_sessions.sql for schema.
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
import logging
|
|
import json
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
import asyncpg
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Database configuration (same as vocab_session_store)
|
|
DATABASE_URL = os.getenv(
|
|
"DATABASE_URL",
|
|
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
|
)
|
|
|
|
# Connection pool (initialized lazily)
|
|
_pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
async def get_pool() -> asyncpg.Pool:
|
|
"""Get or create the database connection pool."""
|
|
global _pool
|
|
if _pool is None:
|
|
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
|
return _pool
|
|
|
|
|
|
async def init_ocr_pipeline_tables():
|
|
"""Initialize OCR pipeline tables if they don't exist."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
tables_exist = await conn.fetchval("""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_name = 'ocr_pipeline_sessions'
|
|
)
|
|
""")
|
|
|
|
if not tables_exist:
|
|
logger.info("Creating OCR pipeline tables...")
|
|
migration_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
"migrations/002_ocr_pipeline_sessions.sql"
|
|
)
|
|
if os.path.exists(migration_path):
|
|
with open(migration_path, "r") as f:
|
|
sql = f.read()
|
|
await conn.execute(sql)
|
|
logger.info("OCR pipeline tables created successfully")
|
|
else:
|
|
logger.warning(f"Migration file not found: {migration_path}")
|
|
else:
|
|
logger.debug("OCR pipeline tables already exist")
|
|
|
|
|
|
# =============================================================================
|
|
# SESSION CRUD
|
|
# =============================================================================
|
|
|
|
async def create_session_db(
|
|
session_id: str,
|
|
name: str,
|
|
filename: str,
|
|
original_png: bytes,
|
|
) -> Dict[str, Any]:
|
|
"""Create a new OCR pipeline session."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
INSERT INTO ocr_pipeline_sessions (
|
|
id, name, filename, original_png, status, current_step
|
|
) VALUES ($1, $2, $3, $4, 'active', 1)
|
|
RETURNING id, name, filename, status, current_step,
|
|
deskew_result, dewarp_result, column_result, row_result,
|
|
word_result, ground_truth, auto_shear_degrees,
|
|
created_at, updated_at
|
|
""", uuid.UUID(session_id), name, filename, original_png)
|
|
|
|
return _row_to_dict(row)
|
|
|
|
|
|
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get session metadata (without images)."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
SELECT id, name, filename, status, current_step,
|
|
deskew_result, dewarp_result, column_result, row_result,
|
|
word_result, ground_truth, auto_shear_degrees,
|
|
created_at, updated_at
|
|
FROM ocr_pipeline_sessions WHERE id = $1
|
|
""", uuid.UUID(session_id))
|
|
|
|
if row:
|
|
return _row_to_dict(row)
|
|
return None
|
|
|
|
|
|
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
|
|
"""Load a single image (BYTEA) from the session."""
|
|
column_map = {
|
|
"original": "original_png",
|
|
"deskewed": "deskewed_png",
|
|
"binarized": "binarized_png",
|
|
"dewarped": "dewarped_png",
|
|
}
|
|
column = column_map.get(image_type)
|
|
if not column:
|
|
return None
|
|
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
return await conn.fetchval(
|
|
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
|
|
uuid.UUID(session_id)
|
|
)
|
|
|
|
|
|
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
|
|
"""Update session fields dynamically."""
|
|
pool = await get_pool()
|
|
|
|
fields = []
|
|
values = []
|
|
param_idx = 1
|
|
|
|
allowed_fields = {
|
|
'name', 'filename', 'status', 'current_step',
|
|
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
|
|
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
|
'word_result', 'ground_truth', 'auto_shear_degrees',
|
|
}
|
|
|
|
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'}
|
|
|
|
for key, value in kwargs.items():
|
|
if key in allowed_fields:
|
|
fields.append(f"{key} = ${param_idx}")
|
|
if key in jsonb_fields and value is not None and not isinstance(value, str):
|
|
value = json.dumps(value)
|
|
values.append(value)
|
|
param_idx += 1
|
|
|
|
if not fields:
|
|
return await get_session_db(session_id)
|
|
|
|
# Always update updated_at
|
|
fields.append(f"updated_at = NOW()")
|
|
|
|
values.append(uuid.UUID(session_id))
|
|
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow(f"""
|
|
UPDATE ocr_pipeline_sessions
|
|
SET {', '.join(fields)}
|
|
WHERE id = ${param_idx}
|
|
RETURNING id, name, filename, status, current_step,
|
|
deskew_result, dewarp_result, column_result, row_result,
|
|
word_result, ground_truth, auto_shear_degrees,
|
|
created_at, updated_at
|
|
""", *values)
|
|
|
|
if row:
|
|
return _row_to_dict(row)
|
|
return None
|
|
|
|
|
|
async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""List all sessions (metadata only, no images)."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
rows = await conn.fetch("""
|
|
SELECT id, name, filename, status, current_step,
|
|
created_at, updated_at
|
|
FROM ocr_pipeline_sessions
|
|
ORDER BY created_at DESC
|
|
LIMIT $1
|
|
""", limit)
|
|
|
|
return [_row_to_dict(row) for row in rows]
|
|
|
|
|
|
async def delete_session_db(session_id: str) -> bool:
|
|
"""Delete a session."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
result = await conn.execute("""
|
|
DELETE FROM ocr_pipeline_sessions WHERE id = $1
|
|
""", uuid.UUID(session_id))
|
|
return result == "DELETE 1"
|
|
|
|
|
|
# =============================================================================
|
|
# HELPER
|
|
# =============================================================================
|
|
|
|
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
|
"""Convert asyncpg Record to JSON-serializable dict."""
|
|
if row is None:
|
|
return {}
|
|
|
|
result = dict(row)
|
|
|
|
# UUID → string
|
|
for key in ['id', 'session_id']:
|
|
if key in result and result[key] is not None:
|
|
result[key] = str(result[key])
|
|
|
|
# datetime → ISO string
|
|
for key in ['created_at', 'updated_at']:
|
|
if key in result and result[key] is not None:
|
|
result[key] = result[key].isoformat()
|
|
|
|
# JSONB → parsed (asyncpg returns str for JSONB)
|
|
for key in ['deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth']:
|
|
if key in result and result[key] is not None:
|
|
if isinstance(result[key], str):
|
|
result[key] = json.loads(result[key])
|
|
|
|
return result
|