Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_regression.py
Benjamin Admin a1e079b911
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Failing after 1m55s
CI / test-python-agent-core (push) Successful in 16s
CI / test-nodejs-website (push) Successful in 19s
feat: Sprint 1 — IPA hardening, regression framework, ground-truth review
Track A (Backend):
- Compound word IPA decomposition (schoolbag→school+bag)
- Trailing garbled IPA fragment removal after brackets (R21 fix)
- Regression runner with DB persistence, history endpoints
- Page crop determinism verified with tests

Track B (Frontend):
- OCR Regression dashboard (/ai/ocr-regression)
- Ground Truth Review workflow (/ai/ocr-ground-truth)
  with split-view, confidence highlighting, inline edit,
  batch mark, progress tracking

Track C (Docs):
- OCR-Pipeline.md v5.0 (Steps 5e-5h)
- Regression testing guide
- mkdocs.yml nav update

Track D (Infra):
- TrOCR baseline benchmark script
- run-regression.sh shell script
- Migration 008: regression_runs table

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 09:21:27 +01:00

537 lines
18 KiB
Python

"""
OCR Pipeline Regression Tests — Ground Truth comparison system.
Allows marking sessions as "ground truth" and re-running build_grid()
to detect regressions after code changes.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_pool,
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# DB persistence for regression runs
# ---------------------------------------------------------------------------
async def _init_regression_table():
"""Ensure regression_runs table exists (idempotent)."""
pool = await get_pool()
async with pool.acquire() as conn:
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/008_regression_runs.sql",
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
async def _persist_regression_run(
status: str,
summary: dict,
results: list,
duration_ms: int,
triggered_by: str = "manual",
) -> str:
"""Save a regression run to the database. Returns the run ID."""
try:
await _init_regression_table()
pool = await get_pool()
run_id = str(uuid.uuid4())
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO regression_runs
(id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
run_id,
status,
summary.get("total", 0),
summary.get("passed", 0),
summary.get("failed", 0),
summary.get("errors", 0),
duration_ms,
json.dumps(results),
triggered_by,
)
logger.info("Regression run %s persisted: %s", run_id, status)
return run_id
except Exception as e:
logger.warning("Failed to persist regression run: %s", e)
return ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt)
logger.info(
"Ground truth marked for session %s: %d cells",
session_id, len(reference["cells"]),
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))