Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 25s
CI / test-go-edu-search (push) Successful in 25s
CI / test-python-klausur (push) Failing after 1m48s
CI / test-python-agent-core (push) Successful in 17s
CI / test-nodejs-website (push) Successful in 16s
Replace the trivial top_y/bottom_y threshold check with horizontal projection gap analysis that finds large whitespace gaps separating header/footer content from the main body. This correctly detects headers (e.g. "VOCABULARY" banners) and footers (page numbers) even when _find_content_bounds includes them in the content area. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1884 lines
65 KiB
Python
1884 lines
65 KiB
Python
"""
|
||
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
||
|
||
Zerlegt den OCR-Prozess in 8 einzelne Schritte:
|
||
1. Deskewing - Scan begradigen
|
||
2. Dewarping - Buchwoelbung entzerren
|
||
3. Spaltenerkennung - Unsichtbare Spalten finden
|
||
4. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen
|
||
5. Worterkennung - OCR mit Bounding Boxes
|
||
6. LLM-Korrektur - OCR-Fehler per LLM korrigieren
|
||
7. Seitenrekonstruktion - Seite nachbauen
|
||
8. Ground Truth Validierung - Gesamtpruefung
|
||
|
||
Lizenz: Apache 2.0
|
||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
import uuid
|
||
from dataclasses import asdict
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
|
||
from fastapi.responses import Response, StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from cv_vocab_pipeline import (
|
||
OLLAMA_REVIEW_MODEL,
|
||
PageRegion,
|
||
RowGeometry,
|
||
_cells_to_vocab_entries,
|
||
_fix_character_confusion,
|
||
_fix_phonetic_brackets,
|
||
analyze_layout,
|
||
analyze_layout_by_words,
|
||
build_cell_grid,
|
||
build_cell_grid_streaming,
|
||
build_word_grid,
|
||
classify_column_types,
|
||
create_layout_image,
|
||
create_ocr_image,
|
||
deskew_image,
|
||
deskew_image_by_word_alignment,
|
||
detect_column_geometry,
|
||
detect_row_geometry,
|
||
dewarp_image,
|
||
dewarp_image_manual,
|
||
llm_review_entries,
|
||
llm_review_entries_streaming,
|
||
render_image_high_res,
|
||
render_pdf_high_res,
|
||
)
|
||
from ocr_pipeline_session_store import (
|
||
create_session_db,
|
||
delete_session_db,
|
||
get_session_db,
|
||
get_session_image,
|
||
init_ocr_pipeline_tables,
|
||
list_sessions_db,
|
||
update_session_db,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||
# DB is source of truth, cache holds BGR arrays during active processing.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
|
||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
if session_id in _cache:
|
||
return _cache[session_id]
|
||
|
||
cache_entry: Dict[str, Any] = {
|
||
"id": session_id,
|
||
**session,
|
||
"original_bgr": None,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
}
|
||
|
||
# Decode images from DB into BGR numpy arrays
|
||
for img_type, bgr_key in [
|
||
("original", "original_bgr"),
|
||
("deskewed", "deskewed_bgr"),
|
||
("dewarped", "dewarped_bgr"),
|
||
]:
|
||
png_data = await get_session_image(session_id, img_type)
|
||
if png_data:
|
||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
cache_entry[bgr_key] = bgr
|
||
|
||
_cache[session_id] = cache_entry
|
||
return cache_entry
|
||
|
||
|
||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||
"""Get from cache or raise 404."""
|
||
entry = _cache.get(session_id)
|
||
if not entry:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||
return entry
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic Models
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ManualDeskewRequest(BaseModel):
|
||
angle: float
|
||
|
||
|
||
class DeskewGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_angle: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualDewarpRequest(BaseModel):
|
||
shear_degrees: float
|
||
|
||
|
||
class DewarpGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_shear: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class RenameSessionRequest(BaseModel):
|
||
name: str
|
||
|
||
|
||
class ManualColumnsRequest(BaseModel):
|
||
columns: List[Dict[str, Any]]
|
||
|
||
|
||
class ColumnGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualRowsRequest(BaseModel):
|
||
rows: List[Dict[str, Any]]
|
||
|
||
|
||
class RowGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Session Management Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions")
|
||
async def list_sessions():
|
||
"""List all OCR pipeline sessions."""
|
||
sessions = await list_sessions_db()
|
||
return {"sessions": sessions}
|
||
|
||
|
||
@router.post("/sessions")
|
||
async def create_session(
|
||
file: UploadFile = File(...),
|
||
name: Optional[str] = Form(None),
|
||
):
|
||
"""Upload a PDF or image file and create a pipeline session."""
|
||
file_data = await file.read()
|
||
filename = file.filename or "upload"
|
||
content_type = file.content_type or ""
|
||
|
||
session_id = str(uuid.uuid4())
|
||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||
|
||
try:
|
||
if is_pdf:
|
||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||
else:
|
||
img_bgr = render_image_high_res(file_data)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||
|
||
# Encode original as PNG bytes
|
||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||
|
||
original_png = png_buf.tobytes()
|
||
session_name = name or filename
|
||
|
||
# Persist to DB
|
||
await create_session_db(
|
||
session_id=session_id,
|
||
name=session_name,
|
||
filename=filename,
|
||
original_png=original_png,
|
||
)
|
||
|
||
# Cache BGR array for immediate processing
|
||
_cache[session_id] = {
|
||
"id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"original_bgr": img_bgr,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
"deskew_result": None,
|
||
"dewarp_result": None,
|
||
"ground_truth": {},
|
||
"current_step": 1,
|
||
}
|
||
|
||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"image_width": img_bgr.shape[1],
|
||
"image_height": img_bgr.shape[0],
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}")
|
||
async def get_session_info(session_id: str):
|
||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# Get image dimensions from original PNG
|
||
original_png = await get_session_image(session_id, "original")
|
||
if original_png:
|
||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||
else:
|
||
img_w, img_h = 0, 0
|
||
|
||
result = {
|
||
"session_id": session["id"],
|
||
"filename": session.get("filename", ""),
|
||
"name": session.get("name", ""),
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
"current_step": session.get("current_step", 1),
|
||
}
|
||
|
||
if session.get("deskew_result"):
|
||
result["deskew_result"] = session["deskew_result"]
|
||
if session.get("dewarp_result"):
|
||
result["dewarp_result"] = session["dewarp_result"]
|
||
if session.get("column_result"):
|
||
result["column_result"] = session["column_result"]
|
||
if session.get("row_result"):
|
||
result["row_result"] = session["row_result"]
|
||
if session.get("word_result"):
|
||
result["word_result"] = session["word_result"]
|
||
|
||
return result
|
||
|
||
|
||
@router.put("/sessions/{session_id}")
|
||
async def rename_session(session_id: str, req: RenameSessionRequest):
|
||
"""Rename a session."""
|
||
updated = await update_session_db(session_id, name=req.name)
|
||
if not updated:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "name": req.name}
|
||
|
||
|
||
@router.delete("/sessions/{session_id}")
|
||
async def delete_session(session_id: str):
|
||
"""Delete a session."""
|
||
_cache.pop(session_id, None)
|
||
deleted = await delete_session_db(session_id)
|
||
if not deleted:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "deleted": True}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Image Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||
async def get_image(session_id: str, image_type: str):
|
||
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
|
||
if image_type not in valid_types:
|
||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||
|
||
if image_type == "columns-overlay":
|
||
return await _get_columns_overlay(session_id)
|
||
|
||
if image_type == "rows-overlay":
|
||
return await _get_rows_overlay(session_id)
|
||
|
||
if image_type == "words-overlay":
|
||
return await _get_words_overlay(session_id)
|
||
|
||
# Try cache first for fast serving
|
||
cached = _cache.get(session_id)
|
||
if cached:
|
||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||
|
||
# For binarized, check if we have it cached as PNG
|
||
if image_type == "binarized" and cached.get("binarized_png"):
|
||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||
|
||
# Load from DB
|
||
data = await get_session_image(session_id, image_type)
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||
|
||
return Response(content=data, media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Deskew Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/deskew")
|
||
async def auto_deskew(session_id: str):
|
||
"""Run both deskew methods and pick the best one."""
|
||
# Ensure session is in cache
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("original_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Original image not available")
|
||
|
||
t0 = time.time()
|
||
|
||
# Method 1: Hough Lines
|
||
try:
|
||
deskewed_hough, angle_hough = deskew_image(img_bgr.copy())
|
||
except Exception as e:
|
||
logger.warning(f"Hough deskew failed: {e}")
|
||
deskewed_hough, angle_hough = img_bgr, 0.0
|
||
|
||
# Method 2: Word Alignment (needs image bytes)
|
||
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||
|
||
try:
|
||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||
except Exception as e:
|
||
logger.warning(f"Word alignment deskew failed: {e}")
|
||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||
|
||
duration = time.time() - t0
|
||
|
||
# Pick best method
|
||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||
method_used = "word_alignment"
|
||
angle_applied = angle_wa
|
||
wa_array = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||
deskewed_bgr = cv2.imdecode(wa_array, cv2.IMREAD_COLOR)
|
||
if deskewed_bgr is None:
|
||
deskewed_bgr = deskewed_hough
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
else:
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
deskewed_bgr = deskewed_hough
|
||
|
||
# Encode as PNG
|
||
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
||
|
||
# Create binarized version
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(deskewed_bgr)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception as e:
|
||
logger.warning(f"Binarization failed: {e}")
|
||
|
||
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
||
|
||
deskew_result = {
|
||
"angle_hough": round(angle_hough, 3),
|
||
"angle_word_alignment": round(angle_wa, 3),
|
||
"angle_applied": round(angle_applied, 3),
|
||
"method_used": method_used,
|
||
"confidence": round(confidence, 2),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = deskewed_bgr
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
"current_step": 2,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
||
f"hough={angle_hough:.2f} wa={angle_wa:.2f} -> {method_used} {angle_applied:.2f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**deskew_result,
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/deskew/manual")
|
||
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||
"""Apply a manual rotation angle to the original image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("original_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Original image not available")
|
||
|
||
angle = max(-5.0, min(5.0, req.angle))
|
||
|
||
h, w = img_bgr.shape[:2]
|
||
center = (w // 2, h // 2)
|
||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
||
flags=cv2.INTER_LINEAR,
|
||
borderMode=cv2.BORDER_REPLICATE)
|
||
|
||
success, png_buf = cv2.imencode(".png", rotated)
|
||
deskewed_png = png_buf.tobytes() if success else b""
|
||
|
||
# Binarize
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(rotated)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception:
|
||
pass
|
||
|
||
deskew_result = {
|
||
**(cached.get("deskew_result") or {}),
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = rotated
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
||
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
||
"""Save ground truth feedback for the deskew step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_angle": req.corrected_angle,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"deskew_result": session.get("deskew_result"),
|
||
}
|
||
ground_truth["deskew"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
# Update cache
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dewarp Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/dewarp")
|
||
async def auto_dewarp(session_id: str):
|
||
"""Detect and correct vertical shear on the deskewed image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
t0 = time.time()
|
||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||
duration = time.time() - t0
|
||
|
||
# Encode as PNG
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
"method_used": dewarp_info["method"],
|
||
"shear_degrees": dewarp_info["shear_degrees"],
|
||
"confidence": dewarp_info["confidence"],
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||
current_step=3,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
||
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**dewarp_result,
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/dewarp/manual")
|
||
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||
"""Apply shear correction with a manual angle."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
||
|
||
if abs(shear_deg) < 0.001:
|
||
dewarped_bgr = deskewed_bgr
|
||
else:
|
||
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
||
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
**(cached.get("dewarp_result") or {}),
|
||
"method_used": "manual",
|
||
"shear_degrees": round(shear_deg, 3),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"shear_degrees": round(shear_deg, 3),
|
||
"method_used": "manual",
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
||
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
||
"""Save ground truth feedback for the dewarp step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_shear": req.corrected_shear,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"dewarp_result": session.get("dewarp_result"),
|
||
}
|
||
ground_truth["dewarp"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Column Detection Endpoints (Step 3)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/columns")
|
||
async def detect_columns(session_id: str):
|
||
"""Run column detection on the dewarped image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection")
|
||
|
||
t0 = time.time()
|
||
|
||
# Binarized image for layout analysis
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
h, w = ocr_img.shape[:2]
|
||
|
||
# Phase A: Geometry detection (returns word_dicts + inv for reuse)
|
||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||
|
||
if geo_result is None:
|
||
# Fallback to projection-based layout
|
||
layout_img = create_layout_image(dewarped_bgr)
|
||
regions = analyze_layout(layout_img, ocr_img)
|
||
else:
|
||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
content_w = right_x - left_x
|
||
|
||
# Cache intermediates for row detection (avoids second Tesseract run)
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
# Phase B: Content-based classification
|
||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||
left_x=left_x, right_x=right_x, inv=inv)
|
||
|
||
duration = time.time() - t0
|
||
|
||
columns = [asdict(r) for r in regions]
|
||
|
||
# Determine classification methods used
|
||
methods = list(set(
|
||
c.get("classification_method", "") for c in columns
|
||
if c.get("classification_method")
|
||
))
|
||
|
||
column_result = {
|
||
"columns": columns,
|
||
"classification_methods": methods,
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB — also invalidate downstream results (rows, words)
|
||
await update_session_db(
|
||
session_id,
|
||
column_result=column_result,
|
||
row_result=None,
|
||
word_result=None,
|
||
current_step=3,
|
||
)
|
||
|
||
# Update cache
|
||
cached["column_result"] = column_result
|
||
cached.pop("row_result", None)
|
||
cached.pop("word_result", None)
|
||
|
||
col_count = len([c for c in columns if c["type"].startswith("column")])
|
||
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
||
f"{col_count} columns detected ({duration:.2f}s)")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**column_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/columns/manual")
|
||
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
||
"""Override detected columns with manual definitions."""
|
||
column_result = {
|
||
"columns": req.columns,
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, column_result=column_result,
|
||
row_result=None, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["column_result"] = column_result
|
||
_cache[session_id].pop("row_result", None)
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
||
f"{len(req.columns)} columns set")
|
||
|
||
return {"session_id": session_id, **column_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/columns")
|
||
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
||
"""Save ground truth feedback for the column detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_columns": req.corrected_columns,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"column_result": session.get("column_result"),
|
||
}
|
||
ground_truth["columns"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/columns")
|
||
async def get_column_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
columns_gt = ground_truth.get("columns")
|
||
if not columns_gt:
|
||
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"columns_gt": columns_gt,
|
||
"columns_auto": session.get("column_result"),
|
||
}
|
||
|
||
|
||
async def _get_columns_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with column borders drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise HTTPException(status_code=404, detail="No column data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for region types (BGR)
|
||
colors = {
|
||
"column_en": (255, 180, 0), # Blue
|
||
"column_de": (0, 200, 0), # Green
|
||
"column_example": (0, 140, 255), # Orange
|
||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||
"page_ref": (200, 0, 200), # Purple
|
||
"column_marker": (0, 0, 220), # Red
|
||
"column_ignore": (180, 180, 180), # Light Gray
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for col in column_result["columns"]:
|
||
x, y = col["x"], col["y"]
|
||
w, h = col["width"], col["height"]
|
||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||
|
||
# Label with confidence
|
||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||
conf = col.get("classification_confidence")
|
||
if conf is not None and conf < 1.0:
|
||
label = f"{label} {int(conf * 100)}%"
|
||
cv2.putText(img, label, (x + 10, y + 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||
|
||
# Blend overlay at 20% opacity
|
||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Row Detection Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/rows")
|
||
async def detect_rows(session_id: str):
|
||
"""Run row detection on the dewarped image using horizontal gap analysis."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before row detection")
|
||
|
||
t0 = time.time()
|
||
|
||
# Try to reuse cached word_dicts and inv from column detection
|
||
word_dicts = cached.get("_word_dicts")
|
||
inv = cached.get("_inv")
|
||
content_bounds = cached.get("_content_bounds")
|
||
|
||
if word_dicts is None or inv is None or content_bounds is None:
|
||
# Not cached — run column geometry to get intermediates
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||
if geo_result is None:
|
||
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
else:
|
||
left_x, right_x, top_y, bottom_y = content_bounds
|
||
|
||
# Run row detection
|
||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||
duration = time.time() - t0
|
||
|
||
# Build serializable result (exclude words to keep payload small)
|
||
rows_data = []
|
||
for r in rows:
|
||
rows_data.append({
|
||
"index": r.index,
|
||
"x": r.x,
|
||
"y": r.y,
|
||
"width": r.width,
|
||
"height": r.height,
|
||
"word_count": r.word_count,
|
||
"row_type": r.row_type,
|
||
"gap_before": r.gap_before,
|
||
})
|
||
|
||
type_counts = {}
|
||
for r in rows:
|
||
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||
|
||
row_result = {
|
||
"rows": rows_data,
|
||
"summary": type_counts,
|
||
"total_rows": len(rows),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB — also invalidate word_result since rows changed
|
||
await update_session_db(
|
||
session_id,
|
||
row_result=row_result,
|
||
word_result=None,
|
||
current_step=4,
|
||
)
|
||
|
||
cached["row_result"] = row_result
|
||
cached.pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**row_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/rows/manual")
|
||
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||
"""Override detected rows with manual definitions."""
|
||
row_result = {
|
||
"rows": req.rows,
|
||
"total_rows": len(req.rows),
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, row_result=row_result, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["row_result"] = row_result
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||
f"{len(req.rows)} rows set")
|
||
|
||
return {"session_id": session_id, **row_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||
"""Save ground truth feedback for the row detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_rows": req.corrected_rows,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"row_result": session.get("row_result"),
|
||
}
|
||
ground_truth["rows"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||
async def get_row_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for row detection."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
rows_gt = ground_truth.get("rows")
|
||
if not rows_gt:
|
||
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"rows_gt": rows_gt,
|
||
"rows_auto": session.get("row_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Word Recognition Endpoints (Step 5)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/words")
|
||
async def detect_words(
|
||
session_id: str,
|
||
request: Request,
|
||
engine: str = "auto",
|
||
pronunciation: str = "british",
|
||
stream: bool = False,
|
||
):
|
||
"""Build word grid from columns × rows, OCR each cell.
|
||
|
||
Query params:
|
||
engine: 'auto' (default), 'tesseract', or 'rapid'
|
||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
"""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before word detection")
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||
if not row_result or not row_result.get("rows"):
|
||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||
|
||
# Convert column dicts back to PageRegion objects
|
||
col_regions = [
|
||
PageRegion(
|
||
type=c["type"],
|
||
x=c["x"], y=c["y"],
|
||
width=c["width"], height=c["height"],
|
||
classification_confidence=c.get("classification_confidence", 1.0),
|
||
classification_method=c.get("classification_method", ""),
|
||
)
|
||
for c in column_result["columns"]
|
||
]
|
||
|
||
# Convert row dicts back to RowGeometry objects
|
||
row_geoms = [
|
||
RowGeometry(
|
||
index=r["index"],
|
||
x=r["x"], y=r["y"],
|
||
width=r["width"], height=r["height"],
|
||
word_count=r.get("word_count", 0),
|
||
words=[],
|
||
row_type=r.get("row_type", "content"),
|
||
gap_before=r.get("gap_before", 0),
|
||
)
|
||
for r in row_result["rows"]
|
||
]
|
||
|
||
# Re-populate row.words from cached full-page Tesseract words.
|
||
# Word-lookup in _ocr_single_cell needs these to avoid re-running OCR.
|
||
word_dicts = cached.get("_word_dicts")
|
||
if word_dicts is None:
|
||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||
if geo_result is not None:
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
if word_dicts:
|
||
# words['top'] is relative to content-ROI top_y.
|
||
# row.y is absolute. Convert: row_y_rel = row.y - top_y.
|
||
content_bounds = cached.get("_content_bounds")
|
||
if content_bounds:
|
||
_lx, _rx, top_y, _by = content_bounds
|
||
else:
|
||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||
|
||
for row in row_geoms:
|
||
row_y_rel = row.y - top_y
|
||
row_bottom_rel = row_y_rel + row.height
|
||
row.words = [
|
||
w for w in word_dicts
|
||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||
]
|
||
row.word_count = len(row.words)
|
||
|
||
if stream:
|
||
return StreamingResponse(
|
||
_word_stream_generator(
|
||
session_id, cached, col_regions, row_geoms,
|
||
dewarped_bgr, engine, pronunciation, request,
|
||
),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# --- Non-streaming path (unchanged) ---
|
||
t0 = time.time()
|
||
|
||
# Create binarized OCR image (for Tesseract)
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Build generic cell grid
|
||
cells, columns_meta = build_cell_grid(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
)
|
||
duration = time.time() - t0
|
||
|
||
# Layout detection
|
||
col_types = {c['type'] for c in columns_meta}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Count content rows and columns for grid_shape
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
n_cols = len(columns_meta)
|
||
|
||
# Determine which engine was actually used
|
||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||
|
||
# Grid result (always generic)
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# For vocab layout: map cells 1:1 to vocab entries (row→entry).
|
||
# No content shuffling — each cell stays at its detected position.
|
||
if is_vocab:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=5,
|
||
)
|
||
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**word_result,
|
||
}
|
||
|
||
|
||
async def _word_stream_generator(
|
||
session_id: str,
|
||
cached: Dict[str, Any],
|
||
col_regions: List[PageRegion],
|
||
row_geoms: List[RowGeometry],
|
||
dewarped_bgr: np.ndarray,
|
||
engine: str,
|
||
pronunciation: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||
t0 = time.time()
|
||
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Compute grid shape upfront for the meta event
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
|
||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||
|
||
# Determine layout
|
||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Start streaming — first event: meta
|
||
columns_meta = None # will be set from first yield
|
||
total_cells = n_content_rows * n_cols
|
||
|
||
meta_event = {
|
||
"type": "meta",
|
||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
}
|
||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||
|
||
# Stream cells one by one
|
||
all_cells: List[Dict[str, Any]] = []
|
||
cell_idx = 0
|
||
|
||
for cell, cols_meta, total in build_cell_grid_streaming(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||
return
|
||
|
||
if columns_meta is None:
|
||
columns_meta = cols_meta
|
||
# Send columns_used as part of first cell or update meta
|
||
meta_update = {
|
||
"type": "columns",
|
||
"columns_used": cols_meta,
|
||
}
|
||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||
|
||
all_cells.append(cell)
|
||
cell_idx += 1
|
||
|
||
cell_event = {
|
||
"type": "cell",
|
||
"cell": cell,
|
||
"progress": {"current": cell_idx, "total": total},
|
||
}
|
||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||
|
||
# All cells done — build final result
|
||
duration = time.time() - t0
|
||
if columns_meta is None:
|
||
columns_meta = []
|
||
|
||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||
|
||
word_result = {
|
||
"cells": all_cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(all_cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(all_cells),
|
||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# For vocab layout: map cells 1:1 to vocab entries (row→entry).
|
||
# No content shuffling — each cell stays at its detected position.
|
||
vocab_entries = None
|
||
if is_vocab:
|
||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
vocab_entries = entries
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=5,
|
||
)
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||
|
||
# Final complete event
|
||
complete_event = {
|
||
"type": "complete",
|
||
"summary": word_result["summary"],
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
}
|
||
if vocab_entries is not None:
|
||
complete_event["vocab_entries"] = vocab_entries
|
||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||
|
||
|
||
class WordGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/words")
|
||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||
"""Save ground truth feedback for the word recognition step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_entries": req.corrected_entries,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"word_result": session.get("word_result"),
|
||
}
|
||
ground_truth["words"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/words")
|
||
async def get_word_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for word recognition."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
words_gt = ground_truth.get("words")
|
||
if not words_gt:
|
||
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"words_gt": words_gt,
|
||
"words_auto": session.get("word_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM Review Endpoints (Step 6)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@router.post("/sessions/{session_id}/llm-review")
|
||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||
"""Run LLM-based correction on vocab entries from Step 5.
|
||
|
||
Query params:
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if not entries:
|
||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||
|
||
# Optional model override from request body
|
||
body = {}
|
||
try:
|
||
body = await request.json()
|
||
except Exception:
|
||
pass
|
||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||
|
||
if stream:
|
||
return StreamingResponse(
|
||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
# Non-streaming path
|
||
try:
|
||
result = await llm_review_entries(entries, model=model)
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||
|
||
# Store result inside word_result as a sub-key
|
||
word_result["llm_review"] = {
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"entries_corrected": result["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=6)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"total_entries": len(entries),
|
||
"corrections_found": len(result["changes"]),
|
||
}
|
||
|
||
|
||
async def _llm_review_stream_generator(
|
||
session_id: str,
|
||
entries: List[Dict],
|
||
word_result: Dict,
|
||
model: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||
try:
|
||
async for event in llm_review_entries_streaming(entries, model=model):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||
return
|
||
|
||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||
|
||
# On complete: persist to DB
|
||
if event.get("type") == "complete":
|
||
word_result["llm_review"] = {
|
||
"changes": event["changes"],
|
||
"model_used": event["model_used"],
|
||
"duration_ms": event["duration_ms"],
|
||
"entries_corrected": event["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=6)
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||
yield f"data: {json.dumps(error_event)}\n\n"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/llm-review/apply")
|
||
async def apply_llm_corrections(session_id: str, request: Request):
|
||
"""Apply selected LLM corrections to vocab entries."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
llm_review = word_result.get("llm_review")
|
||
if not llm_review:
|
||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||
|
||
body = await request.json()
|
||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||
|
||
changes = llm_review.get("changes", [])
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
|
||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||
corrections = {}
|
||
applied_count = 0
|
||
for idx, change in enumerate(changes):
|
||
if idx in accepted_indices:
|
||
key = (change["row_index"], change["field"])
|
||
corrections[key] = change["new"]
|
||
applied_count += 1
|
||
|
||
# Apply corrections to entries
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
for field_name in ("english", "german", "example"):
|
||
key = (row_idx, field_name)
|
||
if key in corrections:
|
||
entry[field_name] = corrections[key]
|
||
entry["llm_corrected"] = True
|
||
|
||
# Update word_result
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["llm_review"]["applied_count"] = applied_count
|
||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||
|
||
await update_session_db(session_id, word_result=word_result)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"applied_count": applied_count,
|
||
"total_changes": len(changes),
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction")
|
||
async def save_reconstruction(session_id: str, request: Request):
|
||
"""Save edited cell texts from reconstruction step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
body = await request.json()
|
||
cell_updates = body.get("cells", [])
|
||
|
||
if not cell_updates:
|
||
await update_session_db(session_id, current_step=7)
|
||
return {"session_id": session_id, "updated": 0}
|
||
|
||
# Build update map: cell_id -> new text
|
||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||
|
||
# Update cells
|
||
cells = word_result.get("cells", [])
|
||
updated_count = 0
|
||
for cell in cells:
|
||
if cell["cell_id"] in update_map:
|
||
cell["text"] = update_map[cell["cell_id"]]
|
||
cell["status"] = "edited"
|
||
updated_count += 1
|
||
|
||
word_result["cells"] = cells
|
||
|
||
# Also update vocab_entries if present
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if entries:
|
||
# Map cell_id pattern "R{row}_C{col}" to entry fields
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
# Check each field's cell
|
||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||
# Also try without zero-padding
|
||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||
new_text = update_map.get(cell_id) or update_map.get(cell_id_alt)
|
||
if new_text is not None:
|
||
entry[field_name] = new_text
|
||
|
||
word_result["vocab_entries"] = entries
|
||
if "entries" in word_result:
|
||
word_result["entries"] = entries
|
||
|
||
await update_session_db(session_id, word_result=word_result, current_step=7)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"Reconstruction saved for session {session_id}: {updated_count} cells updated")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"updated": updated_count,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reprocess")
|
||
async def reprocess_session(session_id: str, request: Request):
|
||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||
|
||
Body: {"from_step": 5} (1-indexed step number)
|
||
|
||
Clears downstream results:
|
||
- from_step <= 1: deskew_result, dewarp_result, column_result, row_result, word_result
|
||
- from_step <= 2: dewarp_result, column_result, row_result, word_result
|
||
- from_step <= 3: column_result, row_result, word_result
|
||
- from_step <= 4: row_result, word_result
|
||
- from_step <= 5: word_result (cells, vocab_entries)
|
||
- from_step <= 6: word_result.llm_review only
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
body = await request.json()
|
||
from_step = body.get("from_step", 1)
|
||
if not isinstance(from_step, int) or from_step < 1 or from_step > 7:
|
||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 7")
|
||
|
||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||
|
||
# Clear downstream data based on from_step
|
||
if from_step <= 5:
|
||
update_kwargs["word_result"] = None
|
||
elif from_step == 6:
|
||
# Only clear LLM review from word_result
|
||
word_result = session.get("word_result")
|
||
if word_result:
|
||
word_result.pop("llm_review", None)
|
||
word_result.pop("llm_corrections", None)
|
||
update_kwargs["word_result"] = word_result
|
||
|
||
if from_step <= 4:
|
||
update_kwargs["row_result"] = None
|
||
if from_step <= 3:
|
||
update_kwargs["column_result"] = None
|
||
if from_step <= 2:
|
||
update_kwargs["dewarp_result"] = None
|
||
if from_step <= 1:
|
||
update_kwargs["deskew_result"] = None
|
||
|
||
await update_session_db(session_id, **update_kwargs)
|
||
|
||
# Also clear cache
|
||
if session_id in _cache:
|
||
for key in list(update_kwargs.keys()):
|
||
if key != "current_step":
|
||
_cache[session_id][key] = update_kwargs[key]
|
||
_cache[session_id]["current_step"] = from_step
|
||
|
||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"from_step": from_step,
|
||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||
}
|
||
|
||
|
||
async def _get_rows_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with row bands drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
row_result = session.get("row_result")
|
||
if not row_result or not row_result.get("rows"):
|
||
raise HTTPException(status_code=404, detail="No row data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for row types (BGR)
|
||
row_colors = {
|
||
"content": (255, 180, 0), # Blue
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for row in row_result["rows"]:
|
||
x, y = row["x"], row["y"]
|
||
w, h = row["width"], row["height"]
|
||
row_type = row.get("row_type", "content")
|
||
color = row_colors.get(row_type, (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||
|
||
# Label
|
||
idx = row.get("index", 0)
|
||
label = f"R{idx} {row_type.upper()}"
|
||
wc = row.get("word_count", 0)
|
||
if wc:
|
||
label = f"{label} ({wc}w)"
|
||
cv2.putText(img, label, (x + 5, y + 18),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||
|
||
# Blend overlay at 15% opacity
|
||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
async def _get_words_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with cell grid drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Support both new cell-based and legacy entry-based formats
|
||
cells = word_result.get("cells")
|
||
if not cells and not word_result.get("entries"):
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
img_h, img_w = img.shape[:2]
|
||
|
||
overlay = img.copy()
|
||
|
||
if cells:
|
||
# New cell-based overlay: color by column index
|
||
col_palette = [
|
||
(255, 180, 0), # Blue (BGR)
|
||
(0, 200, 0), # Green
|
||
(0, 140, 255), # Orange
|
||
(200, 100, 200), # Purple
|
||
(200, 200, 0), # Cyan
|
||
(100, 200, 200), # Yellow-ish
|
||
]
|
||
|
||
for cell in cells:
|
||
bbox = cell.get("bbox_px", {})
|
||
cx = bbox.get("x", 0)
|
||
cy = bbox.get("y", 0)
|
||
cw = bbox.get("w", 0)
|
||
ch = bbox.get("h", 0)
|
||
if cw <= 0 or ch <= 0:
|
||
continue
|
||
|
||
col_idx = cell.get("col_index", 0)
|
||
color = col_palette[col_idx % len(col_palette)]
|
||
|
||
# Cell rectangle border
|
||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||
|
||
# Cell-ID label (top-left corner)
|
||
cell_id = cell.get("cell_id", "")
|
||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||
|
||
# Text label (bottom of cell)
|
||
text = cell.get("text", "")
|
||
if text:
|
||
conf = cell.get("confidence", 0)
|
||
if conf >= 70:
|
||
text_color = (0, 180, 0)
|
||
elif conf >= 50:
|
||
text_color = (0, 180, 220)
|
||
else:
|
||
text_color = (0, 0, 220)
|
||
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
else:
|
||
# Legacy fallback: entry-based overlay (for old sessions)
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
col_colors = {
|
||
"column_en": (255, 180, 0),
|
||
"column_de": (0, 200, 0),
|
||
"column_example": (0, 140, 255),
|
||
}
|
||
|
||
columns = []
|
||
if column_result and column_result.get("columns"):
|
||
columns = [c for c in column_result["columns"]
|
||
if c.get("type", "").startswith("column_")]
|
||
|
||
content_rows_data = []
|
||
if row_result and row_result.get("rows"):
|
||
content_rows_data = [r for r in row_result["rows"]
|
||
if r.get("row_type") == "content"]
|
||
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
color = col_colors.get(col_type, (200, 200, 200))
|
||
cx, cw = col["x"], col["width"]
|
||
for row in content_rows_data:
|
||
ry, rh = row["y"], row["height"]
|
||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||
|
||
entries = word_result["entries"]
|
||
entry_by_row: Dict[int, Dict] = {}
|
||
for entry in entries:
|
||
entry_by_row[entry.get("row_index", -1)] = entry
|
||
|
||
for row_idx, row in enumerate(content_rows_data):
|
||
entry = entry_by_row.get(row_idx)
|
||
if not entry:
|
||
continue
|
||
conf = entry.get("confidence", 0)
|
||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||
ry, rh = row["y"], row["height"]
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
cx, cw = col["x"], col["width"]
|
||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||
text = entry.get(field, "") if field else ""
|
||
if text:
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
|
||
# Blend overlay at 10% opacity
|
||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|