""" OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression. Extracted from ocr_pipeline_regression.py for modularity. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import json import logging import time from typing import Any, Dict, Optional from fastapi import APIRouter, HTTPException, Query from grid_editor_api import _build_grid_core from ocr_pipeline_session_store import ( get_session_db, list_ground_truth_sessions_db, update_session_db, ) from ocr_pipeline_regression_helpers import ( _build_reference_snapshot, _init_regression_table, _persist_regression_run, compare_grids, get_pool, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/mark-ground-truth") async def mark_ground_truth( session_id: str, pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), ): """Save the current build-grid result as ground-truth reference.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") grid_result = session.get("grid_editor_result") if not grid_result or not grid_result.get("zones"): raise HTTPException( status_code=400, detail="No grid_editor_result found. Run build-grid first.", ) # Auto-detect pipeline from word_result if not provided if not pipeline: wr = session.get("word_result") or {} engine = wr.get("ocr_engine", "") if engine in ("kombi", "rapid_kombi"): pipeline = "kombi" elif engine == "paddle_direct": pipeline = "paddle-direct" else: pipeline = "pipeline" reference = _build_reference_snapshot(grid_result, pipeline=pipeline) # Merge into existing ground_truth JSONB gt = session.get("ground_truth") or {} gt["build_grid_reference"] = reference await update_session_db(session_id, ground_truth=gt, current_step=11) # Compare with auto-snapshot if available (shows what the user corrected) auto_snapshot = gt.get("auto_grid_snapshot") correction_diff = None if auto_snapshot: correction_diff = compare_grids(auto_snapshot, reference) logger.info( "Ground truth marked for session %s: %d cells (corrections: %s)", session_id, len(reference["cells"]), correction_diff["summary"] if correction_diff else "no auto-snapshot", ) return { "status": "ok", "session_id": session_id, "cells_saved": len(reference["cells"]), "summary": reference["summary"], "correction_diff": correction_diff, } @router.delete("/sessions/{session_id}/mark-ground-truth") async def unmark_ground_truth(session_id: str): """Remove the ground-truth reference from a session.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") gt = session.get("ground_truth") or {} if "build_grid_reference" not in gt: raise HTTPException(status_code=404, detail="No ground truth reference found") del gt["build_grid_reference"] await update_session_db(session_id, ground_truth=gt) logger.info("Ground truth removed for session %s", session_id) return {"status": "ok", "session_id": session_id} @router.get("/sessions/{session_id}/correction-diff") async def get_correction_diff(session_id: str): """Compare automatic OCR grid with manually corrected ground truth. Returns a diff showing exactly which cells the user corrected, broken down by col_type (english, german, ipa, etc.). """ session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") gt = session.get("ground_truth") or {} auto_snapshot = gt.get("auto_grid_snapshot") reference = gt.get("build_grid_reference") if not auto_snapshot: raise HTTPException( status_code=404, detail="No auto_grid_snapshot found. Re-run build-grid to create one.", ) if not reference: raise HTTPException( status_code=404, detail="No ground truth reference found. Mark as ground truth first.", ) diff = compare_grids(auto_snapshot, reference) # Enrich with per-col_type breakdown col_type_stats: Dict[str, Dict[str, int]] = {} for cell_diff in diff.get("cell_diffs", []): if cell_diff["type"] != "text_change": continue # Find col_type from reference cells cell_id = cell_diff["cell_id"] ref_cell = next( (c for c in reference.get("cells", []) if c["cell_id"] == cell_id), None, ) ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown" if ct not in col_type_stats: col_type_stats[ct] = {"total": 0, "corrected": 0} col_type_stats[ct]["corrected"] += 1 # Count total cells per col_type from reference for cell in reference.get("cells", []): ct = cell.get("col_type", "unknown") if ct not in col_type_stats: col_type_stats[ct] = {"total": 0, "corrected": 0} col_type_stats[ct]["total"] += 1 # Calculate accuracy per col_type for ct, stats in col_type_stats.items(): total = stats["total"] corrected = stats["corrected"] stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0 diff["col_type_breakdown"] = col_type_stats return diff @router.get("/ground-truth-sessions") async def list_ground_truth_sessions(): """List all sessions that have a ground-truth reference.""" sessions = await list_ground_truth_sessions_db() result = [] for s in sessions: gt = s.get("ground_truth") or {} ref = gt.get("build_grid_reference", {}) result.append({ "session_id": s["id"], "name": s.get("name", ""), "filename": s.get("filename", ""), "document_category": s.get("document_category"), "pipeline": ref.get("pipeline"), "saved_at": ref.get("saved_at"), "summary": ref.get("summary", {}), }) return {"sessions": result, "count": len(result)} @router.post("/sessions/{session_id}/regression/run") async def run_single_regression(session_id: str): """Re-run build_grid for a single session and compare to ground truth.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") gt = session.get("ground_truth") or {} reference = gt.get("build_grid_reference") if not reference: raise HTTPException( status_code=400, detail="No ground truth reference found for this session", ) # Re-compute grid without persisting try: new_result = await _build_grid_core(session_id, session) except (ValueError, Exception) as e: return { "session_id": session_id, "name": session.get("name", ""), "status": "error", "error": str(e), } new_snapshot = _build_reference_snapshot(new_result) diff = compare_grids(reference, new_snapshot) logger.info( "Regression test session %s: %s (%d structural, %d cell diffs)", session_id, diff["status"], diff["summary"]["structural_changes"], sum(v for k, v in diff["summary"].items() if k != "structural_changes"), ) return { "session_id": session_id, "name": session.get("name", ""), "status": diff["status"], "diff": diff, "reference_summary": reference.get("summary", {}), "current_summary": new_snapshot.get("summary", {}), } @router.post("/regression/run") async def run_all_regressions( triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"), ): """Re-run build_grid for ALL ground-truth sessions and compare.""" start_time = time.monotonic() sessions = await list_ground_truth_sessions_db() if not sessions: return { "status": "pass", "message": "No ground truth sessions found", "results": [], "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, } results = [] passed = 0 failed = 0 errors = 0 for s in sessions: session_id = s["id"] gt = s.get("ground_truth") or {} reference = gt.get("build_grid_reference") if not reference: continue # Re-load full session (list query may not include all JSONB fields) full_session = await get_session_db(session_id) if not full_session: results.append({ "session_id": session_id, "name": s.get("name", ""), "status": "error", "error": "Session not found during re-load", }) errors += 1 continue try: new_result = await _build_grid_core(session_id, full_session) except (ValueError, Exception) as e: results.append({ "session_id": session_id, "name": s.get("name", ""), "status": "error", "error": str(e), }) errors += 1 continue new_snapshot = _build_reference_snapshot(new_result) diff = compare_grids(reference, new_snapshot) entry = { "session_id": session_id, "name": s.get("name", ""), "status": diff["status"], "diff_summary": diff["summary"], "reference_summary": reference.get("summary", {}), "current_summary": new_snapshot.get("summary", {}), } # Include full diffs only for failures (keep response compact) if diff["status"] == "fail": entry["structural_diffs"] = diff["structural_diffs"] entry["cell_diffs"] = diff["cell_diffs"] failed += 1 else: passed += 1 results.append(entry) overall = "pass" if failed == 0 and errors == 0 else "fail" duration_ms = int((time.monotonic() - start_time) * 1000) summary = { "total": len(results), "passed": passed, "failed": failed, "errors": errors, } logger.info( "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms", overall, passed, failed, errors, len(results), duration_ms, ) # Persist to DB run_id = await _persist_regression_run( status=overall, summary=summary, results=results, duration_ms=duration_ms, triggered_by=triggered_by, ) return { "status": overall, "run_id": run_id, "duration_ms": duration_ms, "results": results, "summary": summary, } @router.get("/regression/history") async def get_regression_history( limit: int = Query(20, ge=1, le=100), ): """Get recent regression run history from the database.""" try: await _init_regression_table() pool = await get_pool() async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT id, run_at, status, total, passed, failed, errors, duration_ms, triggered_by FROM regression_runs ORDER BY run_at DESC LIMIT $1 """, limit, ) return { "runs": [ { "id": str(row["id"]), "run_at": row["run_at"].isoformat() if row["run_at"] else None, "status": row["status"], "total": row["total"], "passed": row["passed"], "failed": row["failed"], "errors": row["errors"], "duration_ms": row["duration_ms"], "triggered_by": row["triggered_by"], } for row in rows ], "count": len(rows), } except Exception as e: logger.warning("Failed to fetch regression history: %s", e) return {"runs": [], "count": 0, "error": str(e)} @router.get("/regression/history/{run_id}") async def get_regression_run_detail(run_id: str): """Get detailed results of a specific regression run.""" try: await _init_regression_table() pool = await get_pool() async with pool.acquire() as conn: row = await conn.fetchrow( "SELECT * FROM regression_runs WHERE id = $1", run_id, ) if not row: raise HTTPException(status_code=404, detail="Run not found") return { "id": str(row["id"]), "run_at": row["run_at"].isoformat() if row["run_at"] else None, "status": row["status"], "total": row["total"], "passed": row["passed"], "failed": row["failed"], "errors": row["errors"], "duration_ms": row["duration_ms"], "triggered_by": row["triggered_by"], "results": json.loads(row["results"]) if row["results"] else [], } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e))