Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_regression.py
Benjamin Admin 410d36f3de feat: save automatic grid snapshot before manual edits for GT comparison
- build-grid now saves the automatic OCR result as ground_truth.auto_grid_snapshot
- mark-ground-truth includes a correction_diff comparing auto vs corrected
- New endpoint GET /correction-diff returns detailed diff with per-col_type
  accuracy breakdown (english, german, ipa, etc.)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-24 13:16:44 +01:00

608 lines
20 KiB
Python

"""
OCR Pipeline Regression Tests — Ground Truth comparison system.
Allows marking sessions as "ground truth" and re-running build_grid()
to detect regressions after code changes.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_pool,
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))