Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_regression_endpoints.py
Benjamin Admin bd4b956e3c [split-required] Split final 43 files (500-668 LOC) to complete refactoring
klausur-service (11 files):
- cv_gutter_repair, ocr_pipeline_regression, upload_api
- ocr_pipeline_sessions, smart_spell, nru_worksheet_generator
- ocr_pipeline_overlays, mail/aggregator, zeugnis_api
- cv_syllable_detect, self_rag

backend-lehrer (17 files):
- classroom_engine/suggestions, generators/quiz_generator
- worksheets_api, llm_gateway/comparison, state_engine_api
- classroom/models (→ 4 submodules), services/file_processor
- alerts_agent/api/wizard+digests+routes, content_generators/pdf
- classroom/routes/sessions, llm_gateway/inference
- classroom_engine/analytics, auth/keycloak_auth
- alerts_agent/processing/rule_engine, ai_processor/print_versions

agent-core (5 files):
- brain/memory_store, brain/knowledge_graph, brain/context_manager
- orchestrator/supervisor, sessions/session_manager

admin-lehrer (5 components):
- GridOverlay, StepGridReview, DevOpsPipelineSidebar
- DataFlowDiagram, sbom/wizard/page

website (2 files):
- DependencyMap, lehrer/abitur-archiv

Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-25 09:41:42 +02:00

422 lines
14 KiB
Python

"""
OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
Extracted from ocr_pipeline_regression.py for modularity.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from grid_editor_api import _build_grid_core
from ocr_pipeline_session_store import (
get_session_db,
list_ground_truth_sessions_db,
update_session_db,
)
from ocr_pipeline_regression_helpers import (
_build_reference_snapshot,
_init_regression_table,
_persist_regression_run,
compare_grids,
get_pool,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/mark-ground-truth")
async def mark_ground_truth(
session_id: str,
pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
):
"""Save the current build-grid result as ground-truth reference."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_result = session.get("grid_editor_result")
if not grid_result or not grid_result.get("zones"):
raise HTTPException(
status_code=400,
detail="No grid_editor_result found. Run build-grid first.",
)
# Auto-detect pipeline from word_result if not provided
if not pipeline:
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
pipeline = "kombi"
elif engine == "paddle_direct":
pipeline = "paddle-direct"
else:
pipeline = "pipeline"
reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
# Merge into existing ground_truth JSONB
gt = session.get("ground_truth") or {}
gt["build_grid_reference"] = reference
await update_session_db(session_id, ground_truth=gt, current_step=11)
# Compare with auto-snapshot if available (shows what the user corrected)
auto_snapshot = gt.get("auto_grid_snapshot")
correction_diff = None
if auto_snapshot:
correction_diff = compare_grids(auto_snapshot, reference)
logger.info(
"Ground truth marked for session %s: %d cells (corrections: %s)",
session_id,
len(reference["cells"]),
correction_diff["summary"] if correction_diff else "no auto-snapshot",
)
return {
"status": "ok",
"session_id": session_id,
"cells_saved": len(reference["cells"]),
"summary": reference["summary"],
"correction_diff": correction_diff,
}
@router.delete("/sessions/{session_id}/mark-ground-truth")
async def unmark_ground_truth(session_id: str):
"""Remove the ground-truth reference from a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
if "build_grid_reference" not in gt:
raise HTTPException(status_code=404, detail="No ground truth reference found")
del gt["build_grid_reference"]
await update_session_db(session_id, ground_truth=gt)
logger.info("Ground truth removed for session %s", session_id)
return {"status": "ok", "session_id": session_id}
@router.get("/sessions/{session_id}/correction-diff")
async def get_correction_diff(session_id: str):
"""Compare automatic OCR grid with manually corrected ground truth.
Returns a diff showing exactly which cells the user corrected,
broken down by col_type (english, german, ipa, etc.).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
auto_snapshot = gt.get("auto_grid_snapshot")
reference = gt.get("build_grid_reference")
if not auto_snapshot:
raise HTTPException(
status_code=404,
detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
)
if not reference:
raise HTTPException(
status_code=404,
detail="No ground truth reference found. Mark as ground truth first.",
)
diff = compare_grids(auto_snapshot, reference)
# Enrich with per-col_type breakdown
col_type_stats: Dict[str, Dict[str, int]] = {}
for cell_diff in diff.get("cell_diffs", []):
if cell_diff["type"] != "text_change":
continue
# Find col_type from reference cells
cell_id = cell_diff["cell_id"]
ref_cell = next(
(c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
None,
)
ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["corrected"] += 1
# Count total cells per col_type from reference
for cell in reference.get("cells", []):
ct = cell.get("col_type", "unknown")
if ct not in col_type_stats:
col_type_stats[ct] = {"total": 0, "corrected": 0}
col_type_stats[ct]["total"] += 1
# Calculate accuracy per col_type
for ct, stats in col_type_stats.items():
total = stats["total"]
corrected = stats["corrected"]
stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
diff["col_type_breakdown"] = col_type_stats
return diff
@router.get("/ground-truth-sessions")
async def list_ground_truth_sessions():
"""List all sessions that have a ground-truth reference."""
sessions = await list_ground_truth_sessions_db()
result = []
for s in sessions:
gt = s.get("ground_truth") or {}
ref = gt.get("build_grid_reference", {})
result.append({
"session_id": s["id"],
"name": s.get("name", ""),
"filename": s.get("filename", ""),
"document_category": s.get("document_category"),
"pipeline": ref.get("pipeline"),
"saved_at": ref.get("saved_at"),
"summary": ref.get("summary", {}),
})
return {"sessions": result, "count": len(result)}
@router.post("/sessions/{session_id}/regression/run")
async def run_single_regression(session_id: str):
"""Re-run build_grid for a single session and compare to ground truth."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
gt = session.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
raise HTTPException(
status_code=400,
detail="No ground truth reference found for this session",
)
# Re-compute grid without persisting
try:
new_result = await _build_grid_core(session_id, session)
except (ValueError, Exception) as e:
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": "error",
"error": str(e),
}
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
logger.info(
"Regression test session %s: %s (%d structural, %d cell diffs)",
session_id, diff["status"],
diff["summary"]["structural_changes"],
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
)
return {
"session_id": session_id,
"name": session.get("name", ""),
"status": diff["status"],
"diff": diff,
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
@router.post("/regression/run")
async def run_all_regressions(
triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
):
"""Re-run build_grid for ALL ground-truth sessions and compare."""
start_time = time.monotonic()
sessions = await list_ground_truth_sessions_db()
if not sessions:
return {
"status": "pass",
"message": "No ground truth sessions found",
"results": [],
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
}
results = []
passed = 0
failed = 0
errors = 0
for s in sessions:
session_id = s["id"]
gt = s.get("ground_truth") or {}
reference = gt.get("build_grid_reference")
if not reference:
continue
# Re-load full session (list query may not include all JSONB fields)
full_session = await get_session_db(session_id)
if not full_session:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": "Session not found during re-load",
})
errors += 1
continue
try:
new_result = await _build_grid_core(session_id, full_session)
except (ValueError, Exception) as e:
results.append({
"session_id": session_id,
"name": s.get("name", ""),
"status": "error",
"error": str(e),
})
errors += 1
continue
new_snapshot = _build_reference_snapshot(new_result)
diff = compare_grids(reference, new_snapshot)
entry = {
"session_id": session_id,
"name": s.get("name", ""),
"status": diff["status"],
"diff_summary": diff["summary"],
"reference_summary": reference.get("summary", {}),
"current_summary": new_snapshot.get("summary", {}),
}
# Include full diffs only for failures (keep response compact)
if diff["status"] == "fail":
entry["structural_diffs"] = diff["structural_diffs"]
entry["cell_diffs"] = diff["cell_diffs"]
failed += 1
else:
passed += 1
results.append(entry)
overall = "pass" if failed == 0 and errors == 0 else "fail"
duration_ms = int((time.monotonic() - start_time) * 1000)
summary = {
"total": len(results),
"passed": passed,
"failed": failed,
"errors": errors,
}
logger.info(
"Regression suite: %s%d passed, %d failed, %d errors (of %d) in %dms",
overall, passed, failed, errors, len(results), duration_ms,
)
# Persist to DB
run_id = await _persist_regression_run(
status=overall,
summary=summary,
results=results,
duration_ms=duration_ms,
triggered_by=triggered_by,
)
return {
"status": overall,
"run_id": run_id,
"duration_ms": duration_ms,
"results": results,
"summary": summary,
}
@router.get("/regression/history")
async def get_regression_history(
limit: int = Query(20, ge=1, le=100),
):
"""Get recent regression run history from the database."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, run_at, status, total, passed, failed, errors,
duration_ms, triggered_by
FROM regression_runs
ORDER BY run_at DESC
LIMIT $1
""",
limit,
)
return {
"runs": [
{
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
}
for row in rows
],
"count": len(rows),
}
except Exception as e:
logger.warning("Failed to fetch regression history: %s", e)
return {"runs": [], "count": 0, "error": str(e)}
@router.get("/regression/history/{run_id}")
async def get_regression_run_detail(run_id: str):
"""Get detailed results of a specific regression run."""
try:
await _init_regression_table()
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM regression_runs WHERE id = $1",
run_id,
)
if not row:
raise HTTPException(status_code=404, detail="Run not found")
return {
"id": str(row["id"]),
"run_at": row["run_at"].isoformat() if row["run_at"] else None,
"status": row["status"],
"total": row["total"],
"passed": row["passed"],
"failed": row["failed"],
"errors": row["errors"],
"duration_ms": row["duration_ms"],
"triggered_by": row["triggered_by"],
"results": json.loads(row["results"]) if row["results"] else [],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))