feat: save automatic grid snapshot before manual edits for GT comparison
- build-grid now saves the automatic OCR result as ground_truth.auto_grid_snapshot - mark-ground-truth includes a correction_diff comparing auto vs corrected - New endpoint GET /correction-diff returns detailed diff with per-col_type accuracy breakdown (english, german, ipa, etc.) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from ocr_pipeline_session_store import (
|
||||
get_session_image,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_regression import _build_reference_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2814,8 +2815,22 @@ async def build_grid(session_id: str):
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Save automatic grid snapshot for later comparison with manual corrections
|
||||
wr = session.get("word_result") or {}
|
||||
engine = wr.get("ocr_engine", "")
|
||||
if engine in ("kombi", "rapid_kombi"):
|
||||
auto_pipeline = "kombi"
|
||||
elif engine == "paddle_direct":
|
||||
auto_pipeline = "paddle-direct"
|
||||
else:
|
||||
auto_pipeline = "pipeline"
|
||||
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["auto_grid_snapshot"] = auto_snapshot
|
||||
|
||||
# Persist to DB and advance current_step to 11 (reconstruction complete)
|
||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
||||
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
|
||||
|
||||
logger.info(
|
||||
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
||||
|
||||
@@ -258,9 +258,17 @@ async def mark_ground_truth(
|
||||
gt["build_grid_reference"] = reference
|
||||
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
||||
|
||||
# Compare with auto-snapshot if available (shows what the user corrected)
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
correction_diff = None
|
||||
if auto_snapshot:
|
||||
correction_diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
logger.info(
|
||||
"Ground truth marked for session %s: %d cells",
|
||||
session_id, len(reference["cells"]),
|
||||
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
||||
session_id,
|
||||
len(reference["cells"]),
|
||||
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -268,6 +276,7 @@ async def mark_ground_truth(
|
||||
"session_id": session_id,
|
||||
"cells_saved": len(reference["cells"]),
|
||||
"summary": reference["summary"],
|
||||
"correction_diff": correction_diff,
|
||||
}
|
||||
|
||||
|
||||
@@ -289,6 +298,68 @@ async def unmark_ground_truth(session_id: str):
|
||||
return {"status": "ok", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/correction-diff")
|
||||
async def get_correction_diff(session_id: str):
|
||||
"""Compare automatic OCR grid with manually corrected ground truth.
|
||||
|
||||
Returns a diff showing exactly which cells the user corrected,
|
||||
broken down by col_type (english, german, ipa, etc.).
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
auto_snapshot = gt.get("auto_grid_snapshot")
|
||||
reference = gt.get("build_grid_reference")
|
||||
|
||||
if not auto_snapshot:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
||||
)
|
||||
if not reference:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No ground truth reference found. Mark as ground truth first.",
|
||||
)
|
||||
|
||||
diff = compare_grids(auto_snapshot, reference)
|
||||
|
||||
# Enrich with per-col_type breakdown
|
||||
col_type_stats: Dict[str, Dict[str, int]] = {}
|
||||
for cell_diff in diff.get("cell_diffs", []):
|
||||
if cell_diff["type"] != "text_change":
|
||||
continue
|
||||
# Find col_type from reference cells
|
||||
cell_id = cell_diff["cell_id"]
|
||||
ref_cell = next(
|
||||
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
||||
None,
|
||||
)
|
||||
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["corrected"] += 1
|
||||
|
||||
# Count total cells per col_type from reference
|
||||
for cell in reference.get("cells", []):
|
||||
ct = cell.get("col_type", "unknown")
|
||||
if ct not in col_type_stats:
|
||||
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
||||
col_type_stats[ct]["total"] += 1
|
||||
|
||||
# Calculate accuracy per col_type
|
||||
for ct, stats in col_type_stats.items():
|
||||
total = stats["total"]
|
||||
corrected = stats["corrected"]
|
||||
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
||||
|
||||
diff["col_type_breakdown"] = col_type_stats
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
@router.get("/ground-truth-sessions")
|
||||
async def list_ground_truth_sessions():
|
||||
"""List all sessions that have a ground-truth reference."""
|
||||
|
||||
Reference in New Issue
Block a user