Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_regression.py
Benjamin Admin a3e2a7f994
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 27s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m51s
CI / test-python-agent-core (push) Successful in 15s
CI / test-nodejs-website (push) Successful in 18s
Add GT button to OCR overlay, prominent category picker, track pipeline
- Ground Truth button on last step of Pipeline/Kombi modes in ocr-overlay
- Prominent category picker in active session info bar (pulses when unset)
- GT badge shown when session has ground truth reference
- Backend: auto-detect pipeline from ocr_engine, store in GT snapshot
- Pipeline info shown in GT session list and regression reports
- Also pass pipeline param from ocr-pipeline StepGroundTruth

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 14:49:02 +01:00

389 lines
13 KiB
Python

"""
OCR Pipeline Regression Tests — Ground Truth comparison system.
Allows marking sessions as "ground truth" and re-running build_grid()
to detect regressions after code changes.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
"""Extract a flat list of cells from a grid_editor_result for comparison.
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
"""
cells = []
for zone in grid_result.get("zones", []):
for cell in zone.get("cells", []):
cells.append({
"cell_id": cell.get("cell_id", ""),
"row_index": cell.get("row_index"),
"col_index": cell.get("col_index"),
"col_type": cell.get("col_type", ""),
"text": cell.get("text", ""),
})
return cells
def _build_reference_snapshot(
grid_result: dict,
pipeline: Optional[str] = None,
) -> dict:
"""Build a ground-truth reference snapshot from a grid_editor_result."""
cells = _extract_cells_for_comparison(grid_result)
total_zones = len(grid_result.get("zones", []))
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
snapshot = {
"saved_at": datetime.now(timezone.utc).isoformat(),
"version": 1,
"pipeline": pipeline,
"summary": {
"total_zones": total_zones,
"total_columns": total_columns,
"total_rows": total_rows,
"total_cells": len(cells),
},
"cells": cells,
}
return snapshot
def compare_grids(reference: dict, current: dict) -> dict:
"""Compare a reference grid snapshot with a newly computed one.
Returns a diff report with:
- status: "pass" or "fail"
- structural_diffs: changes in zone/row/column counts
- cell_diffs: list of individual cell changes
"""
ref_summary = reference.get("summary", {})
cur_summary = current.get("summary", {})
structural_diffs = []
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
ref_val = ref_summary.get(key, 0)
cur_val = cur_summary.get(key, 0)
if ref_val != cur_val:
structural_diffs.append({
"field": key,
"reference": ref_val,
"current": cur_val,
})
# Build cell lookup by cell_id
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
cell_diffs: List[Dict[str, Any]] = []
# Check for missing cells (in reference but not in current)
for cell_id in ref_cells:
if cell_id not in cur_cells:
cell_diffs.append({
"type": "cell_missing",
"cell_id": cell_id,
"reference_text": ref_cells[cell_id].get("text", ""),
})
# Check for added cells (in current but not in reference)
for cell_id in cur_cells:
if cell_id not in ref_cells:
cell_diffs.append({
"type": "cell_added",
"cell_id": cell_id,
"current_text": cur_cells[cell_id].get("text", ""),
})
# Check for changes in shared cells
for cell_id in ref_cells:
if cell_id not in cur_cells:
continue
ref_cell = ref_cells[cell_id]
cur_cell = cur_cells[cell_id]
if ref_cell.get("text", "") != cur_cell.get("text", ""):
cell_diffs.append({
"type": "text_change",
"cell_id": cell_id,
"reference": ref_cell.get("text", ""),
"current": cur_cell.get("text", ""),
})
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
cell_diffs.append({
"type": "col_type_change",
"cell_id": cell_id,
"reference": ref_cell.get("col_type", ""),
"current": cur_cell.get("col_type", ""),
})
status = "pass" if not structural_diffs and not cell_diffs else "fail"
return {
"status": status,
"structural_diffs": structural_diffs,
"cell_diffs": cell_diffs,
"summary": {
"structural_changes": len(structural_diffs),
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
},
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt)
logger.info(
"Ground truth marked for session %s: %d cells",
session_id, len(reference["cells"]),
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions():
"""Re-run build_grid for ALL ground-truth sessions and compare."""
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d)",
overall, passed, failed, errors, len(results),
)
return {
"status": overall,
"results": results,
"summary": {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
},
}