Add Ground Truth regression test system for OCR pipeline
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 35s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m47s
CI / test-python-agent-core (push) Successful in 15s
CI / test-nodejs-website (push) Successful in 22s
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 35s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m47s
CI / test-python-agent-core (push) Successful in 15s
CI / test-nodejs-website (push) Successful in 22s
Extract _build_grid_core() from build_grid() endpoint for reuse. New ocr_pipeline_regression.py with endpoints to mark sessions as ground truth, list them, and run regression comparisons after code changes. Frontend button in StepGroundTruth.tsx to mark/update GT. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -43,6 +43,9 @@ export function StepGroundTruth({ sessionId, onNext }: StepGroundTruthProps) {
|
|||||||
const [drawingRegion, setDrawingRegion] = useState(false)
|
const [drawingRegion, setDrawingRegion] = useState(false)
|
||||||
const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null)
|
const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null)
|
||||||
const [dragEnd, setDragEnd] = useState<{ x: number; y: number } | null>(null)
|
const [dragEnd, setDragEnd] = useState<{ x: number; y: number } | null>(null)
|
||||||
|
const [isGroundTruth, setIsGroundTruth] = useState(false)
|
||||||
|
const [gtSaving, setGtSaving] = useState(false)
|
||||||
|
const [gtMessage, setGtMessage] = useState('')
|
||||||
|
|
||||||
const leftPanelRef = useRef<HTMLDivElement>(null)
|
const leftPanelRef = useRef<HTMLDivElement>(null)
|
||||||
const rightPanelRef = useRef<HTMLDivElement>(null)
|
const rightPanelRef = useRef<HTMLDivElement>(null)
|
||||||
@@ -86,6 +89,10 @@ export function StepGroundTruth({ sessionId, onNext }: StepGroundTruthProps) {
|
|||||||
: `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/original`,
|
: `${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/image/original`,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Check if session has ground truth reference
|
||||||
|
const gt = data.ground_truth
|
||||||
|
setIsGroundTruth(!!gt?.build_grid_reference)
|
||||||
|
|
||||||
// Load existing validation data
|
// Load existing validation data
|
||||||
const valResp = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/reconstruction/validation`)
|
const valResp = await fetch(`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/reconstruction/validation`)
|
||||||
if (valResp.ok) {
|
if (valResp.ok) {
|
||||||
@@ -196,6 +203,31 @@ export function StepGroundTruth({ sessionId, onNext }: StepGroundTruthProps) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark/update ground truth reference
|
||||||
|
const handleMarkGroundTruth = async () => {
|
||||||
|
if (!sessionId) return
|
||||||
|
setGtSaving(true)
|
||||||
|
setGtMessage('')
|
||||||
|
try {
|
||||||
|
const resp = await fetch(
|
||||||
|
`${KLAUSUR_API}/api/v1/ocr-pipeline/sessions/${sessionId}/mark-ground-truth`,
|
||||||
|
{ method: 'POST' }
|
||||||
|
)
|
||||||
|
if (!resp.ok) {
|
||||||
|
const body = await resp.text().catch(() => '')
|
||||||
|
throw new Error(`Ground Truth fehlgeschlagen (${resp.status}): ${body}`)
|
||||||
|
}
|
||||||
|
const data = await resp.json()
|
||||||
|
setIsGroundTruth(true)
|
||||||
|
setGtMessage(`Ground Truth gespeichert (${data.cells_saved} Zellen)`)
|
||||||
|
setTimeout(() => setGtMessage(''), 5000)
|
||||||
|
} catch (e) {
|
||||||
|
setGtMessage(e instanceof Error ? e.message : String(e))
|
||||||
|
} finally {
|
||||||
|
setGtSaving(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle manual region drawing on reconstruction
|
// Handle manual region drawing on reconstruction
|
||||||
const handleReconMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
const handleReconMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||||
if (!drawingRegion) return
|
if (!drawingRegion) return
|
||||||
@@ -570,8 +602,20 @@ export function StepGroundTruth({ sessionId, onNext }: StepGroundTruthProps) {
|
|||||||
<div className="text-sm text-gray-500 dark:text-gray-400">
|
<div className="text-sm text-gray-500 dark:text-gray-400">
|
||||||
{status === 'saved' && <span className="text-green-600 dark:text-green-400">Validierung gespeichert</span>}
|
{status === 'saved' && <span className="text-green-600 dark:text-green-400">Validierung gespeichert</span>}
|
||||||
{status === 'saving' && <span>Speichere...</span>}
|
{status === 'saving' && <span>Speichere...</span>}
|
||||||
|
{gtMessage && (
|
||||||
|
<span className={gtMessage.includes('fehlgeschlagen') ? 'text-red-500' : 'text-amber-600 dark:text-amber-400'}>
|
||||||
|
{gtMessage}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
|
<button
|
||||||
|
onClick={handleMarkGroundTruth}
|
||||||
|
disabled={gtSaving || status === 'saving'}
|
||||||
|
className="px-4 py-2 text-sm bg-amber-600 text-white rounded hover:bg-amber-700 disabled:opacity-50"
|
||||||
|
>
|
||||||
|
{gtSaving ? 'Speichere...' : isGroundTruth ? 'Ground Truth aktualisieren' : 'Als Ground Truth markieren'}
|
||||||
|
</button>
|
||||||
<button
|
<button
|
||||||
onClick={handleSave}
|
onClick={handleSave}
|
||||||
disabled={status === 'saving'}
|
disabled={status === 'saving'}
|
||||||
|
|||||||
@@ -745,42 +745,38 @@ def _filter_footer_words(
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Endpoints
|
# Core computation (used by build-grid endpoint and regression tests)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/build-grid")
|
async def _build_grid_core(session_id: str, session: dict) -> dict:
|
||||||
async def build_grid(session_id: str):
|
"""Core grid building logic — pure computation, no HTTP or DB side effects.
|
||||||
"""Build a structured, zone-aware grid from existing Kombi word results.
|
|
||||||
|
|
||||||
Requires that paddle-kombi or rapid-kombi has already been run on the session.
|
Args:
|
||||||
Uses the image for box detection and the word positions for grid structuring.
|
session_id: Session identifier (for logging and image loading).
|
||||||
|
session: Full session dict from get_session_db().
|
||||||
|
|
||||||
Returns a StructuredGrid with zones, each containing their own
|
Returns:
|
||||||
columns, rows, and cells — ready for the frontend Excel-like editor.
|
StructuredGrid result dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If session data is incomplete.
|
||||||
"""
|
"""
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
# 1. Load session and word results
|
# 1. Validate and load word results
|
||||||
session = await get_session_db(session_id)
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
||||||
|
|
||||||
word_result = session.get("word_result")
|
word_result = session.get("word_result")
|
||||||
if not word_result or not word_result.get("cells"):
|
if not word_result or not word_result.get("cells"):
|
||||||
raise HTTPException(
|
raise ValueError("No word results found. Run paddle-kombi or rapid-kombi first.")
|
||||||
status_code=400,
|
|
||||||
detail="No word results found. Run paddle-kombi or rapid-kombi first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
img_w = word_result.get("image_width", 0)
|
img_w = word_result.get("image_width", 0)
|
||||||
img_h = word_result.get("image_height", 0)
|
img_h = word_result.get("image_height", 0)
|
||||||
if not img_w or not img_h:
|
if not img_w or not img_h:
|
||||||
raise HTTPException(status_code=400, detail="Missing image dimensions in word_result")
|
raise ValueError("Missing image dimensions in word_result")
|
||||||
|
|
||||||
# 2. Flatten all word boxes from cells
|
# 2. Flatten all word boxes from cells
|
||||||
all_words = _flatten_word_boxes(word_result["cells"])
|
all_words = _flatten_word_boxes(word_result["cells"])
|
||||||
if not all_words:
|
if not all_words:
|
||||||
raise HTTPException(status_code=400, detail="No word boxes found in cells")
|
raise ValueError("No word boxes found in cells")
|
||||||
|
|
||||||
logger.info("build-grid session %s: %d words from %d cells",
|
logger.info("build-grid session %s: %d words from %d cells",
|
||||||
session_id, len(all_words), len(word_result["cells"]))
|
session_id, len(all_words), len(word_result["cells"]))
|
||||||
@@ -1313,14 +1309,45 @@ async def build_grid(session_id: str):
|
|||||||
"duration_seconds": round(duration, 2),
|
"duration_seconds": round(duration, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 7. Persist to DB
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/build-grid")
|
||||||
|
async def build_grid(session_id: str):
|
||||||
|
"""Build a structured, zone-aware grid from existing Kombi word results.
|
||||||
|
|
||||||
|
Requires that paddle-kombi or rapid-kombi has already been run on the session.
|
||||||
|
Uses the image for box detection and the word positions for grid structuring.
|
||||||
|
|
||||||
|
Returns a StructuredGrid with zones, each containing their own
|
||||||
|
columns, rows, and cells — ready for the frontend Excel-like editor.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await _build_grid_core(session_id, session)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
await update_session_db(session_id, grid_editor_result=result)
|
await update_session_db(session_id, grid_editor_result=result)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
||||||
"%d boxes in %.2fs",
|
"%d boxes in %.2fs",
|
||||||
session_id, len(zones_data), total_columns, total_rows,
|
session_id,
|
||||||
total_cells, boxes_detected, duration,
|
len(result.get("zones", [])),
|
||||||
|
result.get("summary", {}).get("total_columns", 0),
|
||||||
|
result.get("summary", {}).get("total_rows", 0),
|
||||||
|
result.get("summary", {}).get("total_cells", 0),
|
||||||
|
result.get("boxes_detected", 0),
|
||||||
|
result.get("duration_seconds", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from ocr_pipeline_ocr_merge import (
|
|||||||
)
|
)
|
||||||
from ocr_pipeline_postprocess import router as _postprocess_router
|
from ocr_pipeline_postprocess import router as _postprocess_router
|
||||||
from ocr_pipeline_auto import router as _auto_router
|
from ocr_pipeline_auto import router as _auto_router
|
||||||
|
from ocr_pipeline_regression import router as _regression_router
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Composite router (used by main.py)
|
# Composite router (used by main.py)
|
||||||
@@ -59,3 +60,4 @@ router.include_router(_words_router)
|
|||||||
router.include_router(_ocr_merge_router)
|
router.include_router(_ocr_merge_router)
|
||||||
router.include_router(_postprocess_router)
|
router.include_router(_postprocess_router)
|
||||||
router.include_router(_auto_router)
|
router.include_router(_auto_router)
|
||||||
|
router.include_router(_regression_router)
|
||||||
|
|||||||
367
klausur-service/backend/ocr_pipeline_regression.py
Normal file
367
klausur-service/backend/ocr_pipeline_regression.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Regression Tests — Ground Truth comparison system.
|
||||||
|
|
||||||
|
Allows marking sessions as "ground truth" and re-running build_grid()
|
||||||
|
to detect regressions after code changes.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from grid_editor_api import _build_grid_core
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
get_session_db,
|
||||||
|
list_ground_truth_sessions_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract a flat list of cells from a grid_editor_result for comparison.
|
||||||
|
|
||||||
|
Only keeps fields relevant for comparison: cell_id, row_index, col_index,
|
||||||
|
col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
|
||||||
|
"""
|
||||||
|
cells = []
|
||||||
|
for zone in grid_result.get("zones", []):
|
||||||
|
for cell in zone.get("cells", []):
|
||||||
|
cells.append({
|
||||||
|
"cell_id": cell.get("cell_id", ""),
|
||||||
|
"row_index": cell.get("row_index"),
|
||||||
|
"col_index": cell.get("col_index"),
|
||||||
|
"col_type": cell.get("col_type", ""),
|
||||||
|
"text": cell.get("text", ""),
|
||||||
|
})
|
||||||
|
return cells
|
||||||
|
|
||||||
|
|
||||||
|
def _build_reference_snapshot(grid_result: dict) -> dict:
|
||||||
|
"""Build a ground-truth reference snapshot from a grid_editor_result."""
|
||||||
|
cells = _extract_cells_for_comparison(grid_result)
|
||||||
|
|
||||||
|
total_zones = len(grid_result.get("zones", []))
|
||||||
|
total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
|
||||||
|
total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"saved_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"version": 1,
|
||||||
|
"summary": {
|
||||||
|
"total_zones": total_zones,
|
||||||
|
"total_columns": total_columns,
|
||||||
|
"total_rows": total_rows,
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"cells": cells,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compare_grids(reference: dict, current: dict) -> dict:
|
||||||
|
"""Compare a reference grid snapshot with a newly computed one.
|
||||||
|
|
||||||
|
Returns a diff report with:
|
||||||
|
- status: "pass" or "fail"
|
||||||
|
- structural_diffs: changes in zone/row/column counts
|
||||||
|
- cell_diffs: list of individual cell changes
|
||||||
|
"""
|
||||||
|
ref_summary = reference.get("summary", {})
|
||||||
|
cur_summary = current.get("summary", {})
|
||||||
|
|
||||||
|
structural_diffs = []
|
||||||
|
for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
|
||||||
|
ref_val = ref_summary.get(key, 0)
|
||||||
|
cur_val = cur_summary.get(key, 0)
|
||||||
|
if ref_val != cur_val:
|
||||||
|
structural_diffs.append({
|
||||||
|
"field": key,
|
||||||
|
"reference": ref_val,
|
||||||
|
"current": cur_val,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build cell lookup by cell_id
|
||||||
|
ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
|
||||||
|
cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
|
||||||
|
|
||||||
|
cell_diffs: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Check for missing cells (in reference but not in current)
|
||||||
|
for cell_id in ref_cells:
|
||||||
|
if cell_id not in cur_cells:
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "cell_missing",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference_text": ref_cells[cell_id].get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for added cells (in current but not in reference)
|
||||||
|
for cell_id in cur_cells:
|
||||||
|
if cell_id not in ref_cells:
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "cell_added",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"current_text": cur_cells[cell_id].get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for changes in shared cells
|
||||||
|
for cell_id in ref_cells:
|
||||||
|
if cell_id not in cur_cells:
|
||||||
|
continue
|
||||||
|
ref_cell = ref_cells[cell_id]
|
||||||
|
cur_cell = cur_cells[cell_id]
|
||||||
|
|
||||||
|
if ref_cell.get("text", "") != cur_cell.get("text", ""):
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "text_change",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference": ref_cell.get("text", ""),
|
||||||
|
"current": cur_cell.get("text", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
|
||||||
|
cell_diffs.append({
|
||||||
|
"type": "col_type_change",
|
||||||
|
"cell_id": cell_id,
|
||||||
|
"reference": ref_cell.get("col_type", ""),
|
||||||
|
"current": cur_cell.get("col_type", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
status = "pass" if not structural_diffs and not cell_diffs else "fail"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"structural_diffs": structural_diffs,
|
||||||
|
"cell_diffs": cell_diffs,
|
||||||
|
"summary": {
|
||||||
|
"structural_changes": len(structural_diffs),
|
||||||
|
"cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
|
||||||
|
"cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
|
||||||
|
"text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
|
||||||
|
"col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/mark-ground-truth")
|
||||||
|
async def mark_ground_truth(session_id: str):
|
||||||
|
"""Save the current build-grid result as ground-truth reference."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
grid_result = session.get("grid_editor_result")
|
||||||
|
if not grid_result or not grid_result.get("zones"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No grid_editor_result found. Run build-grid first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
reference = _build_reference_snapshot(grid_result)
|
||||||
|
|
||||||
|
# Merge into existing ground_truth JSONB
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
gt["build_grid_reference"] = reference
|
||||||
|
await update_session_db(session_id, ground_truth=gt)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Ground truth marked for session %s: %d cells",
|
||||||
|
session_id, len(reference["cells"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"session_id": session_id,
|
||||||
|
"cells_saved": len(reference["cells"]),
|
||||||
|
"summary": reference["summary"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}/mark-ground-truth")
|
||||||
|
async def unmark_ground_truth(session_id: str):
|
||||||
|
"""Remove the ground-truth reference from a session."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
if "build_grid_reference" not in gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No ground truth reference found")
|
||||||
|
|
||||||
|
del gt["build_grid_reference"]
|
||||||
|
await update_session_db(session_id, ground_truth=gt)
|
||||||
|
|
||||||
|
logger.info("Ground truth removed for session %s", session_id)
|
||||||
|
return {"status": "ok", "session_id": session_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ground-truth-sessions")
|
||||||
|
async def list_ground_truth_sessions():
|
||||||
|
"""List all sessions that have a ground-truth reference."""
|
||||||
|
sessions = await list_ground_truth_sessions_db()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for s in sessions:
|
||||||
|
gt = s.get("ground_truth") or {}
|
||||||
|
ref = gt.get("build_grid_reference", {})
|
||||||
|
result.append({
|
||||||
|
"session_id": s["id"],
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"filename": s.get("filename", ""),
|
||||||
|
"saved_at": ref.get("saved_at"),
|
||||||
|
"summary": ref.get("summary", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"sessions": result, "count": len(result)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/regression/run")
|
||||||
|
async def run_single_regression(session_id: str):
|
||||||
|
"""Re-run build_grid for a single session and compare to ground truth."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
gt = session.get("ground_truth") or {}
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
if not reference:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No ground truth reference found for this session",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-compute grid without persisting
|
||||||
|
try:
|
||||||
|
new_result = await _build_grid_core(session_id, session)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
new_snapshot = _build_reference_snapshot(new_result)
|
||||||
|
diff = compare_grids(reference, new_snapshot)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Regression test session %s: %s (%d structural, %d cell diffs)",
|
||||||
|
session_id, diff["status"],
|
||||||
|
diff["summary"]["structural_changes"],
|
||||||
|
sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"status": diff["status"],
|
||||||
|
"diff": diff,
|
||||||
|
"reference_summary": reference.get("summary", {}),
|
||||||
|
"current_summary": new_snapshot.get("summary", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/regression/run")
|
||||||
|
async def run_all_regressions():
|
||||||
|
"""Re-run build_grid for ALL ground-truth sessions and compare."""
|
||||||
|
sessions = await list_ground_truth_sessions_db()
|
||||||
|
|
||||||
|
if not sessions:
|
||||||
|
return {
|
||||||
|
"status": "pass",
|
||||||
|
"message": "No ground truth sessions found",
|
||||||
|
"results": [],
|
||||||
|
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
for s in sessions:
|
||||||
|
session_id = s["id"]
|
||||||
|
gt = s.get("ground_truth") or {}
|
||||||
|
reference = gt.get("build_grid_reference")
|
||||||
|
if not reference:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Re-load full session (list query may not include all JSONB fields)
|
||||||
|
full_session = await get_session_db(session_id)
|
||||||
|
if not full_session:
|
||||||
|
results.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": "Session not found during re-load",
|
||||||
|
})
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_result = await _build_grid_core(session_id, full_session)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
results.append({
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_snapshot = _build_reference_snapshot(new_result)
|
||||||
|
diff = compare_grids(reference, new_snapshot)
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"name": s.get("name", ""),
|
||||||
|
"status": diff["status"],
|
||||||
|
"diff_summary": diff["summary"],
|
||||||
|
"reference_summary": reference.get("summary", {}),
|
||||||
|
"current_summary": new_snapshot.get("summary", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include full diffs only for failures (keep response compact)
|
||||||
|
if diff["status"] == "fail":
|
||||||
|
entry["structural_diffs"] = diff["structural_diffs"]
|
||||||
|
entry["cell_diffs"] = diff["cell_diffs"]
|
||||||
|
failed += 1
|
||||||
|
else:
|
||||||
|
passed += 1
|
||||||
|
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
overall = "pass" if failed == 0 and errors == 0 else "fail"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Regression suite: %s — %d passed, %d failed, %d errors (of %d)",
|
||||||
|
overall, passed, failed, errors, len(results),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": overall,
|
||||||
|
"results": results,
|
||||||
|
"summary": {
|
||||||
|
"total": len(results),
|
||||||
|
"passed": passed,
|
||||||
|
"failed": failed,
|
||||||
|
"errors": errors,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -270,6 +270,26 @@ async def get_sub_sessions(parent_session_id: str) -> List[Dict[str, Any]]:
|
|||||||
return [_row_to_dict(row) for row in rows]
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def list_ground_truth_sessions_db() -> List[Dict[str, Any]]:
|
||||||
|
"""List sessions that have a build_grid_reference in ground_truth."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT id, name, filename, status, current_step,
|
||||||
|
document_category, doc_type,
|
||||||
|
ground_truth,
|
||||||
|
parent_session_id, box_index,
|
||||||
|
created_at, updated_at
|
||||||
|
FROM ocr_pipeline_sessions
|
||||||
|
WHERE ground_truth IS NOT NULL
|
||||||
|
AND ground_truth::text LIKE '%build_grid_reference%'
|
||||||
|
AND parent_session_id IS NULL
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
async def delete_session_db(session_id: str) -> bool:
|
async def delete_session_db(session_id: str) -> bool:
|
||||||
"""Delete a session."""
|
"""Delete a session."""
|
||||||
pool = await get_pool()
|
pool = await get_pool()
|
||||||
|
|||||||
Reference in New Issue
Block a user