feat: Box-Zonen durch gesamte Pipeline + Sub-Sessions fuer Box-Inhalt
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Failing after 2m0s
CI / test-python-agent-core (push) Successful in 18s
CI / test-nodejs-website (push) Successful in 19s

- Rote semi-transparente Box-Markierung in allen Overlays (Spalten, Zeilen, Woerter)
- Zeilenerkennung: Combined-Image-Ansatz schliesst Box-Bereiche aus
- Woerter-Erkennung: Zeilen innerhalb von Box-Zonen werden gefiltert
- Sub-Sessions: parent_session_id/box_index in DB-Schema
- POST /sessions/{id}/create-box-sessions erstellt Sub-Sessions aus Box-Regionen
- Session-Info zeigt Sub-Sessions bzw. Parent-Verknuepfung
- Sessions-Liste blendet Sub-Sessions per Default aus
- Rekonstruktion: Fabric-JSON merged Sub-Session-Zellen an Box-Positionen
- Save-Reconstruction routet box{N}_* Updates an Sub-Sessions
- GET /sessions/{id}/vocab-entries/merged fuer zusammengefuehrte Eintraege

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-09 18:24:34 +01:00
parent 4610137ecc
commit 256efef3ea
2 changed files with 485 additions and 25 deletions

View File

@@ -20,6 +20,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
import json import json
import logging import logging
import os import os
import re
import time import time
import uuid import uuid
from dataclasses import asdict from dataclasses import asdict
@@ -75,6 +76,7 @@ from ocr_pipeline_session_store import (
delete_session_db, delete_session_db,
get_session_db, get_session_db,
get_session_image, get_session_image,
get_sub_sessions,
init_ocr_pipeline_tables, init_ocr_pipeline_tables,
list_sessions_db, list_sessions_db,
update_session_db, update_session_db,
@@ -209,9 +211,13 @@ class RemoveHandwritingRequest(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/sessions") @router.get("/sessions")
async def list_sessions(): async def list_sessions(include_sub_sessions: bool = False):
"""List all OCR pipeline sessions.""" """List OCR pipeline sessions.
sessions = await list_sessions_db()
By default, sub-sessions (box regions) are hidden.
Pass ?include_sub_sessions=true to show them.
"""
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
return {"sessions": sessions} return {"sessions": sessions}
@@ -328,6 +334,19 @@ async def get_session_info(session_id: str):
if session.get("doc_type_result"): if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"] result["doc_type_result"] = session["doc_type_result"]
# Sub-session info
if session.get("parent_session_id"):
result["parent_session_id"] = session["parent_session_id"]
result["box_index"] = session.get("box_index")
else:
# Check for sub-sessions
subs = await get_sub_sessions(session_id)
if subs:
result["sub_sessions"] = [
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
for s in subs
]
return result return result
@@ -370,6 +389,118 @@ async def delete_all_sessions():
return {"deleted_count": count} return {"deleted_count": count}
@router.post("/sessions/{session_id}/create-box-sessions")
async def create_box_sessions(session_id: str):
"""Create sub-sessions for each detected box region.
Crops box regions from the cropped/dewarped image and creates
independent sub-sessions that can be processed through the pipeline.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result:
raise HTTPException(status_code=400, detail="Column detection must be completed first")
zones = column_result.get("zones", [])
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
if not box_zones:
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
# Check for existing sub-sessions
existing = await get_sub_sessions(session_id)
if existing:
return {
"session_id": session_id,
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
"message": f"{len(existing)} sub-session(s) already exist",
}
# Load base image
base_png = await get_session_image(session_id, "cropped")
if not base_png:
base_png = await get_session_image(session_id, "dewarped")
if not base_png:
raise HTTPException(status_code=400, detail="No base image available")
arr = np.frombuffer(base_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
parent_name = session.get("name", "Session")
created = []
for i, zone in enumerate(box_zones):
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Crop box region with small padding
pad = 5
y1 = max(0, by - pad)
y2 = min(img.shape[0], by + bh + pad)
x1 = max(0, bx - pad)
x2 = min(img.shape[1], bx + bw + pad)
crop = img[y1:y2, x1:x2]
# Encode as PNG
success, png_buf = cv2.imencode(".png", crop)
if not success:
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
continue
sub_id = str(uuid.uuid4())
sub_name = f"{parent_name} — Box {i + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=session.get("filename", "box-crop.png"),
original_png=png_buf.tobytes(),
parent_session_id=session_id,
box_index=i,
)
# Cache the BGR for immediate processing
_cache[sub_id] = {
"id": sub_id,
"filename": session.get("filename", "box-crop.png"),
"name": sub_name,
"original_bgr": crop.copy(),
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
created.append({
"id": sub_id,
"name": sub_name,
"box_index": i,
"box": box,
"image_width": crop.shape[1],
"image_height": crop.shape[0],
})
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
return {
"session_id": session_id,
"sub_sessions": created,
"total": len(created),
}
@router.get("/sessions/{session_id}/thumbnail") @router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image.""" """Return a small thumbnail of the original image."""
@@ -1218,6 +1349,36 @@ async def get_column_ground_truth(session_id: str):
} }
def _draw_box_exclusion_overlay(
img: np.ndarray,
zones: List[Dict],
*,
label: str = "BOX — separat verarbeitet",
) -> None:
"""Draw red semi-transparent rectangles over box zones (in-place).
Reusable for columns, rows, and words overlays.
"""
for zone in zones:
if zone.get("zone_type") != "box" or not zone.get("box"):
continue
box = zone["box"]
bx, by = box["x"], box["y"]
bw, bh = box["width"], box["height"]
# Red semi-transparent fill (~25 %)
box_overlay = img.copy()
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
# Border
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
# Label
cv2.putText(img, label, (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
async def _get_columns_overlay(session_id: str) -> Response: async def _get_columns_overlay(session_id: str) -> Response:
"""Generate cropped (or dewarped) image with column borders drawn on it.""" """Generate cropped (or dewarped) image with column borders drawn on it."""
session = await get_session_db(session_id) session = await get_session_db(session_id)
@@ -1299,6 +1460,9 @@ async def _get_columns_overlay(session_id: str) -> Response:
cv2.putText(img, "BOX", (bx + 10, by + bh - 10), cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img) success, result_png = cv2.imencode(".png", img)
if not success: if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image") raise HTTPException(status_code=500, detail="Failed to encode overlay image")
@@ -1341,13 +1505,101 @@ async def detect_rows(session_id: str):
else: else:
left_x, right_x, top_y, bottom_y = content_bounds left_x, right_x, top_y, bottom_y = content_bounds
# Run row detection # Read zones from column_result to exclude box regions
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) session = await get_session_db(session_id)
column_result = (session or {}).get("column_result") or {}
zones = column_result.get("zones", [])
# Collect box y-ranges for filtering
box_ranges = [] # [(y_start, y_end)]
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
box_ranges.append((box["y"], box["y"] + box["height"]))
if box_ranges and inv is not None:
# Combined-image approach: strip box regions from inv image,
# run row detection on the combined image, then remap y-coords back.
content_strips = [] # [(y_start, y_end)] in absolute coords
# Build content strips by subtracting box ranges from [top_y, bottom_y]
sorted_boxes = sorted(box_ranges, key=lambda r: r[0])
strip_start = top_y
for by_start, by_end in sorted_boxes:
if by_start > strip_start:
content_strips.append((strip_start, by_start))
strip_start = max(strip_start, by_end)
if strip_start < bottom_y:
content_strips.append((strip_start, bottom_y))
# Filter to strips with meaningful height
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
if content_strips:
# Stack content strips vertically
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
combined_inv = np.vstack(inv_strips)
# Filter word_dicts to only include words from content strips
combined_words = []
cum_y = 0
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
for ys, ye in content_strips:
h = ye - ys
strip_offsets.append((cum_y, h, ys))
for w in word_dicts:
w_abs_y = w['top'] + top_y # word y is relative to content top
w_center = w_abs_y + w['height'] / 2
if ys <= w_center < ye:
# Remap to combined coordinates
w_copy = dict(w)
w_copy['top'] = cum_y + (w_abs_y - ys)
combined_words.append(w_copy)
cum_y += h
# Run row detection on combined image
combined_h = combined_inv.shape[0]
rows = detect_row_geometry(
combined_inv, combined_words, left_x, right_x, 0, combined_h,
)
# Remap y-coordinates back to absolute page coords
def _combined_y_to_abs(cy: int) -> int:
for c_start, s_h, abs_start in strip_offsets:
if cy < c_start + s_h:
return abs_start + (cy - c_start)
last_c, last_h, last_abs = strip_offsets[-1]
return last_abs + last_h
for r in rows:
abs_y = _combined_y_to_abs(r.y)
abs_y_end = _combined_y_to_abs(r.y + r.height)
r.y = abs_y
r.height = abs_y_end - abs_y
else:
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
else:
# No boxes — standard row detection
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
duration = time.time() - t0 duration = time.time() - t0
# Assign zone_index based on which content zone each row falls in
# Build content zone list with indices
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
# Build serializable result (exclude words to keep payload small) # Build serializable result (exclude words to keep payload small)
rows_data = [] rows_data = []
for r in rows: for r in rows:
# Determine zone_index
zone_idx = 0
row_center_y = r.y + r.height / 2
for zi, zone in content_zones:
zy = zone["y"]
zh = zone["height"]
if zy <= row_center_y < zy + zh:
zone_idx = zi
break
rd = { rd = {
"index": r.index, "index": r.index,
"x": r.x, "x": r.x,
@@ -1357,7 +1609,7 @@ async def detect_rows(session_id: str):
"word_count": r.word_count, "word_count": r.word_count,
"row_type": r.row_type, "row_type": r.row_type,
"gap_before": r.gap_before, "gap_before": r.gap_before,
"zone_index": 0, "zone_index": zone_idx,
} }
rows_data.append(rd) rows_data.append(rd)
@@ -1564,6 +1816,25 @@ async def detect_words(
] ]
row.word_count = len(row.words) row.word_count = len(row.words)
# Exclude rows that fall within box zones
zones = column_result.get("zones", [])
box_ranges = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
box_ranges.append((box["y"], box["y"] + box["height"]))
if box_ranges:
def _row_in_box(r):
center_y = r.y + r.height / 2
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges)
before_count = len(row_geoms)
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
excluded = before_count - len(row_geoms)
if excluded:
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
if stream: if stream:
# Cell-First OCR v2: use batch-then-stream approach instead of # Cell-First OCR v2: use batch-then-stream approach instead of
# per-cell streaming. The parallel ThreadPoolExecutor in # per-cell streaming. The parallel ThreadPoolExecutor in
@@ -2205,12 +2476,24 @@ async def save_reconstruction(session_id: str, request: Request):
# Build update map: cell_id -> new text # Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates} update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Update cells # Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", []) cells = word_result.get("cells", [])
updated_count = 0 updated_count = 0
for cell in cells: for cell in cells:
if cell["cell_id"] in update_map: if cell["cell_id"] in main_updates:
cell["text"] = update_map[cell["cell_id"]] cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited" cell["status"] = "edited"
updated_count += 1 updated_count += 1
@@ -2227,7 +2510,7 @@ async def save_reconstruction(session_id: str, request: Request):
cell_id = f"R{row_idx:02d}_C{col_idx}" cell_id = f"R{row_idx:02d}_C{col_idx}"
# Also try without zero-padding # Also try without zero-padding
cell_id_alt = f"R{row_idx}_C{col_idx}" cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = update_map.get(cell_id) or update_map.get(cell_id_alt) new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None: if new_text is not None:
entry[field_name] = new_text entry[field_name] = new_text
@@ -2240,17 +2523,51 @@ async def save_reconstruction(session_id: str, request: Request):
if session_id in _cache: if session_id in _cache:
_cache[session_id]["word_result"] = word_result _cache[session_id]["word_result"] = word_result
logger.info(f"Reconstruction saved for session {session_id}: {updated_count} cells updated") # Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return { return {
"session_id": session_id, "session_id": session_id,
"updated": updated_count, "updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
} }
@router.get("/sessions/{session_id}/reconstruction/fabric-json") @router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str): async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.""" """Return cell grid as Fabric.js-compatible JSON for the canvas editor.
If the session has sub-sessions (box regions), their cells are merged
into the result at the correct Y positions.
"""
session = await get_session_db(session_id) session = await get_session_db(session_id)
if not session: if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
@@ -2259,16 +2576,108 @@ async def get_fabric_json(session_id: str):
if not word_result: if not word_result:
raise HTTPException(status_code=400, detail="No word result found") raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", []) cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800) img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600) img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones", [])
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
# Offset sub-session cells to absolute page coordinates
for cell in sub_word["cells"]:
cell_copy = dict(cell)
# Prefix cell_id with box index
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
# Offset bbox_px
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h) fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json return fabric_json
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
# Tag main entries
for e in entries:
e.setdefault("source", "main")
# Merge sub-session entries
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones", [])
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y # for sorting
entries.append(e_copy)
# Sort by approximate Y position
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100 # main entries by row index
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf") @router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str): async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table.""" """Export the reconstructed cell grid as a PDF table."""
@@ -2804,6 +3213,9 @@ async def _get_rows_overlay(session_id: str) -> Response:
ex = min(sx + dash_len, img_w_px) ex = min(sx + dash_len, img_w_px)
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
# Red semi-transparent overlay for box zones
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img) success, result_png = cv2.imencode(".png", img)
if not success: if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image") raise HTTPException(status_code=500, detail="Failed to encode overlay image")
@@ -2943,6 +3355,11 @@ async def _get_words_overlay(session_id: str) -> Response:
# Blend overlay at 10% opacity # Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
# Red semi-transparent overlay for box zones
column_result = session.get("column_result") or {}
zones = column_result.get("zones", [])
_draw_box_exclusion_overlay(img, zones)
success, result_png = cv2.imencode(".png", img) success, result_png = cv2.imencode(".png", img)
if not success: if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image") raise HTTPException(status_code=500, detail="Failed to encode overlay image")

View File

@@ -72,7 +72,9 @@ async def init_ocr_pipeline_tables():
ADD COLUMN IF NOT EXISTS oriented_png BYTEA, ADD COLUMN IF NOT EXISTS oriented_png BYTEA,
ADD COLUMN IF NOT EXISTS cropped_png BYTEA, ADD COLUMN IF NOT EXISTS cropped_png BYTEA,
ADD COLUMN IF NOT EXISTS orientation_result JSONB, ADD COLUMN IF NOT EXISTS orientation_result JSONB,
ADD COLUMN IF NOT EXISTS crop_result JSONB ADD COLUMN IF NOT EXISTS crop_result JSONB,
ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES ocr_pipeline_sessions(id) ON DELETE CASCADE,
ADD COLUMN IF NOT EXISTS box_index INT
""") """)
@@ -85,22 +87,33 @@ async def create_session_db(
name: str, name: str,
filename: str, filename: str,
original_png: bytes, original_png: bytes,
parent_session_id: Optional[str] = None,
box_index: Optional[int] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create a new OCR pipeline session.""" """Create a new OCR pipeline session.
Args:
parent_session_id: If set, this is a sub-session for a box region.
box_index: 0-based index of the box this sub-session represents.
"""
pool = await get_pool() pool = await get_pool()
parent_uuid = uuid.UUID(parent_session_id) if parent_session_id else None
async with pool.acquire() as conn: async with pool.acquire() as conn:
row = await conn.fetchrow(""" row = await conn.fetchrow("""
INSERT INTO ocr_pipeline_sessions ( INSERT INTO ocr_pipeline_sessions (
id, name, filename, original_png, status, current_step id, name, filename, original_png, status, current_step,
) VALUES ($1, $2, $3, $4, 'active', 1) parent_session_id, box_index
) VALUES ($1, $2, $3, $4, 'active', 1, $5, $6)
RETURNING id, name, filename, status, current_step, RETURNING id, name, filename, status, current_step,
orientation_result, crop_result, orientation_result, crop_result,
deskew_result, dewarp_result, column_result, row_result, deskew_result, dewarp_result, column_result, row_result,
word_result, ground_truth, auto_shear_degrees, word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result, doc_type, doc_type_result,
document_category, pipeline_log, document_category, pipeline_log,
parent_session_id, box_index,
created_at, updated_at created_at, updated_at
""", uuid.UUID(session_id), name, filename, original_png) """, uuid.UUID(session_id), name, filename, original_png,
parent_uuid, box_index)
return _row_to_dict(row) return _row_to_dict(row)
@@ -116,6 +129,7 @@ async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
word_result, ground_truth, auto_shear_degrees, word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result, doc_type, doc_type_result,
document_category, pipeline_log, document_category, pipeline_log,
parent_session_id, box_index,
created_at, updated_at created_at, updated_at
FROM ocr_pipeline_sessions WHERE id = $1 FROM ocr_pipeline_sessions WHERE id = $1
""", uuid.UUID(session_id)) """, uuid.UUID(session_id))
@@ -166,6 +180,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
'word_result', 'ground_truth', 'auto_shear_degrees', 'word_result', 'ground_truth', 'auto_shear_degrees',
'doc_type', 'doc_type_result', 'doc_type', 'doc_type_result',
'document_category', 'pipeline_log', 'document_category', 'pipeline_log',
'parent_session_id', 'box_index',
} }
jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log'} jsonb_fields = {'orientation_result', 'crop_result', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta', 'doc_type_result', 'pipeline_log'}
@@ -197,6 +212,7 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
word_result, ground_truth, auto_shear_degrees, word_result, ground_truth, auto_shear_degrees,
doc_type, doc_type_result, doc_type, doc_type_result,
document_category, pipeline_log, document_category, pipeline_log,
parent_session_id, box_index,
created_at, updated_at created_at, updated_at
""", *values) """, *values)
@@ -205,18 +221,45 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
return None return None
async def list_sessions_db(limit: int = 50) -> List[Dict[str, Any]]: async def list_sessions_db(
"""List all sessions (metadata only, no images).""" limit: int = 50,
include_sub_sessions: bool = False,
) -> List[Dict[str, Any]]:
"""List sessions (metadata only, no images).
By default, sub-sessions (those with parent_session_id) are excluded.
Pass include_sub_sessions=True to include them.
"""
pool = await get_pool()
async with pool.acquire() as conn:
where = "" if include_sub_sessions else "WHERE parent_session_id IS NULL"
rows = await conn.fetch(f"""
SELECT id, name, filename, status, current_step,
document_category, doc_type,
parent_session_id, box_index,
created_at, updated_at
FROM ocr_pipeline_sessions
{where}
ORDER BY created_at DESC
LIMIT $1
""", limit)
return [_row_to_dict(row) for row in rows]
async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
"""Get all sub-sessions for a parent session, ordered by box_index."""
pool = await get_pool() pool = await get_pool()
async with pool.acquire() as conn: async with pool.acquire() as conn:
rows = await conn.fetch(""" rows = await conn.fetch("""
SELECT id, name, filename, status, current_step, SELECT id, name, filename, status, current_step,
document_category, doc_type, document_category, doc_type,
parent_session_id, box_index,
created_at, updated_at created_at, updated_at
FROM ocr_pipeline_sessions FROM ocr_pipeline_sessions
ORDER BY created_at DESC WHERE parent_session_id = $1
LIMIT $1 ORDER BY box_index ASC
""", limit) """, uuid.UUID(parent_session_id))
return [_row_to_dict(row) for row in rows] return [_row_to_dict(row) for row in rows]
@@ -255,7 +298,7 @@ def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
result = dict(row) result = dict(row)
# UUID → string # UUID → string
for key in ['id', 'session_id']: for key in ['id', 'session_id', 'parent_session_id']:
if key in result and result[key] is not None: if key in result and result[key] is not None:
result[key] = str(result[key]) result[key] = str(result[key])