Restructure: Move ocr_pipeline + labeling + crop into ocr/ package
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,388 +1,4 @@
|
||||
"""
|
||||
OCR Pipeline Session Store - PostgreSQL persistence for OCR pipeline sessions.
|
||||
|
||||
Replaces in-memory storage with database persistence.
|
||||
See migrations/002_ocr_pipeline_sessions.sql for schema.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration (same as vocab_session_store)
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
||||
)
|
||||
|
||||
# Connection pool (initialized lazily)
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
"""Get or create the database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_ocr_pipeline_tables():
|
||||
"""Initialize OCR pipeline tables if they don't exist."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
tables_exist = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'ocr_pipeline_sessions'
|
||||
)
|
||||
""")
|
||||
|
||||
if not tables_exist:
|
||||
logger.info("Creating OCR pipeline tables...")
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"migrations/002_ocr_pipeline_sessions.sql"
|
||||
)
|
||||
if os.path.exists(migration_path):
|
||||
with open(migration_path, "r") as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
logger.info("OCR pipeline tables created successfully")
|
||||
else:
|
||||
logger.warning(f"Migration file not found: {migration_path}")
|
||||
else:
|
||||
logger.debug("OCR pipeline tables already exist")
|
||||
|
||||
# Ensure new columns exist (idempotent ALTER TABLE)
|
||||
await conn.execute("""
|
||||
ALTER TABLE ocr_pipeline_sessions
|
||||
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
||||
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB,
|
||||
ADD COLUMN IF NOT EXISTS doc_type VARCHAR(50),
|
||||
ADD COLUMN IF NOT EXISTS doc_type_result JSONB,
|
||||
ADD COLUMN IF NOT EXISTS document_category VARCHAR(50),
|
||||
ADD COLUMN IF NOT EXISTS pipeline_log JSONB,
|
||||
ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
|
||||
ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
|
||||
ADD COLUMN IF NOT EXISTS orientation_result JSONB,
|
||||
ADD COLUMN IF NOT EXISTS crop_result JSONB,
|
||||
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
|
||||
ADD COLUMN IF NOT EXISTS box_index INT,
|
||||
ADD COLUMN IF NOT EXISTS grid_editor_result JSONB,
|
||||
ADD COLUMN IF NOT EXISTS structure_result JSONB,
|
||||
ADD COLUMN IF NOT EXISTS document_group_id UUID,
|
||||
ADD COLUMN IF NOT EXISTS page_number INT
|
||||
""")
|
||||
|
||||
# Index for document group lookups
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_ocr_sessions_document_group
|
||||
ON ocr_pipeline_sessions (document_group_id)
|
||||
WHERE document_group_id IS NOT NULL
|
||||
""")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSION CRUD
|
||||
# =============================================================================
|
||||
|
||||
async def create_session_db(
|
||||
session_id: str,
|
||||
name: str,
|
||||
filename: str,
|
||||
original_png: bytes,
|
||||
parent_session_id: Optional[str] = None,
|
||||
box_index: Optional[int] = None,
|
||||
document_group_id: Optional[str] = None,
|
||||
page_number: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new OCR pipeline session.
|
||||
|
||||
Args:
|
||||
parent_session_id: If set, this is a sub-session for a box region.
|
||||
box_index: 0-based index of the box this sub-session represents.
|
||||
document_group_id: Groups multi-page uploads into one document.
|
||||
page_number: 1-based page index within the document group.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
|
||||
group_uuid = uuid.UUID(document_group_id) if document_group_id else None
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO ocr_pipeline_sessions (
|
||||
id, name, filename, original_png, status, current_step,
|
||||
parent_session_id, box_index, document_group_id, page_number
|
||||
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6, $7, $8)
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
orientation_result, crop_result,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
document_category, pipeline_log,
|
||||
grid_editor_result, structure_result,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at
|
||||
""", uuid.UUID(session_id), name, filename, original_png,
|
||||
parent_uuid, box_index, group_uuid, page_number)
|
||||
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session metadata (without images)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
orientation_result, crop_result,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
document_category, pipeline_log,
|
||||
grid_editor_result, structure_result,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]:
|
||||
"""Load a single image (BYTEA) from the session."""
|
||||
column_map = {
|
||||
"original": "original_png",
|
||||
"oriented": "oriented_png",
|
||||
"cropped": "cropped_png",
|
||||
"deskewed": "deskewed_png",
|
||||
"binarized": "binarized_png",
|
||||
"dewarped": "dewarped_png",
|
||||
"clean": "clean_png",
|
||||
}
|
||||
column = column_map.get(image_type)
|
||||
if not column:
|
||||
return None
|
||||
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
f"SELECT {column} FROM ocr_pipeline_sessions WHERE id = $1",
|
||||
uuid.UUID(session_id)
|
||||
)
|
||||
|
||||
|
||||
async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any]]:
|
||||
"""Update session fields dynamically."""
|
||||
pool = await get_pool()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
param_idx = 1
|
||||
|
||||
allowed_fields = {
|
||||
'name', 'filename', 'status', 'current_step',
|
||||
'original_png', 'oriented_png', 'cropped_png',
|
||||
'deskewed_png', 'binarized_png', 'dewarped_png',
|
||||
'clean_png', 'handwriting_removal_meta',
|
||||
'orientation_result', 'crop_result',
|
||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||
'doc_type', 'doc_type_result',
|
||||
'document_category', 'pipeline_log',
|
||||
'grid_editor_result', 'structure_result',
|
||||
'parent_session_id', 'box_index',
|
||||
'document_group_id', 'page_number',
|
||||
}
|
||||
|
||||
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result'}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
fields.append(f"{key} = ${param_idx}")
|
||||
if key in jsonb_fields and value is not None and not isinstance(value, str):
|
||||
value = json.dumps(value)
|
||||
values.append(value)
|
||||
param_idx += 1
|
||||
|
||||
if not fields:
|
||||
return await get_session_db(session_id)
|
||||
|
||||
# Always update updated_at
|
||||
fields.append(f"updated_at = NOW()")
|
||||
|
||||
values.append(uuid.UUID(session_id))
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(f"""
|
||||
UPDATE ocr_pipeline_sessions
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING id, name, filename, status, current_step,
|
||||
orientation_result, crop_result,
|
||||
deskew_result, dewarp_result, column_result, row_result,
|
||||
word_result, ground_truth, auto_shear_degrees,
|
||||
doc_type, doc_type_result,
|
||||
document_category, pipeline_log,
|
||||
grid_editor_result, structure_result,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at
|
||||
""", *values)
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def list_sessions_db(
|
||||
limit: int = 50,
|
||||
include_sub_sessions: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions (metadata only, no images).
|
||||
|
||||
By default, sub-sessions (those with parent_session_id) are excluded.
|
||||
Pass include_sub_sessions=True to include them.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL AND (status IS NULL OR status != 'split')"
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
document_category, doc_type,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at,
|
||||
ground_truth
|
||||
FROM ocr_pipeline_sessions
|
||||
{where}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""", limit)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
d = _row_to_dict(row)
|
||||
# Derive is_ground_truth flag from JSONB, then drop the heavy field
|
||||
gt = d.pop("ground_truth", None) or {}
|
||||
d["is_ground_truth"] = bool(gt.get("build_grid_reference"))
|
||||
results.append(d)
|
||||
return results
|
||||
|
||||
|
||||
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all sub-sessions for a parent session, ordered by box_index."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
document_category, doc_type,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions
|
||||
WHERE parent_session_id = $1
|
||||
ORDER BY box_index ASC
|
||||
""", uuid.UUID(parent_session_id))
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def get_document_group_sessions(document_group_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all sessions in a document group, ordered by page_number."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
document_category, doc_type,
|
||||
parent_session_id, box_index,
|
||||
document_group_id, page_number,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions
|
||||
WHERE document_group_id = $1
|
||||
ORDER BY page_number ASC
|
||||
""", uuid.UUID(document_group_id))
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
|
||||
"""List sessions that have a build_grid_reference in ground_truth."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, filename, status, current_step,
|
||||
document_category, doc_type,
|
||||
ground_truth,
|
||||
parent_session_id, box_index,
|
||||
created_at, updated_at
|
||||
FROM ocr_pipeline_sessions
|
||||
WHERE ground_truth IS NOT NULL
|
||||
AND ground_truth::text LIKE '%build_grid_reference%'
|
||||
AND parent_session_id IS NULL
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def delete_session_db(session_id: str) -> bool:
|
||||
"""Delete a session."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM ocr_pipeline_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
return result == "DELETE 1"
|
||||
|
||||
|
||||
async def delete_all_sessions_db() -> int:
|
||||
"""Delete all sessions. Returns number of deleted rows."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("DELETE FROM ocr_pipeline_sessions")
|
||||
# result is e.g. "DELETE 5"
|
||||
try:
|
||||
return int(result.split()[-1])
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HELPER
|
||||
# =============================================================================
|
||||
|
||||
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||
"""Convert asyncpg Record to JSON-serializable dict."""
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
result = dict(row)
|
||||
|
||||
# UUID → string
|
||||
for key in ['id', 'session_id', 'parent_session_id', 'document_group_id']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = str(result[key])
|
||||
|
||||
# datetime → ISO string
|
||||
for key in ['created_at', 'updated_at']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = result[key].isoformat()
|
||||
|
||||
# JSONB → parsed (asyncpg returns str for JSONB)
|
||||
for key in ['orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'doc_type_result', 'pipeline_log', 'grid_editor_result', 'structure_result']:
|
||||
if key in result and result[key] is not None:
|
||||
if isinstance(result[key], str):
|
||||
result[key] = json.loads(result[key])
|
||||
|
||||
return result
|
||||
# Backward-compat shim -- module moved to ocr/pipeline/session_store.py
|
||||
import importlib as _importlib
|
||||
import sys as _sys
|
||||
_sys.modules[__name__] = _importlib.import_module("ocr.pipeline.session_store")
|
||||
|
||||
Reference in New Issue
Block a user