refactor: split ocr_pipeline_api.py (5426 lines) into 8 modules
Each module is under 1050 lines: - ocr_pipeline_common.py (354) - shared state, cache, models, helpers - ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type - ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns - ocr_pipeline_rows.py (348) - row detection, box-overlay helper - ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct - ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints - ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export - ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports router, _cache, and test-imported symbols for backward compatibility. No changes needed in main.py or tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
483
klausur-service/backend/ocr_pipeline_sessions.py
Normal file
483
klausur-service/backend/ocr_pipeline_sessions.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
OCR Pipeline Sessions API - Session management and image serving endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_api.py for modularity.
|
||||
Handles: CRUD for sessions, thumbnails, pipeline logs, categories,
|
||||
image serving (with overlay dispatch), and document type detection.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
create_ocr_image,
|
||||
detect_document_type,
|
||||
render_image_high_res,
|
||||
render_pdf_high_res,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
VALID_DOCUMENT_CATEGORIES,
|
||||
UpdateSessionRequest,
|
||||
_append_pipeline_log,
|
||||
_cache,
|
||||
_get_base_image_png,
|
||||
_get_cached,
|
||||
_load_session_to_cache,
|
||||
)
|
||||
from ocr_pipeline_overlays import render_overlay
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
delete_all_sessions_db,
|
||||
delete_session_db,
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
get_sub_sessions,
|
||||
list_sessions_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Management Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(include_sub_sessions: bool = False):
|
||||
"""List OCR pipeline sessions.
|
||||
|
||||
By default, sub-sessions (box regions) are hidden.
|
||||
Pass ?include_sub_sessions=true to show them.
|
||||
"""
|
||||
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||||
return {"sessions": sessions}
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_session(
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
):
|
||||
"""Upload a PDF or image file and create a pipeline session."""
|
||||
file_data = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
content_type = file.content_type or ""
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||||
|
||||
try:
|
||||
if is_pdf:
|
||||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||||
else:
|
||||
img_bgr = render_image_high_res(file_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||
|
||||
# Encode original as PNG bytes
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||
|
||||
original_png = png_buf.tobytes()
|
||||
session_name = name or filename
|
||||
|
||||
# Persist to DB
|
||||
await create_session_db(
|
||||
session_id=session_id,
|
||||
name=session_name,
|
||||
filename=filename,
|
||||
original_png=original_png,
|
||||
)
|
||||
|
||||
# Cache BGR array for immediate processing
|
||||
_cache[session_id] = {
|
||||
"id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"original_bgr": img_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"filename": filename,
|
||||
"name": session_name,
|
||||
"image_width": img_bgr.shape[1],
|
||||
"image_height": img_bgr.shape[0],
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
async def get_session_info(session_id: str):
|
||||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Get image dimensions from original PNG
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if original_png:
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||
else:
|
||||
img_w, img_h = 0, 0
|
||||
|
||||
result = {
|
||||
"session_id": session["id"],
|
||||
"filename": session.get("filename", ""),
|
||||
"name": session.get("name", ""),
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||
"current_step": session.get("current_step", 1),
|
||||
"document_category": session.get("document_category"),
|
||||
"doc_type": session.get("doc_type"),
|
||||
}
|
||||
|
||||
if session.get("orientation_result"):
|
||||
result["orientation_result"] = session["orientation_result"]
|
||||
if session.get("crop_result"):
|
||||
result["crop_result"] = session["crop_result"]
|
||||
if session.get("deskew_result"):
|
||||
result["deskew_result"] = session["deskew_result"]
|
||||
if session.get("dewarp_result"):
|
||||
result["dewarp_result"] = session["dewarp_result"]
|
||||
if session.get("column_result"):
|
||||
result["column_result"] = session["column_result"]
|
||||
if session.get("row_result"):
|
||||
result["row_result"] = session["row_result"]
|
||||
if session.get("word_result"):
|
||||
result["word_result"] = session["word_result"]
|
||||
if session.get("doc_type_result"):
|
||||
result["doc_type_result"] = session["doc_type_result"]
|
||||
|
||||
# Sub-session info
|
||||
if session.get("parent_session_id"):
|
||||
result["parent_session_id"] = session["parent_session_id"]
|
||||
result["box_index"] = session.get("box_index")
|
||||
else:
|
||||
# Check for sub-sessions
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
result["sub_sessions"] = [
|
||||
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||||
for s in subs
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}")
|
||||
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||||
"""Update session name and/or document category."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if req.name is not None:
|
||||
kwargs["name"] = req.name
|
||||
if req.document_category is not None:
|
||||
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||||
)
|
||||
kwargs["document_category"] = req.document_category
|
||||
if not kwargs:
|
||||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||
updated = await update_session_db(session_id, **kwargs)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, **kwargs}
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a session."""
|
||||
_cache.pop(session_id, None)
|
||||
deleted = await delete_session_db(session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "deleted": True}
|
||||
|
||||
|
||||
@router.delete("/sessions")
|
||||
async def delete_all_sessions():
|
||||
"""Delete ALL sessions (cleanup)."""
|
||||
_cache.clear()
|
||||
count = await delete_all_sessions_db()
|
||||
return {"deleted_count": count}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/create-box-sessions")
|
||||
async def create_box_sessions(session_id: str):
|
||||
"""Create sub-sessions for each detected box region.
|
||||
|
||||
Crops box regions from the cropped/dewarped image and creates
|
||||
independent sub-sessions that can be processed through the pipeline.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
if not column_result:
|
||||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
if not box_zones:
|
||||
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||||
|
||||
# Check for existing sub-sessions
|
||||
existing = await get_sub_sessions(session_id)
|
||||
if existing:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||||
"message": f"{len(existing)} sub-session(s) already exist",
|
||||
}
|
||||
|
||||
# Load base image
|
||||
base_png = await get_session_image(session_id, "cropped")
|
||||
if not base_png:
|
||||
base_png = await get_session_image(session_id, "dewarped")
|
||||
if not base_png:
|
||||
raise HTTPException(status_code=400, detail="No base image available")
|
||||
|
||||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
|
||||
parent_name = session.get("name", "Session")
|
||||
created = []
|
||||
|
||||
for i, zone in enumerate(box_zones):
|
||||
box = zone["box"]
|
||||
bx, by = box["x"], box["y"]
|
||||
bw, bh = box["width"], box["height"]
|
||||
|
||||
# Crop box region with small padding
|
||||
pad = 5
|
||||
y1 = max(0, by - pad)
|
||||
y2 = min(img.shape[0], by + bh + pad)
|
||||
x1 = max(0, bx - pad)
|
||||
x2 = min(img.shape[1], bx + bw + pad)
|
||||
crop = img[y1:y2, x1:x2]
|
||||
|
||||
# Encode as PNG
|
||||
success, png_buf = cv2.imencode(".png", crop)
|
||||
if not success:
|
||||
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||||
continue
|
||||
|
||||
sub_id = str(uuid.uuid4())
|
||||
sub_name = f"{parent_name} — Box {i + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=session.get("filename", "box-crop.png"),
|
||||
original_png=png_buf.tobytes(),
|
||||
parent_session_id=session_id,
|
||||
box_index=i,
|
||||
)
|
||||
|
||||
# Cache the BGR for immediate processing
|
||||
# Promote original to cropped so column/row/word detection finds it
|
||||
box_bgr = crop.copy()
|
||||
_cache[sub_id] = {
|
||||
"id": sub_id,
|
||||
"filename": session.get("filename", "box-crop.png"),
|
||||
"name": sub_name,
|
||||
"parent_session_id": session_id,
|
||||
"original_bgr": box_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": box_bgr,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": 1,
|
||||
}
|
||||
|
||||
created.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"box_index": i,
|
||||
"box": box,
|
||||
"image_width": crop.shape[1],
|
||||
"image_height": crop.shape[0],
|
||||
})
|
||||
|
||||
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||||
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"sub_sessions": created,
|
||||
"total": len(created),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/thumbnail")
|
||||
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||||
"""Return a small thumbnail of the original image."""
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if not original_png:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||
h, w = img.shape[:2]
|
||||
scale = size / max(h, w)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
_, png_bytes = cv2.imencode(".png", thumb)
|
||||
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||||
headers={"Cache-Control": "public, max-age=3600"})
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/pipeline-log")
|
||||
async def get_pipeline_log(session_id: str):
|
||||
"""Get the pipeline execution log for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def list_categories():
|
||||
"""List valid document categories."""
|
||||
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||
async def get_image(session_id: str, image_type: str):
|
||||
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||||
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||
if image_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||
|
||||
if image_type == "structure-overlay":
|
||||
return await render_overlay("structure", session_id)
|
||||
|
||||
if image_type == "columns-overlay":
|
||||
return await render_overlay("columns", session_id)
|
||||
|
||||
if image_type == "rows-overlay":
|
||||
return await render_overlay("rows", session_id)
|
||||
|
||||
if image_type == "words-overlay":
|
||||
return await render_overlay("words", session_id)
|
||||
|
||||
# Try cache first for fast serving
|
||||
cached = _cache.get(session_id)
|
||||
if cached:
|
||||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||
|
||||
# For binarized, check if we have it cached as PNG
|
||||
if image_type == "binarized" and cached.get("binarized_png"):
|
||||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||
|
||||
# Load from DB — for cropped/dewarped, fall back through the chain
|
||||
if image_type in ("cropped", "dewarped"):
|
||||
data = await _get_base_image_png(session_id)
|
||||
else:
|
||||
data = await get_session_image(session_id, image_type)
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||
|
||||
return Response(content=data, media_type="image/png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document Type Detection (between Dewarp and Columns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/detect-type")
|
||||
async def detect_type(session_id: str):
|
||||
"""Detect document type (vocab_table, full_text, generic_table).
|
||||
|
||||
Should be called after crop (clean image available).
|
||||
Falls back to dewarped if crop was skipped.
|
||||
Stores result in session for frontend to decide pipeline flow.
|
||||
"""
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(img_bgr)
|
||||
result = detect_document_type(ocr_img, img_bgr)
|
||||
duration = time.time() - t0
|
||||
|
||||
result_dict = {
|
||||
"doc_type": result.doc_type,
|
||||
"confidence": result.confidence,
|
||||
"pipeline": result.pipeline,
|
||||
"skip_steps": result.skip_steps,
|
||||
"features": result.features,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
doc_type=result.doc_type,
|
||||
doc_type_result=result_dict,
|
||||
)
|
||||
|
||||
cached["doc_type_result"] = result_dict
|
||||
|
||||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||
|
||||
await _append_pipeline_log(session_id, "detect_type", {
|
||||
"doc_type": result.doc_type,
|
||||
"pipeline": result.pipeline,
|
||||
"confidence": result.confidence,
|
||||
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **result_dict}
|
||||
Reference in New Issue
Block a user