- build-grid now saves the automatic OCR result as ground_truth.auto_grid_snapshot - mark-ground-truth includes a correction_diff comparing auto vs corrected - New endpoint GET /correction-diff returns detailed diff with per-col_type accuracy breakdown (english, german, ipa, etc.) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
608 lines
20 KiB
Python
608 lines
20 KiB
Python
"""
|
|
OCR Pipeline Regression Tests — Ground Truth comparison system.
|
|
|
|
Allows marking sessions as "ground truth" and re-running build_grid()
|
|
to detect regressions after code changes.
|
|
|
|
Lizenz: Apache 2.0
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
|
|
from grid_editor_api import _build_grid_core
|
|
from ocr_pipeline_session_store import (
|
|
get_pool,
|
|
get_session_db,
|
|
list_ground_truth_sessions_db,
|
|
update_session_db,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DB persistence for regression runs
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _init_regression_table():
|
|
"""Ensure regression_runs table exists (idempotent)."""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
migration_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
"migrations/008_regression_runs.sql",
|
|
)
|
|
if os.path.exists(migration_path):
|
|
with open(migration_path, "r") as f:
|
|
sql = f.read()
|
|
await conn.execute(sql)
|
|
|
|
|
|
async def _persist_regression_run(
|
|
status: str,
|
|
summary: dict,
|
|
results: list,
|
|
duration_ms: int,
|
|
triggered_by: str = "manual",
|
|
) -> str:
|
|
"""Save a regression run to the database. Returns the run ID."""
|
|
try:
|
|
await _init_regression_table()
|
|
pool = await get_pool()
|
|
run_id = str(uuid.uuid4())
|
|
async with pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO regression_runs
|
|
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
|
|
""",
|
|
run_id,
|
|
status,
|
|
summary.get("total", 0),
|
|
summary.get("passed", 0),
|
|
summary.get("failed", 0),
|
|
summary.get("errors", 0),
|
|
duration_ms,
|
|
json.dumps(results),
|
|
triggered_by,
|
|
)
|
|
logger.info("Regression run %s persisted: %s", run_id, status)
|
|
return run_id
|
|
except Exception as e:
|
|
logger.warning("Failed to persist regression run: %s", e)
|
|
return ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
|
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
|
|
|
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
|
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
|
"""
|
|
cells = []
|
|
for zone in grid_result.get("zones", []):
|
|
for cell in zone.get("cells", []):
|
|
cells.append({
|
|
"cell_id": cell.get("cell_id", ""),
|
|
"row_index": cell.get("row_index"),
|
|
"col_index": cell.get("col_index"),
|
|
"col_type": cell.get("col_type", ""),
|
|
"text": cell.get("text", ""),
|
|
})
|
|
return cells
|
|
|
|
|
|
def _build_reference_snapshot(
|
|
grid_result: dict,
|
|
pipeline: Optional[str] = None,
|
|
) -> dict:
|
|
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
|
cells = _extract_cells_for_comparison(grid_result)
|
|
|
|
total_zones = len(grid_result.get("zones", []))
|
|
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
|
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
|
|
|
snapshot = {
|
|
"saved_at": datetime.now(timezone.utc).isoformat(),
|
|
"version": 1,
|
|
"pipeline": pipeline,
|
|
"summary": {
|
|
"total_zones": total_zones,
|
|
"total_columns": total_columns,
|
|
"total_rows": total_rows,
|
|
"total_cells": len(cells),
|
|
},
|
|
"cells": cells,
|
|
}
|
|
return snapshot
|
|
|
|
|
|
def compare_grids(reference: dict, current: dict) -> dict:
|
|
"""Compare a reference grid snapshot with a newly computed one.
|
|
|
|
Returns a diff report with:
|
|
- status: "pass" or "fail"
|
|
- structural_diffs: changes in zone/row/column counts
|
|
- cell_diffs: list of individual cell changes
|
|
"""
|
|
ref_summary = reference.get("summary", {})
|
|
cur_summary = current.get("summary", {})
|
|
|
|
structural_diffs = []
|
|
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
|
ref_val = ref_summary.get(key, 0)
|
|
cur_val = cur_summary.get(key, 0)
|
|
if ref_val != cur_val:
|
|
structural_diffs.append({
|
|
"field": key,
|
|
"reference": ref_val,
|
|
"current": cur_val,
|
|
})
|
|
|
|
# Build cell lookup by cell_id
|
|
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
|
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
|
|
|
cell_diffs: List[Dict[str, Any]] = []
|
|
|
|
# Check for missing cells (in reference but not in current)
|
|
for cell_id in ref_cells:
|
|
if cell_id not in cur_cells:
|
|
cell_diffs.append({
|
|
"type": "cell_missing",
|
|
"cell_id": cell_id,
|
|
"reference_text": ref_cells[cell_id].get("text", ""),
|
|
})
|
|
|
|
# Check for added cells (in current but not in reference)
|
|
for cell_id in cur_cells:
|
|
if cell_id not in ref_cells:
|
|
cell_diffs.append({
|
|
"type": "cell_added",
|
|
"cell_id": cell_id,
|
|
"current_text": cur_cells[cell_id].get("text", ""),
|
|
})
|
|
|
|
# Check for changes in shared cells
|
|
for cell_id in ref_cells:
|
|
if cell_id not in cur_cells:
|
|
continue
|
|
ref_cell = ref_cells[cell_id]
|
|
cur_cell = cur_cells[cell_id]
|
|
|
|
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
|
cell_diffs.append({
|
|
"type": "text_change",
|
|
"cell_id": cell_id,
|
|
"reference": ref_cell.get("text", ""),
|
|
"current": cur_cell.get("text", ""),
|
|
})
|
|
|
|
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
|
cell_diffs.append({
|
|
"type": "col_type_change",
|
|
"cell_id": cell_id,
|
|
"reference": ref_cell.get("col_type", ""),
|
|
"current": cur_cell.get("col_type", ""),
|
|
})
|
|
|
|
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
|
|
|
return {
|
|
"status": status,
|
|
"structural_diffs": structural_diffs,
|
|
"cell_diffs": cell_diffs,
|
|
"summary": {
|
|
"structural_changes": len(structural_diffs),
|
|
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
|
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
|
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
|
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/mark-ground-truth")
|
|
async def mark_ground_truth(
|
|
session_id: str,
|
|
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
|
|
):
|
|
"""Save the current build-grid result as ground-truth reference."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
grid_result = session.get("grid_editor_result")
|
|
if not grid_result or not grid_result.get("zones"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No grid_editor_result found. Run build-grid first.",
|
|
)
|
|
|
|
# Auto-detect pipeline from word_result if not provided
|
|
if not pipeline:
|
|
wr = session.get("word_result") or {}
|
|
engine = wr.get("ocr_engine", "")
|
|
if engine in ("kombi", "rapid_kombi"):
|
|
pipeline = "kombi"
|
|
elif engine == "paddle_direct":
|
|
pipeline = "paddle-direct"
|
|
else:
|
|
pipeline = "pipeline"
|
|
|
|
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
|
|
|
|
# Merge into existing ground_truth JSONB
|
|
gt = session.get("ground_truth") or {}
|
|
gt["build_grid_reference"] = reference
|
|
await update_session_db(session_id, ground_truth=gt, current_step=11)
|
|
|
|
# Compare with auto-snapshot if available (shows what the user corrected)
|
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
|
correction_diff = None
|
|
if auto_snapshot:
|
|
correction_diff = compare_grids(auto_snapshot, reference)
|
|
|
|
logger.info(
|
|
"Ground truth marked for session %s: %d cells (corrections: %s)",
|
|
session_id,
|
|
len(reference["cells"]),
|
|
correction_diff["summary"] if correction_diff else "no auto-snapshot",
|
|
)
|
|
|
|
return {
|
|
"status": "ok",
|
|
"session_id": session_id,
|
|
"cells_saved": len(reference["cells"]),
|
|
"summary": reference["summary"],
|
|
"correction_diff": correction_diff,
|
|
}
|
|
|
|
|
|
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
|
async def unmark_ground_truth(session_id: str):
|
|
"""Remove the ground-truth reference from a session."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
gt = session.get("ground_truth") or {}
|
|
if "build_grid_reference" not in gt:
|
|
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
|
|
|
del gt["build_grid_reference"]
|
|
await update_session_db(session_id, ground_truth=gt)
|
|
|
|
logger.info("Ground truth removed for session %s", session_id)
|
|
return {"status": "ok", "session_id": session_id}
|
|
|
|
|
|
@router.get("/sessions/{session_id}/correction-diff")
|
|
async def get_correction_diff(session_id: str):
|
|
"""Compare automatic OCR grid with manually corrected ground truth.
|
|
|
|
Returns a diff showing exactly which cells the user corrected,
|
|
broken down by col_type (english, german, ipa, etc.).
|
|
"""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
gt = session.get("ground_truth") or {}
|
|
auto_snapshot = gt.get("auto_grid_snapshot")
|
|
reference = gt.get("build_grid_reference")
|
|
|
|
if not auto_snapshot:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
|
|
)
|
|
if not reference:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="No ground truth reference found. Mark as ground truth first.",
|
|
)
|
|
|
|
diff = compare_grids(auto_snapshot, reference)
|
|
|
|
# Enrich with per-col_type breakdown
|
|
col_type_stats: Dict[str, Dict[str, int]] = {}
|
|
for cell_diff in diff.get("cell_diffs", []):
|
|
if cell_diff["type"] != "text_change":
|
|
continue
|
|
# Find col_type from reference cells
|
|
cell_id = cell_diff["cell_id"]
|
|
ref_cell = next(
|
|
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
|
|
None,
|
|
)
|
|
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
|
|
if ct not in col_type_stats:
|
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
|
col_type_stats[ct]["corrected"] += 1
|
|
|
|
# Count total cells per col_type from reference
|
|
for cell in reference.get("cells", []):
|
|
ct = cell.get("col_type", "unknown")
|
|
if ct not in col_type_stats:
|
|
col_type_stats[ct] = {"total": 0, "corrected": 0}
|
|
col_type_stats[ct]["total"] += 1
|
|
|
|
# Calculate accuracy per col_type
|
|
for ct, stats in col_type_stats.items():
|
|
total = stats["total"]
|
|
corrected = stats["corrected"]
|
|
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
|
|
|
|
diff["col_type_breakdown"] = col_type_stats
|
|
|
|
return diff
|
|
|
|
|
|
@router.get("/ground-truth-sessions")
|
|
async def list_ground_truth_sessions():
|
|
"""List all sessions that have a ground-truth reference."""
|
|
sessions = await list_ground_truth_sessions_db()
|
|
|
|
result = []
|
|
for s in sessions:
|
|
gt = s.get("ground_truth") or {}
|
|
ref = gt.get("build_grid_reference", {})
|
|
result.append({
|
|
"session_id": s["id"],
|
|
"name": s.get("name", ""),
|
|
"filename": s.get("filename", ""),
|
|
"document_category": s.get("document_category"),
|
|
"pipeline": ref.get("pipeline"),
|
|
"saved_at": ref.get("saved_at"),
|
|
"summary": ref.get("summary", {}),
|
|
})
|
|
|
|
return {"sessions": result, "count": len(result)}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/regression/run")
|
|
async def run_single_regression(session_id: str):
|
|
"""Re-run build_grid for a single session and compare to ground truth."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
gt = session.get("ground_truth") or {}
|
|
reference = gt.get("build_grid_reference")
|
|
if not reference:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No ground truth reference found for this session",
|
|
)
|
|
|
|
# Re-compute grid without persisting
|
|
try:
|
|
new_result = await _build_grid_core(session_id, session)
|
|
except (ValueError, Exception) as e:
|
|
return {
|
|
"session_id": session_id,
|
|
"name": session.get("name", ""),
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
|
|
new_snapshot = _build_reference_snapshot(new_result)
|
|
diff = compare_grids(reference, new_snapshot)
|
|
|
|
logger.info(
|
|
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
|
session_id, diff["status"],
|
|
diff["summary"]["structural_changes"],
|
|
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
|
)
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"name": session.get("name", ""),
|
|
"status": diff["status"],
|
|
"diff": diff,
|
|
"reference_summary": reference.get("summary", {}),
|
|
"current_summary": new_snapshot.get("summary", {}),
|
|
}
|
|
|
|
|
|
@router.post("/regression/run")
|
|
async def run_all_regressions(
|
|
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
|
|
):
|
|
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
|
start_time = time.monotonic()
|
|
sessions = await list_ground_truth_sessions_db()
|
|
|
|
if not sessions:
|
|
return {
|
|
"status": "pass",
|
|
"message": "No ground truth sessions found",
|
|
"results": [],
|
|
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
|
}
|
|
|
|
results = []
|
|
passed = 0
|
|
failed = 0
|
|
errors = 0
|
|
|
|
for s in sessions:
|
|
session_id = s["id"]
|
|
gt = s.get("ground_truth") or {}
|
|
reference = gt.get("build_grid_reference")
|
|
if not reference:
|
|
continue
|
|
|
|
# Re-load full session (list query may not include all JSONB fields)
|
|
full_session = await get_session_db(session_id)
|
|
if not full_session:
|
|
results.append({
|
|
"session_id": session_id,
|
|
"name": s.get("name", ""),
|
|
"status": "error",
|
|
"error": "Session not found during re-load",
|
|
})
|
|
errors += 1
|
|
continue
|
|
|
|
try:
|
|
new_result = await _build_grid_core(session_id, full_session)
|
|
except (ValueError, Exception) as e:
|
|
results.append({
|
|
"session_id": session_id,
|
|
"name": s.get("name", ""),
|
|
"status": "error",
|
|
"error": str(e),
|
|
})
|
|
errors += 1
|
|
continue
|
|
|
|
new_snapshot = _build_reference_snapshot(new_result)
|
|
diff = compare_grids(reference, new_snapshot)
|
|
|
|
entry = {
|
|
"session_id": session_id,
|
|
"name": s.get("name", ""),
|
|
"status": diff["status"],
|
|
"diff_summary": diff["summary"],
|
|
"reference_summary": reference.get("summary", {}),
|
|
"current_summary": new_snapshot.get("summary", {}),
|
|
}
|
|
|
|
# Include full diffs only for failures (keep response compact)
|
|
if diff["status"] == "fail":
|
|
entry["structural_diffs"] = diff["structural_diffs"]
|
|
entry["cell_diffs"] = diff["cell_diffs"]
|
|
failed += 1
|
|
else:
|
|
passed += 1
|
|
|
|
results.append(entry)
|
|
|
|
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
|
|
|
summary = {
|
|
"total": len(results),
|
|
"passed": passed,
|
|
"failed": failed,
|
|
"errors": errors,
|
|
}
|
|
|
|
logger.info(
|
|
"Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
|
|
overall, passed, failed, errors, len(results), duration_ms,
|
|
)
|
|
|
|
# Persist to DB
|
|
run_id = await _persist_regression_run(
|
|
status=overall,
|
|
summary=summary,
|
|
results=results,
|
|
duration_ms=duration_ms,
|
|
triggered_by=triggered_by,
|
|
)
|
|
|
|
return {
|
|
"status": overall,
|
|
"run_id": run_id,
|
|
"duration_ms": duration_ms,
|
|
"results": results,
|
|
"summary": summary,
|
|
}
|
|
|
|
|
|
@router.get("/regression/history")
|
|
async def get_regression_history(
|
|
limit: int = Query(20, ge=1, le=100),
|
|
):
|
|
"""Get recent regression run history from the database."""
|
|
try:
|
|
await _init_regression_table()
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT id, run_at, status, total, passed, failed, errors,
|
|
duration_ms, triggered_by
|
|
FROM regression_runs
|
|
ORDER BY run_at DESC
|
|
LIMIT $1
|
|
""",
|
|
limit,
|
|
)
|
|
return {
|
|
"runs": [
|
|
{
|
|
"id": str(row["id"]),
|
|
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
|
"status": row["status"],
|
|
"total": row["total"],
|
|
"passed": row["passed"],
|
|
"failed": row["failed"],
|
|
"errors": row["errors"],
|
|
"duration_ms": row["duration_ms"],
|
|
"triggered_by": row["triggered_by"],
|
|
}
|
|
for row in rows
|
|
],
|
|
"count": len(rows),
|
|
}
|
|
except Exception as e:
|
|
logger.warning("Failed to fetch regression history: %s", e)
|
|
return {"runs": [], "count": 0, "error": str(e)}
|
|
|
|
|
|
@router.get("/regression/history/{run_id}")
|
|
async def get_regression_run_detail(run_id: str):
|
|
"""Get detailed results of a specific regression run."""
|
|
try:
|
|
await _init_regression_table()
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT * FROM regression_runs WHERE id = $1",
|
|
run_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
return {
|
|
"id": str(row["id"]),
|
|
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
|
|
"status": row["status"],
|
|
"total": row["total"],
|
|
"passed": row["passed"],
|
|
"failed": row["failed"],
|
|
"errors": row["errors"],
|
|
"duration_ms": row["duration_ms"],
|
|
"triggered_by": row["triggered_by"],
|
|
"results": json.loads(row["results"]) if row["results"] else [],
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|