feat: save automatic grid snapshot before manual edits for GT comparison
- build-grid now saves the automatic OCR result as ground_truth.auto_grid_snapshot - mark-ground-truth includes a correction_diff comparing auto vs corrected - New endpoint GET /correction-diff returns detailed diff with per-col_type accuracy breakdown (english, german, ipa, etc.) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from ocr_pipeline_session_store import (
|
|||||||
get_session_image,
|
get_session_image,
|
||||||
update_session_db,
|
update_session_db,
|
||||||
)
|
)
|
||||||
|
from ocr_pipeline_regression import _build_reference_snapshot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -2814,8 +2815,22 @@ async def build_grid(session_id: str):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
# Save automatic grid snapshot for later comparison with manual corrections
|
||||||
|
wr = session.get("word_result") or {}
|
||||||
|
engine = wr.get("ocr_engine", "")
|
||||||
|
if engine in ("kombi", "rapid_kombi"):
|
||||||
|
auto_pipeline = "kombi"
|
||||||
|
elif engine == "paddle_direct":
|
||||||
|
auto_pipeline = "paddle-direct"
|
||||||
|
else:
|
||||||
|
auto_pipeline = "pipeline"
|
||||||
|
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
gt["auto_grid_snapshot"] = auto_snapshot
|
||||||
|
|
||||||
# Persist to DB and advance current_step to 11 (reconstruction complete)
|
# Persist to DB and advance current_step to 11 (reconstruction complete)
|
||||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
||||||
|
|||||||
@@ -258,9 +258,17 @@ async def mark_ground_truth(
|
|||||||
gt["build_grid_reference"] = reference
|
gt["build_grid_reference"] = reference
|
||||||
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
||||||
|
|
||||||
|
# Compare with auto-snapshot if available (shows what the user corrected)
|
||||||
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||||
|
correction_diff = None
|
||||||
|
if auto_snapshot:
|
||||||
|
correction_diff = compare_grids(auto_snapshot, reference)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Ground truth marked for session %s: %d cells",
|
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
||||||
session_id, len(reference["cells"]),
|
session_id,
|
||||||
|
len(reference["cells"]),
|
||||||
|
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -268,6 +276,7 @@ async def mark_ground_truth(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"cells_saved": len(reference["cells"]),
|
"cells_saved": len(reference["cells"]),
|
||||||
"summary": reference["summary"],
|
"summary": reference["summary"],
|
||||||
|
"correction_diff": correction_diff,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -289,6 +298,68 @@ async def unmark_ground_truth(session_id: str):
|
|||||||
return {"status": "ok", "session_id": session_id}
|
return {"status": "ok", "session_id": session_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/correction-diff")
|
||||||
|
async def get_correction_diff(session_id: str):
|
||||||
|
"""Compare automatic OCR grid with manually corrected ground truth.
|
||||||
|
|
||||||
|
Returns a diff showing exactly which cells the user corrected,
|
||||||
|
broken down by col_type (english, german, ipa, etc.).
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
|
||||||
|
if not auto_snapshot:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
||||||
|
)
|
||||||
|
if not reference:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No ground truth reference found. Mark as ground truth first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
diff = compare_grids(auto_snapshot, reference)
|
||||||
|
|
||||||
|
# Enrich with per-col_type breakdown
|
||||||
|
col_type_stats: Dict[str, Dict[str, int]] = {}
|
||||||
|
for cell_diff in diff.get("cell_diffs", []):
|
||||||
|
if cell_diff["type"] != "text_change":
|
||||||
|
continue
|
||||||
|
# Find col_type from reference cells
|
||||||
|
cell_id = cell_diff["cell_id"]
|
||||||
|
ref_cell = next(
|
||||||
|
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
||||||
|
if ct not in col_type_stats:
|
||||||
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||||
|
col_type_stats[ct]["corrected"] += 1
|
||||||
|
|
||||||
|
# Count total cells per col_type from reference
|
||||||
|
for cell in reference.get("cells", []):
|
||||||
|
ct = cell.get("col_type", "unknown")
|
||||||
|
if ct not in col_type_stats:
|
||||||
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||||
|
col_type_stats[ct]["total"] += 1
|
||||||
|
|
||||||
|
# Calculate accuracy per col_type
|
||||||
|
for ct, stats in col_type_stats.items():
|
||||||
|
total = stats["total"]
|
||||||
|
corrected = stats["corrected"]
|
||||||
|
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
||||||
|
|
||||||
|
diff["col_type_breakdown"] = col_type_stats
|
||||||
|
|
||||||
|
return diff
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ground-truth-sessions")
|
@router.get("/ground-truth-sessions")
|
||||||
async def list_ground_truth_sessions():
|
async def list_ground_truth_sessions():
|
||||||
"""List all sessions that have a ground-truth reference."""
|
"""List all sessions that have a ground-truth reference."""
|
||||||
|
|||||||
Reference in New Issue
Block a user