Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_api.py
Benjamin Admin 2ad391e4e4
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 27s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m1s
CI / test-python-agent-core (push) Successful in 16s
CI / test-nodejs-website (push) Successful in 18s
feat: Feinabstimmung mit 7 Schiebereglern fuer Deskew/Dewarp
Neues aufklappbares Panel unter Entzerrung mit individuellen Reglern:
- 3 Rotations-Regler (P1 Iterative, P2 Word-Alignment, P3 Textline)
- 4 Scherungs-Regler (A-D Methoden) mit Radio-Auswahl
- Kombinierte Vorschau und Ground-Truth-Speicherung
- Backend: POST /sessions/{id}/adjust-combined Endpoint

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 18:22:33 +01:00

3450 lines
129 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
Zerlegt den OCR-Prozess in 8 einzelne Schritte:
1. Deskewing - Scan begradigen
2. Dewarping - Buchwoelbung entzerren
3. Spaltenerkennung - Unsichtbare Spalten finden
4. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen
5. Worterkennung - OCR mit Bounding Boxes
6. LLM-Korrektur - OCR-Fehler per LLM korrigieren
7. Seitenrekonstruktion - Seite nachbauen
8. Ground Truth Validierung - Gesamtpruefung
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import time
import uuid
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
DocumentTypeResult,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
analyze_layout,
analyze_layout_by_words,
build_cell_grid,
build_cell_grid_streaming,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
build_word_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
deskew_image_iterative,
deskew_two_pass,
detect_column_geometry,
detect_document_type,
detect_row_geometry,
expand_narrow_columns,
_apply_shear,
dewarp_image,
dewarp_image_manual,
llm_review_entries,
llm_review_entries_streaming,
render_image_high_res,
render_pdf_high_res,
)
from ocr_pipeline_session_store import (
create_session_db,
delete_all_sessions_db,
delete_session_db,
get_session_db,
get_session_image,
init_ocr_pipeline_tables,
list_sessions_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# In-memory cache for active sessions (BGR numpy arrays for processing)
# DB is source of truth, cache holds BGR arrays during active processing.
# ---------------------------------------------------------------------------
_cache: Dict[str, Dict[str, Any]] = {}
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
if session_id in _cache:
return _cache[session_id]
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
# Decode images from DB into BGR numpy arrays
for img_type, bgr_key in [
("original", "original_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
def _get_cached(session_id: str) -> Dict[str, Any]:
"""Get from cache or raise 404."""
entry = _cache.get(session_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
return entry
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class ManualDeskewRequest(BaseModel):
angle: float
class DeskewGroundTruthRequest(BaseModel):
is_correct: bool
corrected_angle: Optional[float] = None
notes: Optional[str] = None
class ManualDewarpRequest(BaseModel):
shear_degrees: float
class CombinedAdjustRequest(BaseModel):
rotation_degrees: float = 0.0
shear_degrees: float = 0.0
class DewarpGroundTruthRequest(BaseModel):
is_correct: bool
corrected_shear: Optional[float] = None
notes: Optional[str] = None
VALID_DOCUMENT_CATEGORIES = {
'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite',
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
}
class UpdateSessionRequest(BaseModel):
name: Optional[str] = None
document_category: Optional[str] = None
class ManualColumnsRequest(BaseModel):
columns: List[Dict[str, Any]]
class ColumnGroundTruthRequest(BaseModel):
is_correct: bool
corrected_columns: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class ManualRowsRequest(BaseModel):
rows: List[Dict[str, Any]]
class RowGroundTruthRequest(BaseModel):
is_correct: bool
corrected_rows: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
class RemoveHandwritingRequest(BaseModel):
method: str = "auto" # "auto" | "telea" | "ns"
target_ink: str = "all" # "all" | "colored" | "pencil"
dilation: int = 2 # mask dilation iterations (0-5)
use_source: str = "auto" # "original" | "deskewed" | "auto"
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions")
async def list_sessions():
"""List all OCR pipeline sessions."""
sessions = await list_sessions_db()
return {"sessions": sessions}
@router.post("/sessions")
async def create_session(
file: UploadFile = File(...),
name: Optional[str] = Form(None),
):
"""Upload a PDF or image file and create a pipeline session."""
file_data = await file.read()
filename = file.filename or "upload"
content_type = file.content_type or ""
session_id = str(uuid.uuid4())
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
try:
if is_pdf:
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
else:
img_bgr = render_image_high_res(file_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
# Encode original as PNG bytes
success, png_buf = cv2.imencode(".png", img_bgr)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image")
original_png = png_buf.tobytes()
session_name = name or filename
# Persist to DB
await create_session_db(
session_id=session_id,
name=session_name,
filename=filename,
original_png=original_png,
)
# Cache BGR array for immediate processing
_cache[session_id] = {
"id": session_id,
"filename": filename,
"name": session_name,
"original_bgr": img_bgr,
"deskewed_bgr": None,
"dewarped_bgr": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": 1,
}
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
return {
"session_id": session_id,
"filename": filename,
"name": session_name,
"image_width": img_bgr.shape[1],
"image_height": img_bgr.shape[0],
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
}
@router.get("/sessions/{session_id}")
async def get_session_info(session_id: str):
"""Get session info including deskew/dewarp/column results for step navigation."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get image dimensions from original PNG
original_png = await get_session_image(session_id, "original")
if original_png:
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
else:
img_w, img_h = 0, 0
result = {
"session_id": session["id"],
"filename": session.get("filename", ""),
"name": session.get("name", ""),
"image_width": img_w,
"image_height": img_h,
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
"current_step": session.get("current_step", 1),
"document_category": session.get("document_category"),
"doc_type": session.get("doc_type"),
}
if session.get("deskew_result"):
result["deskew_result"] = session["deskew_result"]
if session.get("dewarp_result"):
result["dewarp_result"] = session["dewarp_result"]
if session.get("column_result"):
result["column_result"] = session["column_result"]
if session.get("row_result"):
result["row_result"] = session["row_result"]
if session.get("word_result"):
result["word_result"] = session["word_result"]
if session.get("doc_type_result"):
result["doc_type_result"] = session["doc_type_result"]
return result
@router.put("/sessions/{session_id}")
async def update_session(session_id: str, req: UpdateSessionRequest):
"""Update session name and/or document category."""
kwargs: Dict[str, Any] = {}
if req.name is not None:
kwargs["name"] = req.name
if req.document_category is not None:
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
)
kwargs["document_category"] = req.document_category
if not kwargs:
raise HTTPException(status_code=400, detail="Nothing to update")
updated = await update_session_db(session_id, **kwargs)
if not updated:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, **kwargs}
@router.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
_cache.pop(session_id, None)
deleted = await delete_session_db(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "deleted": True}
@router.delete("/sessions")
async def delete_all_sessions():
"""Delete ALL sessions (cleanup)."""
_cache.clear()
count = await delete_all_sessions_db()
return {"deleted_count": count}
@router.get("/sessions/{session_id}/thumbnail")
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
"""Return a small thumbnail of the original image."""
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
arr = np.frombuffer(original_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
h, w = img.shape[:2]
scale = size / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
_, png_bytes = cv2.imencode(".png", thumb)
return Response(content=png_bytes.tobytes(), media_type="image/png",
headers={"Cache-Control": "public, max-age=3600"})
@router.get("/sessions/{session_id}/pipeline-log")
async def get_pipeline_log(session_id: str):
"""Get the pipeline execution log for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
@router.get("/categories")
async def list_categories():
"""List valid document categories."""
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
# ---------------------------------------------------------------------------
# Pipeline Log Helper
# ---------------------------------------------------------------------------
async def _append_pipeline_log(
session_id: str,
step_name: str,
metrics: Dict[str, Any],
success: bool = True,
duration_ms: Optional[int] = None,
):
"""Append a step entry to the session's pipeline_log JSONB."""
session = await get_session_db(session_id)
if not session:
return
log = session.get("pipeline_log") or {"steps": []}
if not isinstance(log, dict):
log = {"steps": []}
entry = {
"step": step_name,
"completed_at": datetime.utcnow().isoformat(),
"success": success,
"metrics": metrics,
}
if duration_ms is not None:
entry["duration_ms"] = duration_ms
log.setdefault("steps", []).append(entry)
await update_session_db(session_id, pipeline_log=log)
# ---------------------------------------------------------------------------
# Image Endpoints
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if image_type == "columns-overlay":
return await _get_columns_overlay(session_id)
if image_type == "rows-overlay":
return await _get_rows_overlay(session_id)
if image_type == "words-overlay":
return await _get_words_overlay(session_id)
# Try cache first for fast serving
cached = _cache.get(session_id)
if cached:
png_key = f"{image_type}_png" if image_type != "original" else None
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
# For binarized, check if we have it cached as PNG
if image_type == "binarized" and cached.get("binarized_png"):
return Response(content=cached["binarized_png"], media_type="image/png")
# Load from DB
data = await get_session_image(session_id, image_type)
if not data:
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
return Response(content=data, media_type="image/png")
# ---------------------------------------------------------------------------
# Deskew Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/deskew")
async def auto_deskew(session_id: str):
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
# Ensure session is in cache
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
# Two-pass deskew: iterative (±5°) + word-alignment residual check
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
# Also run individual methods for reporting (non-authoritative)
try:
_, angle_hough = deskew_image(img_bgr.copy())
except Exception:
angle_hough = 0.0
success_enc, png_orig = cv2.imencode(".png", img_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
angle_wa = 0.0
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
duration = time.time() - t0
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
)
# Encode as PNG
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = deskewed_png_buf.tobytes() if success else b""
# Create binarized version
binarized_png = None
try:
binarized = create_ocr_image(deskewed_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception as e:
logger.warning(f"Binarization failed: {e}")
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
deskew_result = {
"angle_hough": round(angle_hough, 3),
"angle_word_alignment": round(angle_wa, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"angle_applied": round(angle_applied, 3),
"method_used": method_used,
"confidence": round(confidence, 2),
"duration_seconds": round(duration, 2),
"two_pass_debug": two_pass_debug,
}
# Update cache
cached["deskewed_bgr"] = deskewed_bgr
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
"current_step": 2,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: deskew session {session_id}: "
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
f"textline={angle_textline:.2f} "
f"-> {method_used} total={angle_applied:.2f}")
await _append_pipeline_log(session_id, "deskew", {
"angle_applied": round(angle_applied, 3),
"angle_iterative": round(angle_iterative, 3),
"angle_residual": round(angle_residual, 3),
"angle_textline": round(angle_textline, 3),
"confidence": round(confidence, 2),
"method": method_used,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**deskew_result,
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
}
@router.post("/sessions/{session_id}/deskew/manual")
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
"""Apply a manual rotation angle to the original image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
angle = max(-5.0, min(5.0, req.angle))
h, w = img_bgr.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
success, png_buf = cv2.imencode(".png", rotated)
deskewed_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(rotated)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(angle, 3),
"method_used": "manual",
}
# Update cache
cached["deskewed_bgr"] = rotated
cached["binarized_png"] = binarized_png
cached["deskew_result"] = deskew_result
# Persist to DB
db_update = {
"deskewed_png": deskewed_png,
"deskew_result": deskew_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
return {
"session_id": session_id,
"angle_applied": round(angle, 3),
"method_used": "manual",
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
}
@router.post("/sessions/{session_id}/ground-truth/deskew")
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
"""Save ground truth feedback for the deskew step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_angle": req.corrected_angle,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"deskew_result": session.get("deskew_result"),
}
ground_truth["deskew"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
# Update cache
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Dewarp Endpoints
# ---------------------------------------------------------------------------
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
import re
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
import json
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
@router.post("/sessions/{session_id}/dewarp")
async def auto_dewarp(
session_id: str,
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
):
"""Detect and correct vertical shear on the deskewed image.
Methods:
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
- **cv**: CV ensemble only (same as ensemble)
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
"""
if method not in ("ensemble", "cv", "vlm"):
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
t0 = time.time()
if method == "vlm":
# Encode deskewed image to PNG for VLM
success, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
duration = time.time() - t0
# Encode as PNG
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(duration, 2),
"detections": dewarp_info.get("detections", []),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=3,
)
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
await _append_pipeline_log(session_id, "dewarp", {
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"method": dewarp_info["method"],
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**dewarp_result,
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/dewarp/manual")
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
"""Apply shear correction with a manual angle."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
if abs(shear_deg) < 0.001:
dewarped_bgr = deskewed_bgr
else:
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
success, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success else b""
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
# Persist to DB
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
)
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
return {
"session_id": session_id,
"shear_degrees": round(shear_deg, 3),
"method_used": "manual",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/adjust-combined")
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
"""Apply rotation + shear combined to the original image.
Used by the fine-tuning sliders to preview arbitrary rotation/shear
combinations without re-running the full deskew/dewarp pipeline.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
rotation = max(-15.0, min(15.0, req.rotation_degrees))
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
h, w = img_bgr.shape[:2]
result_bgr = img_bgr
# Step 1: Apply rotation
if abs(rotation) >= 0.001:
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
# Step 2: Apply shear
if abs(shear_deg) >= 0.001:
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
# Encode
success, png_buf = cv2.imencode(".png", result_bgr)
dewarped_png = png_buf.tobytes() if success else b""
# Binarize
binarized_png = None
try:
binarized = create_ocr_image(result_bgr)
success_bin, bin_buf = cv2.imencode(".png", binarized)
binarized_png = bin_buf.tobytes() if success_bin else None
except Exception:
pass
# Build combined result dicts
deskew_result = {
**(cached.get("deskew_result") or {}),
"angle_applied": round(rotation, 3),
"method_used": "manual_combined",
}
dewarp_result = {
**(cached.get("dewarp_result") or {}),
"method_used": "manual_combined",
"shear_degrees": round(shear_deg, 3),
}
# Update cache
cached["deskewed_bgr"] = result_bgr
cached["dewarped_bgr"] = result_bgr
cached["deskew_result"] = deskew_result
cached["dewarp_result"] = dewarp_result
# Persist to DB
db_update = {
"dewarped_png": dewarped_png,
"deskew_result": deskew_result,
"dewarp_result": dewarp_result,
}
if binarized_png:
db_update["binarized_png"] = binarized_png
db_update["deskewed_png"] = dewarped_png
await update_session_db(session_id, **db_update)
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
return {
"session_id": session_id,
"rotation_degrees": round(rotation, 3),
"shear_degrees": round(shear_deg, 3),
"method_used": "manual_combined",
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
}
@router.post("/sessions/{session_id}/ground-truth/dewarp")
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
"""Save ground truth feedback for the dewarp step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_shear": req.corrected_shear,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"dewarp_result": session.get("dewarp_result"),
}
ground_truth["dewarp"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
return {"session_id": session_id, "ground_truth": gt}
# ---------------------------------------------------------------------------
# Document Type Detection (between Dewarp and Columns)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/detect-type")
async def detect_type(session_id: str):
"""Detect document type (vocab_table, full_text, generic_table).
Should be called after dewarp (clean image available).
Stores result in session for frontend to decide pipeline flow.
"""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed first")
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
result = detect_document_type(ocr_img, dewarped_bgr)
duration = time.time() - t0
result_dict = {
"doc_type": result.doc_type,
"confidence": result.confidence,
"pipeline": result.pipeline,
"skip_steps": result.skip_steps,
"features": result.features,
"duration_seconds": round(duration, 2),
}
# Persist to DB
await update_session_db(
session_id,
doc_type=result.doc_type,
doc_type_result=result_dict,
)
cached["doc_type_result"] = result_dict
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
await _append_pipeline_log(session_id, "detect_type", {
"doc_type": result.doc_type,
"pipeline": result.pipeline,
"confidence": result.confidence,
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **result_dict}
# ---------------------------------------------------------------------------
# Column Detection Endpoints (Step 3)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/columns")
async def detect_columns(session_id: str):
"""Run column detection on the dewarped image."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection")
t0 = time.time()
# Binarized image for layout analysis
ocr_img = create_ocr_image(dewarped_bgr)
h, w = ocr_img.shape[:2]
# Phase A: Geometry detection (returns word_dicts + inv for reuse)
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
# Fallback to projection-based layout
layout_img = create_layout_image(dewarped_bgr)
regions = analyze_layout(layout_img, ocr_img)
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
# Cache intermediates for row detection (avoids second Tesseract run)
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
# Detect header/footer early so sub-column clustering ignores them
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
# Split sub-columns (e.g. page references) before classification
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
# Expand narrow columns (sub-columns are often very narrow)
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
# Phase B: Content-based classification
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
duration = time.time() - t0
columns = [asdict(r) for r in regions]
# Determine classification methods used
methods = list(set(
c.get("classification_method", "") for c in columns
if c.get("classification_method")
))
column_result = {
"columns": columns,
"classification_methods": methods,
"duration_seconds": round(duration, 2),
}
# Persist to DB — also invalidate downstream results (rows, words)
await update_session_db(
session_id,
column_result=column_result,
row_result=None,
word_result=None,
current_step=3,
)
# Update cache
cached["column_result"] = column_result
cached.pop("row_result", None)
cached.pop("word_result", None)
col_count = len([c for c in columns if c["type"].startswith("column")])
logger.info(f"OCR Pipeline: columns session {session_id}: "
f"{col_count} columns detected ({duration:.2f}s)")
img_w = dewarped_bgr.shape[1]
await _append_pipeline_log(session_id, "columns", {
"total_columns": len(columns),
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
"column_types": [c["type"] for c in columns],
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**column_result,
}
@router.post("/sessions/{session_id}/columns/manual")
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
"""Override detected columns with manual definitions."""
column_result = {
"columns": req.columns,
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None)
if session_id in _cache:
_cache[session_id]["column_result"] = column_result
_cache[session_id].pop("row_result", None)
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
f"{len(req.columns)} columns set")
return {"session_id": session_id, **column_result}
@router.post("/sessions/{session_id}/ground-truth/columns")
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
"""Save ground truth feedback for the column detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_columns": req.corrected_columns,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"column_result": session.get("column_result"),
}
ground_truth["columns"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/columns")
async def get_column_ground_truth(session_id: str):
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
columns_gt = ground_truth.get("columns")
if not columns_gt:
raise HTTPException(status_code=404, detail="No column ground truth saved")
return {
"session_id": session_id,
"columns_gt": columns_gt,
"columns_auto": session.get("column_result"),
}
async def _get_columns_overlay(session_id: str) -> Response:
"""Generate dewarped image with column borders drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=404, detail="No column data available")
# Load dewarped image
dewarped_png = await get_session_image(session_id, "dewarped")
if not dewarped_png:
raise HTTPException(status_code=404, detail="Dewarped image not available")
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for region types (BGR)
colors = {
"column_en": (255, 180, 0), # Blue
"column_de": (0, 200, 0), # Green
"column_example": (0, 140, 255), # Orange
"column_text": (200, 200, 0), # Cyan/Turquoise
"page_ref": (200, 0, 200), # Purple
"column_marker": (0, 0, 220), # Red
"column_ignore": (180, 180, 180), # Light Gray
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for col in column_result["columns"]:
x, y = col["x"], col["y"]
w, h = col["width"], col["height"]
color = colors.get(col.get("type", ""), (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
# Label with confidence
label = col.get("type", "unknown").replace("column_", "").upper()
conf = col.get("classification_confidence")
if conf is not None and conf < 1.0:
label = f"{label} {int(conf * 100)}%"
cv2.putText(img, label, (x + 10, y + 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
# Blend overlay at 20% opacity
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
# ---------------------------------------------------------------------------
# Row Detection Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/rows")
async def detect_rows(session_id: str):
"""Run row detection on the dewarped image using horizontal gap analysis."""
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="Dewarp must be completed before row detection")
t0 = time.time()
# Try to reuse cached word_dicts and inv from column detection
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
# Not cached — run column geometry to get intermediates
ocr_img = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
else:
left_x, right_x, top_y, bottom_y = content_bounds
# Run row detection
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
duration = time.time() - t0
# Build serializable result (exclude words to keep payload small)
rows_data = []
for r in rows:
rows_data.append({
"index": r.index,
"x": r.x,
"y": r.y,
"width": r.width,
"height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
})
type_counts = {}
for r in rows:
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
row_result = {
"rows": rows_data,
"summary": type_counts,
"total_rows": len(rows),
"duration_seconds": round(duration, 2),
}
# Persist to DB — also invalidate word_result since rows changed
await update_session_db(
session_id,
row_result=row_result,
word_result=None,
current_step=4,
)
cached["row_result"] = row_result
cached.pop("word_result", None)
logger.info(f"OCR Pipeline: rows session {session_id}: "
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
content_rows = sum(1 for r in rows if r.row_type == "content")
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
await _append_pipeline_log(session_id, "rows", {
"total_rows": len(rows),
"content_rows": content_rows,
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
"avg_row_height_px": avg_height,
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**row_result,
}
@router.post("/sessions/{session_id}/rows/manual")
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
"""Override detected rows with manual definitions."""
row_result = {
"rows": req.rows,
"total_rows": len(req.rows),
"duration_seconds": 0,
"method": "manual",
}
await update_session_db(session_id, row_result=row_result, word_result=None)
if session_id in _cache:
_cache[session_id]["row_result"] = row_result
_cache[session_id].pop("word_result", None)
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
f"{len(req.rows)} rows set")
return {"session_id": session_id, **row_result}
@router.post("/sessions/{session_id}/ground-truth/rows")
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
"""Save ground truth feedback for the row detection step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_rows": req.corrected_rows,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"row_result": session.get("row_result"),
}
ground_truth["rows"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/rows")
async def get_row_ground_truth(session_id: str):
"""Retrieve saved ground truth for row detection."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
rows_gt = ground_truth.get("rows")
if not rows_gt:
raise HTTPException(status_code=404, detail="No row ground truth saved")
return {
"session_id": session_id,
"rows_gt": rows_gt,
"rows_auto": session.get("row_result"),
}
# ---------------------------------------------------------------------------
# Word Recognition Endpoints (Step 5)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
):
"""Build word grid from columns × rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', or 'rapid'
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
stream: false (default) for JSON response, true for SSE streaming
"""
if session_id not in _cache:
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
logger.warning("detect_words: dewarped_bgr is None for session %s (cache keys: %s)",
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
raise HTTPException(status_code=400, detail="Dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
raise HTTPException(status_code=400, detail="Column detection must be completed first")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Cell-First OCR (v2): no full-page word re-population needed.
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
# so populate from cached words if available (just for counting).
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
else:
top_y = min(r.y for r in row_geoms) if row_geoms else 0
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
if stream:
# Cell-First OCR v2: use batch-then-stream approach instead of
# per-cell streaming. The parallel ThreadPoolExecutor in
# build_cell_grid_v2 is much faster than sequential streaming.
return StreamingResponse(
_word_batch_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path ---
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
)
duration = time.time() - t0
# Layout detection
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Count content rows and columns for grid_shape
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len(columns_meta)
# Determine which engine was actually used
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
# Grid result (always generic)
word_result = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
# For vocab layout: map cells 1:1 to vocab entries (row→entry).
# No content shuffling — each cell stays at its detected position.
if is_vocab:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=5,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
await _append_pipeline_log(session_id, "words", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"low_confidence_count": word_result["summary"]["low_confidence"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
"entry_count": word_result.get("entry_count", 0),
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**word_result,
}
async def _word_batch_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that runs batch OCR (parallel) then streams results.
Unlike the old per-cell streaming, this uses build_cell_grid_v2 with
ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events.
The 'preparing' event keeps the connection alive during OCR processing.
"""
import asyncio
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
total_cells = n_content_rows * n_cols
# 1. Send meta event immediately
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# 2. Send preparing event (keepalive for proxy)
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
# 3. Run batch OCR in thread pool with periodic keepalive events.
# The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE
# connections after 30-60s. Send keepalive every 5s to prevent this.
loop = asyncio.get_event_loop()
ocr_future = loop.run_in_executor(
None,
lambda: build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
),
)
# Send keepalive events every 5 seconds while OCR runs
keepalive_count = 0
while not ocr_future.done():
try:
cells, columns_meta = await asyncio.wait_for(
asyncio.shield(ocr_future), timeout=5.0,
)
break # OCR finished
except asyncio.TimeoutError:
keepalive_count += 1
elapsed = int(time.time() - t0)
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
ocr_future.cancel()
return
else:
cells, columns_meta = ocr_future.result()
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
return
# 4. Send columns meta
if columns_meta:
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
# 5. Stream all cells
for idx, cell in enumerate(cells):
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": idx + 1, "total": len(cells)},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# 6. Build final result and persist
duration = time.time() - t0
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
if is_vocab:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=5)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
# 7. Send complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Compute grid shape upfront for the meta event
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
# Determine layout
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Start streaming — first event: meta
columns_meta = None # will be set from first yield
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# Keepalive: send preparing event so proxy doesn't timeout during OCR init
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
# Stream cells one by one
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
last_keepalive = time.time()
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
# Send columns_used as part of first cell or update meta
meta_update = {
"type": "columns",
"columns_used": cols_meta,
}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done — build final result
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
# Post-OCR: remove rows where ALL cells are empty (inter-row gaps
# that had stray Tesseract artifacts giving word_count > 0).
rows_with_text: set = set()
for c in all_cells:
if c.get("text", "").strip():
rows_with_text.add(c["row_index"])
before_filter = len(all_cells)
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
if empty_rows_removed > 0:
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
word_result = {
"cells": all_cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(all_cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
# For vocab layout: map cells 1:1 to vocab entries (row→entry).
# No content shuffling — each cell stays at its detected position.
vocab_entries = None
if is_vocab:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=5,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
# Final complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
@router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_entries": req.corrected_entries,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"word_result": session.get("word_result"),
}
ground_truth["words"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
words_gt = ground_truth.get("words")
if not words_gt:
raise HTTPException(status_code=404, detail="No word ground truth saved")
return {
"session_id": session_id,
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# LLM Review Endpoints (Step 6)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=6)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=6)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=7)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Update cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in update_map:
cell["text"] = update_map[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
# Map cell_id pattern "R{row}_C{col}" to entry fields
for entry in entries:
row_idx = entry.get("row_index", -1)
# Check each field's cell
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
# Also try without zero-padding
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = update_map.get(cell_id) or update_map.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=7)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Reconstruction saved for session {session_id}: {updated_count} cells updated")
return {
"session_id": session_id,
"updated": updated_count,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows × columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
# Generate PDF with reportlab
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion Session {session_id[:8]}', level=1)
# Build header
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
# Header row
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
# Data rows
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")
# ---------------------------------------------------------------------------
# Step 8: Validation — Original vs. Reconstruction
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM.
Sends the original image to qwen2.5vl to find non-text, non-table
image areas, returning bounding boxes (in %) and descriptions.
"""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get original image bytes
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
# Build context from vocab entries for richer descriptions
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON array from response
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
# Normalize to ImageRegion format
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
# Save to ground_truth JSONB
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux.
Sends the prompt (with style suffix) to the mflux-service running
natively on the Mac Mini (Metal GPU required).
"""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
# Determine image size from region aspect ratio (snap to multiples of 64)
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
# Save to ground_truth
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8.
Stores notes, score, and preserves any detected/generated image regions.
Sets current_step = 8 to mark pipeline as complete.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=8)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Clears downstream results:
- from_step <= 1: deskew_result, dewarp_result, column_result, row_result, word_result
- from_step <= 2: dewarp_result, column_result, row_result, word_result
- from_step <= 3: column_result, row_result, word_result
- from_step <= 4: row_result, word_result
- from_step <= 5: word_result (cells, vocab_entries)
- from_step <= 6: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 7:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 7")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
if from_step <= 5:
update_kwargs["word_result"] = None
elif from_step == 6:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 4:
update_kwargs["row_result"] = None
if from_step <= 3:
update_kwargs["column_result"] = None
if from_step <= 2:
update_kwargs["dewarp_result"] = None
if from_step <= 1:
update_kwargs["deskew_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}
async def _get_rows_overlay(session_id: str) -> Response:
"""Generate dewarped image with row bands drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
row_result = session.get("row_result")
if not row_result or not row_result.get("rows"):
raise HTTPException(status_code=404, detail="No row data available")
# Load dewarped image
dewarped_png = await get_session_image(session_id, "dewarped")
if not dewarped_png:
raise HTTPException(status_code=404, detail="Dewarped image not available")
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
# Color map for row types (BGR)
row_colors = {
"content": (255, 180, 0), # Blue
"header": (128, 128, 128), # Gray
"footer": (128, 128, 128), # Gray
"margin_top": (100, 100, 100), # Dark Gray
"margin_bottom": (100, 100, 100), # Dark Gray
}
overlay = img.copy()
for row in row_result["rows"]:
x, y = row["x"], row["y"]
w, h = row["width"], row["height"]
row_type = row.get("row_type", "content")
color = row_colors.get(row_type, (200, 200, 200))
# Semi-transparent fill
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
# Solid border
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
# Label
idx = row.get("index", 0)
label = f"R{idx} {row_type.upper()}"
wc = row.get("word_count", 0)
if wc:
label = f"{label} ({wc}w)"
cv2.putText(img, label, (x + 5, y + 18),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Blend overlay at 15% opacity
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
async def _get_words_overlay(session_id: str) -> Response:
"""Generate dewarped image with cell grid drawn on it."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=404, detail="No word data available")
# Support both new cell-based and legacy entry-based formats
cells = word_result.get("cells")
if not cells and not word_result.get("entries"):
raise HTTPException(status_code=404, detail="No word data available")
# Load dewarped image
dewarped_png = await get_session_image(session_id, "dewarped")
if not dewarped_png:
raise HTTPException(status_code=404, detail="Dewarped image not available")
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=500, detail="Failed to decode image")
img_h, img_w = img.shape[:2]
overlay = img.copy()
if cells:
# New cell-based overlay: color by column index
col_palette = [
(255, 180, 0), # Blue (BGR)
(0, 200, 0), # Green
(0, 140, 255), # Orange
(200, 100, 200), # Purple
(200, 200, 0), # Cyan
(100, 200, 200), # Yellow-ish
]
for cell in cells:
bbox = cell.get("bbox_px", {})
cx = bbox.get("x", 0)
cy = bbox.get("y", 0)
cw = bbox.get("w", 0)
ch = bbox.get("h", 0)
if cw <= 0 or ch <= 0:
continue
col_idx = cell.get("col_index", 0)
color = col_palette[col_idx % len(col_palette)]
# Cell rectangle border
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
# Semi-transparent fill
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
# Cell-ID label (top-left corner)
cell_id = cell.get("cell_id", "")
cv2.putText(img, cell_id, (cx + 2, cy + 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
# Text label (bottom of cell)
text = cell.get("text", "")
if text:
conf = cell.get("confidence", 0)
if conf >= 70:
text_color = (0, 180, 0)
elif conf >= 50:
text_color = (0, 180, 220)
else:
text_color = (0, 0, 220)
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, cy + ch - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
else:
# Legacy fallback: entry-based overlay (for old sessions)
column_result = session.get("column_result")
row_result = session.get("row_result")
col_colors = {
"column_en": (255, 180, 0),
"column_de": (0, 200, 0),
"column_example": (0, 140, 255),
}
columns = []
if column_result and column_result.get("columns"):
columns = [c for c in column_result["columns"]
if c.get("type", "").startswith("column_")]
content_rows_data = []
if row_result and row_result.get("rows"):
content_rows_data = [r for r in row_result["rows"]
if r.get("row_type") == "content"]
for col in columns:
col_type = col.get("type", "")
color = col_colors.get(col_type, (200, 200, 200))
cx, cw = col["x"], col["width"]
for row in content_rows_data:
ry, rh = row["y"], row["height"]
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
entries = word_result["entries"]
entry_by_row: Dict[int, Dict] = {}
for entry in entries:
entry_by_row[entry.get("row_index", -1)] = entry
for row_idx, row in enumerate(content_rows_data):
entry = entry_by_row.get(row_idx)
if not entry:
continue
conf = entry.get("confidence", 0)
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
ry, rh = row["y"], row["height"]
for col in columns:
col_type = col.get("type", "")
cx, cw = col["x"], col["width"]
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
text = entry.get(field, "") if field else ""
if text:
label = text.replace('\n', ' ')[:30]
cv2.putText(img, label, (cx + 3, ry + rh - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
# Blend overlay at 10% opacity
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
success, result_png = cv2.imencode(".png", img)
if not success:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
# ---------------------------------------------------------------------------
# Handwriting Removal Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""
Remove handwriting from a session image using inpainting.
Steps:
1. Load source image (auto → deskewed if available, else original)
2. Detect handwriting mask (filtered by target_ink)
3. Dilate mask to cover stroke edges
4. Inpaint the image
5. Store result as clean_png in the session
Returns metadata including the URL to fetch the clean image.
"""
import time as _time
t0 = _time.monotonic()
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image (convert BGR ndarray → PNG bytes)
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}
# ---------------------------------------------------------------------------
# Auto-Mode Endpoint (Improvement 3)
# ---------------------------------------------------------------------------
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
import json as _json
payload = {"step": step, "status": status, **data}
return f"data: {_json.dumps(payload)}\n\n"
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew — straighten the scan
2. Dewarp — correct vertical shear (ensemble CV or VLM)
3. Columns — detect column layout
4. Rows — detect row layout
5. Words — OCR each cell
6. LLM review — correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# -----------------------------------------------------------------
# Step 1: Deskew
# -----------------------------------------------------------------
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
# Method 1: Hough lines
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
# Method 2: Word alignment
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
# Pick best method
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=2,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# -----------------------------------------------------------------
# Step 2: Dewarp
# -----------------------------------------------------------------
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# -----------------------------------------------------------------
# Step 3: Columns
# -----------------------------------------------------------------
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
dewarped_bgr = cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise ValueError("Dewarped image not available")
ocr_img = create_ocr_image(dewarped_bgr)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
if geo_result is None:
layout_img = create_layout_image(dewarped_bgr)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=4)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# -----------------------------------------------------------------
# Step 4: Rows
# -----------------------------------------------------------------
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
dewarped_bgr = cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is None:
raise ValueError("Column geometry detection failed — cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=5)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# -----------------------------------------------------------------
# Step 5: Words (OCR)
# -----------------------------------------------------------------
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
dewarped_bgr = cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=dewarped_bgr,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
if is_vocab:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=6)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# -----------------------------------------------------------------
# Step 6: LLM Review (optional)
# -----------------------------------------------------------------
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=7)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# -----------------------------------------------------------------
# Final event
# -----------------------------------------------------------------
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)