""" OCR Pipeline Regression Tests — Ground Truth comparison system. Allows marking sessions as "ground truth" and re-running build_grid() to detect regressions after code changes. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import logging from datetime import datetime, timezone from typing import Any, Dict, List, Optional from fastapi import APIRouter, HTTPException, Query from grid_editor_api import _build_grid_core from ocr_pipeline_session_store import ( get_session_db, list_ground_truth_sessions_db, update_session_db, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]: """Extract a flat list of cells from a grid_editor_result for comparison. Only keeps fields relevant for comparison: cell_id, row_index, col_index, col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold. """ cells = [] for zone in grid_result.get("zones", []): for cell in zone.get("cells", []): cells.append({ "cell_id": cell.get("cell_id", ""), "row_index": cell.get("row_index"), "col_index": cell.get("col_index"), "col_type": cell.get("col_type", ""), "text": cell.get("text", ""), }) return cells def _build_reference_snapshot( grid_result: dict, pipeline: Optional[str] = None, ) -> dict: """Build a ground-truth reference snapshot from a grid_editor_result.""" cells = _extract_cells_for_comparison(grid_result) total_zones = len(grid_result.get("zones", [])) total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", [])) total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", [])) snapshot = { "saved_at": datetime.now(timezone.utc).isoformat(), "version": 1, "pipeline": pipeline, "summary": { "total_zones": total_zones, "total_columns": total_columns, "total_rows": total_rows, "total_cells": len(cells), }, "cells": cells, } return snapshot def compare_grids(reference: dict, current: dict) -> dict: """Compare a reference grid snapshot with a newly computed one. Returns a diff report with: - status: "pass" or "fail" - structural_diffs: changes in zone/row/column counts - cell_diffs: list of individual cell changes """ ref_summary = reference.get("summary", {}) cur_summary = current.get("summary", {}) structural_diffs = [] for key in ("total_zones", "total_columns", "total_rows", "total_cells"): ref_val = ref_summary.get(key, 0) cur_val = cur_summary.get(key, 0) if ref_val != cur_val: structural_diffs.append({ "field": key, "reference": ref_val, "current": cur_val, }) # Build cell lookup by cell_id ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])} cur_cells = {c["cell_id"]: c for c in current.get("cells", [])} cell_diffs: List[Dict[str, Any]] = [] # Check for missing cells (in reference but not in current) for cell_id in ref_cells: if cell_id not in cur_cells: cell_diffs.append({ "type": "cell_missing", "cell_id": cell_id, "reference_text": ref_cells[cell_id].get("text", ""), }) # Check for added cells (in current but not in reference) for cell_id in cur_cells: if cell_id not in ref_cells: cell_diffs.append({ "type": "cell_added", "cell_id": cell_id, "current_text": cur_cells[cell_id].get("text", ""), }) # Check for changes in shared cells for cell_id in ref_cells: if cell_id not in cur_cells: continue ref_cell = ref_cells[cell_id] cur_cell = cur_cells[cell_id] if ref_cell.get("text", "") != cur_cell.get("text", ""): cell_diffs.append({ "type": "text_change", "cell_id": cell_id, "reference": ref_cell.get("text", ""), "current": cur_cell.get("text", ""), }) if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""): cell_diffs.append({ "type": "col_type_change", "cell_id": cell_id, "reference": ref_cell.get("col_type", ""), "current": cur_cell.get("col_type", ""), }) status = "pass" if not structural_diffs and not cell_diffs else "fail" return { "status": status, "structural_diffs": structural_diffs, "cell_diffs": cell_diffs, "summary": { "structural_changes": len(structural_diffs), "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"), "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"), "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"), "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"), }, } # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/mark-ground-truth") async def mark_ground_truth( session_id: str, pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), ): """Save the current build-grid result as ground-truth reference.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") grid_result = session.get("grid_editor_result") if not grid_result or not grid_result.get("zones"): raise HTTPException( status_code=400, detail="No grid_editor_result found. Run build-grid first.", ) # Auto-detect pipeline from word_result if not provided if not pipeline: wr = session.get("word_result") or {} engine = wr.get("ocr_engine", "") if engine in ("kombi", "rapid_kombi"): pipeline = "kombi" elif engine == "paddle_direct": pipeline = "paddle-direct" else: pipeline = "pipeline" reference = _build_reference_snapshot(grid_result, pipeline=pipeline) # Merge into existing ground_truth JSONB gt = session.get("ground_truth") or {} gt["build_grid_reference"] = reference await update_session_db(session_id, ground_truth=gt) logger.info( "Ground truth marked for session %s: %d cells", session_id, len(reference["cells"]), ) return { "status": "ok", "session_id": session_id, "cells_saved": len(reference["cells"]), "summary": reference["summary"], } @router.delete("/sessions/{session_id}/mark-ground-truth") async def unmark_ground_truth(session_id: str): """Remove the ground-truth reference from a session.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") gt = session.get("ground_truth") or {} if "build_grid_reference" not in gt: raise HTTPException(status_code=404, detail="No ground truth reference found") del gt["build_grid_reference"] await update_session_db(session_id, ground_truth=gt) logger.info("Ground truth removed for session %s", session_id) return {"status": "ok", "session_id": session_id} @router.get("/ground-truth-sessions") async def list_ground_truth_sessions(): """List all sessions that have a ground-truth reference.""" sessions = await list_ground_truth_sessions_db() result = [] for s in sessions: gt = s.get("ground_truth") or {} ref = gt.get("build_grid_reference", {}) result.append({ "session_id": s["id"], "name": s.get("name", ""), "filename": s.get("filename", ""), "document_category": s.get("document_category"), "pipeline": ref.get("pipeline"), "saved_at": ref.get("saved_at"), "summary": ref.get("summary", {}), }) return {"sessions": result, "count": len(result)} @router.post("/sessions/{session_id}/regression/run") async def run_single_regression(session_id: str): """Re-run build_grid for a single session and compare to ground truth.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") gt = session.get("ground_truth") or {} reference = gt.get("build_grid_reference") if not reference: raise HTTPException( status_code=400, detail="No ground truth reference found for this session", ) # Re-compute grid without persisting try: new_result = await _build_grid_core(session_id, session) except (ValueError, Exception) as e: return { "session_id": session_id, "name": session.get("name", ""), "status": "error", "error": str(e), } new_snapshot = _build_reference_snapshot(new_result) diff = compare_grids(reference, new_snapshot) logger.info( "Regression test session %s: %s (%d structural, %d cell diffs)", session_id, diff["status"], diff["summary"]["structural_changes"], sum(v for k, v in diff["summary"].items() if k != "structural_changes"), ) return { "session_id": session_id, "name": session.get("name", ""), "status": diff["status"], "diff": diff, "reference_summary": reference.get("summary", {}), "current_summary": new_snapshot.get("summary", {}), } @router.post("/regression/run") async def run_all_regressions(): """Re-run build_grid for ALL ground-truth sessions and compare.""" sessions = await list_ground_truth_sessions_db() if not sessions: return { "status": "pass", "message": "No ground truth sessions found", "results": [], "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, } results = [] passed = 0 failed = 0 errors = 0 for s in sessions: session_id = s["id"] gt = s.get("ground_truth") or {} reference = gt.get("build_grid_reference") if not reference: continue # Re-load full session (list query may not include all JSONB fields) full_session = await get_session_db(session_id) if not full_session: results.append({ "session_id": session_id, "name": s.get("name", ""), "status": "error", "error": "Session not found during re-load", }) errors += 1 continue try: new_result = await _build_grid_core(session_id, full_session) except (ValueError, Exception) as e: results.append({ "session_id": session_id, "name": s.get("name", ""), "status": "error", "error": str(e), }) errors += 1 continue new_snapshot = _build_reference_snapshot(new_result) diff = compare_grids(reference, new_snapshot) entry = { "session_id": session_id, "name": s.get("name", ""), "status": diff["status"], "diff_summary": diff["summary"], "reference_summary": reference.get("summary", {}), "current_summary": new_snapshot.get("summary", {}), } # Include full diffs only for failures (keep response compact) if diff["status"] == "fail": entry["structural_diffs"] = diff["structural_diffs"] entry["cell_diffs"] = diff["cell_diffs"] failed += 1 else: passed += 1 results.append(entry) overall = "pass" if failed == 0 and errors == 0 else "fail" logger.info( "Regression suite: %s — %d passed, %d failed, %d errors (of %d)", overall, passed, failed, errors, len(results), ) return { "status": overall, "results": results, "summary": { "total": len(results), "passed": passed, "failed": failed, "errors": errors, }, }