feat: save automatic grid snapshot before manual edits for GT comparison

- build-grid now saves the automatic OCR result as ground_truth.auto_grid_snapshot
- mark-ground-truth includes a correction_diff comparing auto vs corrected
- New endpoint GET /correction-diff returns detailed diff with per-col_type
  accuracy breakdown (english, german, ipa, etc.)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-24 13:16:44 +01:00
parent 72ce4420cb
commit 410d36f3de
2 changed files with 89 additions and 3 deletions

View File

@@ -31,6 +31,7 @@ from ocr_pipeline_session_store import (
get_session_image,
update_session_db,
)
from ocr_pipeline_regression import _build_reference_snapshot
logger = logging.getLogger(__name__)
@@ -2814,8 +2815,22 @@ async def build_grid(session_id: str):
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Save automatic grid snapshot for later comparison with manual corrections
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
auto_pipeline = "kombi"
elif engine == "paddle_direct":
auto_pipeline = "paddle-direct"
else:
auto_pipeline = "pipeline"
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
gt = session.get("ground_truth") or {}
gt["auto_grid_snapshot"] = auto_snapshot
# Persist to DB and advance current_step to 11 (reconstruction complete)
await update_session_db(session_id, grid_editor_result=result, current_step=11)
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
logger.info(
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "

View File

@@ -258,9 +258,17 @@ async def mark_ground_truth(
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells",
session_id, len(reference["cells"]),
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
@@ -268,6 +276,7 @@ async def mark_ground_truth(
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@@ -289,6 +298,68 @@ async def unmark_ground_truth(session_id: str):
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""