Three bugs in the post-processing pipeline were overwriting correct streaming results with wrong ones: 1. _split_comma_entries was splitting "Maus, Mäuse" into two separate entries. Disabled — word forms belong together. 2. _attach_example_sentences treated "Ei" (2 chars) as OCR noise due to `len(de) > 2` threshold. Lowered to `len(de) > 1`. 3. _attach_example_sentences wrongly classified rows with EN text but no DE (like "stand ...") as example sentences, merging them into the previous entry. Now only treats rows as examples if they also have no text in the example column. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1593 lines
54 KiB
Python
1593 lines
54 KiB
Python
"""
|
||
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
||
|
||
Zerlegt den OCR-Prozess in 8 einzelne Schritte:
|
||
1. Deskewing - Scan begradigen
|
||
2. Dewarping - Buchwoelbung entzerren
|
||
3. Spaltenerkennung - Unsichtbare Spalten finden
|
||
4. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen
|
||
5. Worterkennung - OCR mit Bounding Boxes
|
||
6. Koordinatenzuweisung - Exakte Positionen
|
||
7. Seitenrekonstruktion - Seite nachbauen
|
||
8. Ground Truth Validierung - Gesamtpruefung
|
||
|
||
Lizenz: Apache 2.0
|
||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
import uuid
|
||
from dataclasses import asdict
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
|
||
from fastapi.responses import Response, StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from cv_vocab_pipeline import (
|
||
PageRegion,
|
||
RowGeometry,
|
||
_cells_to_vocab_entries,
|
||
_fix_character_confusion,
|
||
_fix_phonetic_brackets,
|
||
_split_comma_entries,
|
||
_attach_example_sentences,
|
||
analyze_layout,
|
||
analyze_layout_by_words,
|
||
build_cell_grid,
|
||
build_cell_grid_streaming,
|
||
build_word_grid,
|
||
classify_column_types,
|
||
create_layout_image,
|
||
create_ocr_image,
|
||
deskew_image,
|
||
deskew_image_by_word_alignment,
|
||
detect_column_geometry,
|
||
detect_row_geometry,
|
||
dewarp_image,
|
||
dewarp_image_manual,
|
||
render_image_high_res,
|
||
render_pdf_high_res,
|
||
)
|
||
from ocr_pipeline_session_store import (
|
||
create_session_db,
|
||
delete_session_db,
|
||
get_session_db,
|
||
get_session_image,
|
||
init_ocr_pipeline_tables,
|
||
list_sessions_db,
|
||
update_session_db,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||
# DB is source of truth, cache holds BGR arrays during active processing.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
|
||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
if session_id in _cache:
|
||
return _cache[session_id]
|
||
|
||
cache_entry: Dict[str, Any] = {
|
||
"id": session_id,
|
||
**session,
|
||
"original_bgr": None,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
}
|
||
|
||
# Decode images from DB into BGR numpy arrays
|
||
for img_type, bgr_key in [
|
||
("original", "original_bgr"),
|
||
("deskewed", "deskewed_bgr"),
|
||
("dewarped", "dewarped_bgr"),
|
||
]:
|
||
png_data = await get_session_image(session_id, img_type)
|
||
if png_data:
|
||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
cache_entry[bgr_key] = bgr
|
||
|
||
_cache[session_id] = cache_entry
|
||
return cache_entry
|
||
|
||
|
||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||
"""Get from cache or raise 404."""
|
||
entry = _cache.get(session_id)
|
||
if not entry:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||
return entry
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic Models
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ManualDeskewRequest(BaseModel):
|
||
angle: float
|
||
|
||
|
||
class DeskewGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_angle: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualDewarpRequest(BaseModel):
|
||
shear_degrees: float
|
||
|
||
|
||
class DewarpGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_shear: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class RenameSessionRequest(BaseModel):
|
||
name: str
|
||
|
||
|
||
class ManualColumnsRequest(BaseModel):
|
||
columns: List[Dict[str, Any]]
|
||
|
||
|
||
class ColumnGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualRowsRequest(BaseModel):
|
||
rows: List[Dict[str, Any]]
|
||
|
||
|
||
class RowGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Session Management Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions")
|
||
async def list_sessions():
|
||
"""List all OCR pipeline sessions."""
|
||
sessions = await list_sessions_db()
|
||
return {"sessions": sessions}
|
||
|
||
|
||
@router.post("/sessions")
|
||
async def create_session(
|
||
file: UploadFile = File(...),
|
||
name: Optional[str] = Form(None),
|
||
):
|
||
"""Upload a PDF or image file and create a pipeline session."""
|
||
file_data = await file.read()
|
||
filename = file.filename or "upload"
|
||
content_type = file.content_type or ""
|
||
|
||
session_id = str(uuid.uuid4())
|
||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||
|
||
try:
|
||
if is_pdf:
|
||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||
else:
|
||
img_bgr = render_image_high_res(file_data)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||
|
||
# Encode original as PNG bytes
|
||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||
|
||
original_png = png_buf.tobytes()
|
||
session_name = name or filename
|
||
|
||
# Persist to DB
|
||
await create_session_db(
|
||
session_id=session_id,
|
||
name=session_name,
|
||
filename=filename,
|
||
original_png=original_png,
|
||
)
|
||
|
||
# Cache BGR array for immediate processing
|
||
_cache[session_id] = {
|
||
"id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"original_bgr": img_bgr,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
"deskew_result": None,
|
||
"dewarp_result": None,
|
||
"ground_truth": {},
|
||
"current_step": 1,
|
||
}
|
||
|
||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"image_width": img_bgr.shape[1],
|
||
"image_height": img_bgr.shape[0],
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}")
|
||
async def get_session_info(session_id: str):
|
||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# Get image dimensions from original PNG
|
||
original_png = await get_session_image(session_id, "original")
|
||
if original_png:
|
||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||
else:
|
||
img_w, img_h = 0, 0
|
||
|
||
result = {
|
||
"session_id": session["id"],
|
||
"filename": session.get("filename", ""),
|
||
"name": session.get("name", ""),
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
"current_step": session.get("current_step", 1),
|
||
}
|
||
|
||
if session.get("deskew_result"):
|
||
result["deskew_result"] = session["deskew_result"]
|
||
if session.get("dewarp_result"):
|
||
result["dewarp_result"] = session["dewarp_result"]
|
||
if session.get("column_result"):
|
||
result["column_result"] = session["column_result"]
|
||
if session.get("row_result"):
|
||
result["row_result"] = session["row_result"]
|
||
if session.get("word_result"):
|
||
result["word_result"] = session["word_result"]
|
||
|
||
return result
|
||
|
||
|
||
@router.put("/sessions/{session_id}")
|
||
async def rename_session(session_id: str, req: RenameSessionRequest):
|
||
"""Rename a session."""
|
||
updated = await update_session_db(session_id, name=req.name)
|
||
if not updated:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "name": req.name}
|
||
|
||
|
||
@router.delete("/sessions/{session_id}")
|
||
async def delete_session(session_id: str):
|
||
"""Delete a session."""
|
||
_cache.pop(session_id, None)
|
||
deleted = await delete_session_db(session_id)
|
||
if not deleted:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "deleted": True}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Image Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||
async def get_image(session_id: str, image_type: str):
|
||
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
|
||
if image_type not in valid_types:
|
||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||
|
||
if image_type == "columns-overlay":
|
||
return await _get_columns_overlay(session_id)
|
||
|
||
if image_type == "rows-overlay":
|
||
return await _get_rows_overlay(session_id)
|
||
|
||
if image_type == "words-overlay":
|
||
return await _get_words_overlay(session_id)
|
||
|
||
# Try cache first for fast serving
|
||
cached = _cache.get(session_id)
|
||
if cached:
|
||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||
|
||
# For binarized, check if we have it cached as PNG
|
||
if image_type == "binarized" and cached.get("binarized_png"):
|
||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||
|
||
# Load from DB
|
||
data = await get_session_image(session_id, image_type)
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||
|
||
return Response(content=data, media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Deskew Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/deskew")
|
||
async def auto_deskew(session_id: str):
|
||
"""Run both deskew methods and pick the best one."""
|
||
# Ensure session is in cache
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("original_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Original image not available")
|
||
|
||
t0 = time.time()
|
||
|
||
# Method 1: Hough Lines
|
||
try:
|
||
deskewed_hough, angle_hough = deskew_image(img_bgr.copy())
|
||
except Exception as e:
|
||
logger.warning(f"Hough deskew failed: {e}")
|
||
deskewed_hough, angle_hough = img_bgr, 0.0
|
||
|
||
# Method 2: Word Alignment (needs image bytes)
|
||
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||
|
||
try:
|
||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||
except Exception as e:
|
||
logger.warning(f"Word alignment deskew failed: {e}")
|
||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||
|
||
duration = time.time() - t0
|
||
|
||
# Pick best method
|
||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||
method_used = "word_alignment"
|
||
angle_applied = angle_wa
|
||
wa_array = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||
deskewed_bgr = cv2.imdecode(wa_array, cv2.IMREAD_COLOR)
|
||
if deskewed_bgr is None:
|
||
deskewed_bgr = deskewed_hough
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
else:
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
deskewed_bgr = deskewed_hough
|
||
|
||
# Encode as PNG
|
||
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
||
|
||
# Create binarized version
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(deskewed_bgr)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception as e:
|
||
logger.warning(f"Binarization failed: {e}")
|
||
|
||
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
||
|
||
deskew_result = {
|
||
"angle_hough": round(angle_hough, 3),
|
||
"angle_word_alignment": round(angle_wa, 3),
|
||
"angle_applied": round(angle_applied, 3),
|
||
"method_used": method_used,
|
||
"confidence": round(confidence, 2),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = deskewed_bgr
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
"current_step": 2,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
||
f"hough={angle_hough:.2f} wa={angle_wa:.2f} -> {method_used} {angle_applied:.2f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**deskew_result,
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/deskew/manual")
|
||
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||
"""Apply a manual rotation angle to the original image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("original_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Original image not available")
|
||
|
||
angle = max(-5.0, min(5.0, req.angle))
|
||
|
||
h, w = img_bgr.shape[:2]
|
||
center = (w // 2, h // 2)
|
||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
||
flags=cv2.INTER_LINEAR,
|
||
borderMode=cv2.BORDER_REPLICATE)
|
||
|
||
success, png_buf = cv2.imencode(".png", rotated)
|
||
deskewed_png = png_buf.tobytes() if success else b""
|
||
|
||
# Binarize
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(rotated)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception:
|
||
pass
|
||
|
||
deskew_result = {
|
||
**(cached.get("deskew_result") or {}),
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = rotated
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
||
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
||
"""Save ground truth feedback for the deskew step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_angle": req.corrected_angle,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"deskew_result": session.get("deskew_result"),
|
||
}
|
||
ground_truth["deskew"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
# Update cache
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dewarp Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/dewarp")
|
||
async def auto_dewarp(session_id: str):
|
||
"""Detect and correct vertical shear on the deskewed image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
t0 = time.time()
|
||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||
duration = time.time() - t0
|
||
|
||
# Encode as PNG
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
"method_used": dewarp_info["method"],
|
||
"shear_degrees": dewarp_info["shear_degrees"],
|
||
"confidence": dewarp_info["confidence"],
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||
current_step=3,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
||
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**dewarp_result,
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/dewarp/manual")
|
||
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||
"""Apply shear correction with a manual angle."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
||
|
||
if abs(shear_deg) < 0.001:
|
||
dewarped_bgr = deskewed_bgr
|
||
else:
|
||
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
||
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
**(cached.get("dewarp_result") or {}),
|
||
"method_used": "manual",
|
||
"shear_degrees": round(shear_deg, 3),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"shear_degrees": round(shear_deg, 3),
|
||
"method_used": "manual",
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
||
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
||
"""Save ground truth feedback for the dewarp step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_shear": req.corrected_shear,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"dewarp_result": session.get("dewarp_result"),
|
||
}
|
||
ground_truth["dewarp"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Column Detection Endpoints (Step 3)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/columns")
|
||
async def detect_columns(session_id: str):
|
||
"""Run column detection on the dewarped image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before column detection")
|
||
|
||
t0 = time.time()
|
||
|
||
# Binarized image for layout analysis
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
h, w = ocr_img.shape[:2]
|
||
|
||
# Phase A: Geometry detection (returns word_dicts + inv for reuse)
|
||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||
|
||
if geo_result is None:
|
||
# Fallback to projection-based layout
|
||
layout_img = create_layout_image(dewarped_bgr)
|
||
regions = analyze_layout(layout_img, ocr_img)
|
||
else:
|
||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
content_w = right_x - left_x
|
||
|
||
# Cache intermediates for row detection (avoids second Tesseract run)
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
# Phase B: Content-based classification
|
||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y)
|
||
|
||
duration = time.time() - t0
|
||
|
||
columns = [asdict(r) for r in regions]
|
||
|
||
# Determine classification methods used
|
||
methods = list(set(
|
||
c.get("classification_method", "") for c in columns
|
||
if c.get("classification_method")
|
||
))
|
||
|
||
column_result = {
|
||
"columns": columns,
|
||
"classification_methods": methods,
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB — also invalidate downstream results (rows, words)
|
||
await update_session_db(
|
||
session_id,
|
||
column_result=column_result,
|
||
row_result=None,
|
||
word_result=None,
|
||
current_step=3,
|
||
)
|
||
|
||
# Update cache
|
||
cached["column_result"] = column_result
|
||
cached.pop("row_result", None)
|
||
cached.pop("word_result", None)
|
||
|
||
col_count = len([c for c in columns if c["type"].startswith("column")])
|
||
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
||
f"{col_count} columns detected ({duration:.2f}s)")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**column_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/columns/manual")
|
||
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
||
"""Override detected columns with manual definitions."""
|
||
column_result = {
|
||
"columns": req.columns,
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, column_result=column_result,
|
||
row_result=None, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["column_result"] = column_result
|
||
_cache[session_id].pop("row_result", None)
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
||
f"{len(req.columns)} columns set")
|
||
|
||
return {"session_id": session_id, **column_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/columns")
|
||
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
||
"""Save ground truth feedback for the column detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_columns": req.corrected_columns,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"column_result": session.get("column_result"),
|
||
}
|
||
ground_truth["columns"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/columns")
|
||
async def get_column_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
columns_gt = ground_truth.get("columns")
|
||
if not columns_gt:
|
||
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"columns_gt": columns_gt,
|
||
"columns_auto": session.get("column_result"),
|
||
}
|
||
|
||
|
||
async def _get_columns_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with column borders drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise HTTPException(status_code=404, detail="No column data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for region types (BGR)
|
||
colors = {
|
||
"column_en": (255, 180, 0), # Blue
|
||
"column_de": (0, 200, 0), # Green
|
||
"column_example": (0, 140, 255), # Orange
|
||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||
"page_ref": (200, 0, 200), # Purple
|
||
"column_marker": (0, 0, 220), # Red
|
||
"column_ignore": (180, 180, 180), # Light Gray
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for col in column_result["columns"]:
|
||
x, y = col["x"], col["y"]
|
||
w, h = col["width"], col["height"]
|
||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||
|
||
# Label with confidence
|
||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||
conf = col.get("classification_confidence")
|
||
if conf is not None and conf < 1.0:
|
||
label = f"{label} {int(conf * 100)}%"
|
||
cv2.putText(img, label, (x + 10, y + 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||
|
||
# Blend overlay at 20% opacity
|
||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Row Detection Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/rows")
|
||
async def detect_rows(session_id: str):
|
||
"""Run row detection on the dewarped image using horizontal gap analysis."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before row detection")
|
||
|
||
t0 = time.time()
|
||
|
||
# Try to reuse cached word_dicts and inv from column detection
|
||
word_dicts = cached.get("_word_dicts")
|
||
inv = cached.get("_inv")
|
||
content_bounds = cached.get("_content_bounds")
|
||
|
||
if word_dicts is None or inv is None or content_bounds is None:
|
||
# Not cached — run column geometry to get intermediates
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||
if geo_result is None:
|
||
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
else:
|
||
left_x, right_x, top_y, bottom_y = content_bounds
|
||
|
||
# Run row detection
|
||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||
duration = time.time() - t0
|
||
|
||
# Build serializable result (exclude words to keep payload small)
|
||
rows_data = []
|
||
for r in rows:
|
||
rows_data.append({
|
||
"index": r.index,
|
||
"x": r.x,
|
||
"y": r.y,
|
||
"width": r.width,
|
||
"height": r.height,
|
||
"word_count": r.word_count,
|
||
"row_type": r.row_type,
|
||
"gap_before": r.gap_before,
|
||
})
|
||
|
||
type_counts = {}
|
||
for r in rows:
|
||
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||
|
||
row_result = {
|
||
"rows": rows_data,
|
||
"summary": type_counts,
|
||
"total_rows": len(rows),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB — also invalidate word_result since rows changed
|
||
await update_session_db(
|
||
session_id,
|
||
row_result=row_result,
|
||
word_result=None,
|
||
current_step=4,
|
||
)
|
||
|
||
cached["row_result"] = row_result
|
||
cached.pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**row_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/rows/manual")
|
||
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||
"""Override detected rows with manual definitions."""
|
||
row_result = {
|
||
"rows": req.rows,
|
||
"total_rows": len(req.rows),
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, row_result=row_result, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["row_result"] = row_result
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||
f"{len(req.rows)} rows set")
|
||
|
||
return {"session_id": session_id, **row_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||
"""Save ground truth feedback for the row detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_rows": req.corrected_rows,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"row_result": session.get("row_result"),
|
||
}
|
||
ground_truth["rows"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||
async def get_row_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for row detection."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
rows_gt = ground_truth.get("rows")
|
||
if not rows_gt:
|
||
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"rows_gt": rows_gt,
|
||
"rows_auto": session.get("row_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Word Recognition Endpoints (Step 5)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/words")
|
||
async def detect_words(
|
||
session_id: str,
|
||
request: Request,
|
||
engine: str = "auto",
|
||
pronunciation: str = "british",
|
||
stream: bool = False,
|
||
):
|
||
"""Build word grid from columns × rows, OCR each cell.
|
||
|
||
Query params:
|
||
engine: 'auto' (default), 'tesseract', or 'rapid'
|
||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
"""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Dewarp must be completed before word detection")
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||
if not row_result or not row_result.get("rows"):
|
||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||
|
||
# Convert column dicts back to PageRegion objects
|
||
col_regions = [
|
||
PageRegion(
|
||
type=c["type"],
|
||
x=c["x"], y=c["y"],
|
||
width=c["width"], height=c["height"],
|
||
classification_confidence=c.get("classification_confidence", 1.0),
|
||
classification_method=c.get("classification_method", ""),
|
||
)
|
||
for c in column_result["columns"]
|
||
]
|
||
|
||
# Convert row dicts back to RowGeometry objects
|
||
row_geoms = [
|
||
RowGeometry(
|
||
index=r["index"],
|
||
x=r["x"], y=r["y"],
|
||
width=r["width"], height=r["height"],
|
||
word_count=r.get("word_count", 0),
|
||
words=[],
|
||
row_type=r.get("row_type", "content"),
|
||
gap_before=r.get("gap_before", 0),
|
||
)
|
||
for r in row_result["rows"]
|
||
]
|
||
|
||
# Re-populate row.words from cached full-page Tesseract words.
|
||
# Word-lookup in _ocr_single_cell needs these to avoid re-running OCR.
|
||
word_dicts = cached.get("_word_dicts")
|
||
if word_dicts is None:
|
||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||
if geo_result is not None:
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
if word_dicts:
|
||
# words['top'] is relative to content-ROI top_y.
|
||
# row.y is absolute. Convert: row_y_rel = row.y - top_y.
|
||
content_bounds = cached.get("_content_bounds")
|
||
if content_bounds:
|
||
_lx, _rx, top_y, _by = content_bounds
|
||
else:
|
||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||
|
||
for row in row_geoms:
|
||
row_y_rel = row.y - top_y
|
||
row_bottom_rel = row_y_rel + row.height
|
||
row.words = [
|
||
w for w in word_dicts
|
||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||
]
|
||
row.word_count = len(row.words)
|
||
|
||
if stream:
|
||
return StreamingResponse(
|
||
_word_stream_generator(
|
||
session_id, cached, col_regions, row_geoms,
|
||
dewarped_bgr, engine, pronunciation, request,
|
||
),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# --- Non-streaming path (unchanged) ---
|
||
t0 = time.time()
|
||
|
||
# Create binarized OCR image (for Tesseract)
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Build generic cell grid
|
||
cells, columns_meta = build_cell_grid(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
)
|
||
duration = time.time() - t0
|
||
|
||
# Layout detection
|
||
col_types = {c['type'] for c in columns_meta}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Count content rows and columns for grid_shape
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
n_cols = len(columns_meta)
|
||
|
||
# Determine which engine was actually used
|
||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||
|
||
# Grid result (always generic)
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# For vocab layout: add post-processed vocab_entries (backwards compat)
|
||
if is_vocab:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
# NOTE: _split_comma_entries disabled — word forms like "mouse, mice"
|
||
# / "Maus, Mäuse" belong together in one entry.
|
||
# entries = _split_comma_entries(entries)
|
||
entries = _attach_example_sentences(entries)
|
||
word_result["vocab_entries"] = entries
|
||
# Also keep "entries" key for backwards compatibility
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=5,
|
||
)
|
||
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**word_result,
|
||
}
|
||
|
||
|
||
async def _word_stream_generator(
|
||
session_id: str,
|
||
cached: Dict[str, Any],
|
||
col_regions: List[PageRegion],
|
||
row_geoms: List[RowGeometry],
|
||
dewarped_bgr: np.ndarray,
|
||
engine: str,
|
||
pronunciation: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||
t0 = time.time()
|
||
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Compute grid shape upfront for the meta event
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
_skip_types = {'column_ignore', 'header', 'footer', 'page_ref'}
|
||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||
|
||
# Determine layout
|
||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Start streaming — first event: meta
|
||
columns_meta = None # will be set from first yield
|
||
total_cells = n_content_rows * n_cols
|
||
|
||
meta_event = {
|
||
"type": "meta",
|
||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
}
|
||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||
|
||
# Stream cells one by one
|
||
all_cells: List[Dict[str, Any]] = []
|
||
cell_idx = 0
|
||
|
||
for cell, cols_meta, total in build_cell_grid_streaming(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||
return
|
||
|
||
if columns_meta is None:
|
||
columns_meta = cols_meta
|
||
# Send columns_used as part of first cell or update meta
|
||
meta_update = {
|
||
"type": "columns",
|
||
"columns_used": cols_meta,
|
||
}
|
||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||
|
||
all_cells.append(cell)
|
||
cell_idx += 1
|
||
|
||
cell_event = {
|
||
"type": "cell",
|
||
"cell": cell,
|
||
"progress": {"current": cell_idx, "total": total},
|
||
}
|
||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||
|
||
# All cells done — build final result
|
||
duration = time.time() - t0
|
||
if columns_meta is None:
|
||
columns_meta = []
|
||
|
||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||
|
||
word_result = {
|
||
"cells": all_cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(all_cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(all_cells),
|
||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# Vocab post-processing
|
||
vocab_entries = None
|
||
if is_vocab:
|
||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
# NOTE: _split_comma_entries disabled — word forms like "mouse, mice"
|
||
# / "Maus, Mäuse" belong together in one entry.
|
||
# entries = _split_comma_entries(entries)
|
||
entries = _attach_example_sentences(entries)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
vocab_entries = entries
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=5,
|
||
)
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||
|
||
# Final complete event
|
||
complete_event = {
|
||
"type": "complete",
|
||
"summary": word_result["summary"],
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
}
|
||
if vocab_entries is not None:
|
||
complete_event["vocab_entries"] = vocab_entries
|
||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||
|
||
|
||
class WordGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/words")
|
||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||
"""Save ground truth feedback for the word recognition step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_entries": req.corrected_entries,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"word_result": session.get("word_result"),
|
||
}
|
||
ground_truth["words"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/words")
|
||
async def get_word_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for word recognition."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
words_gt = ground_truth.get("words")
|
||
if not words_gt:
|
||
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"words_gt": words_gt,
|
||
"words_auto": session.get("word_result"),
|
||
}
|
||
|
||
|
||
async def _get_rows_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with row bands drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
row_result = session.get("row_result")
|
||
if not row_result or not row_result.get("rows"):
|
||
raise HTTPException(status_code=404, detail="No row data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for row types (BGR)
|
||
row_colors = {
|
||
"content": (255, 180, 0), # Blue
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for row in row_result["rows"]:
|
||
x, y = row["x"], row["y"]
|
||
w, h = row["width"], row["height"]
|
||
row_type = row.get("row_type", "content")
|
||
color = row_colors.get(row_type, (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||
|
||
# Label
|
||
idx = row.get("index", 0)
|
||
label = f"R{idx} {row_type.upper()}"
|
||
wc = row.get("word_count", 0)
|
||
if wc:
|
||
label = f"{label} ({wc}w)"
|
||
cv2.putText(img, label, (x + 5, y + 18),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||
|
||
# Blend overlay at 15% opacity
|
||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
async def _get_words_overlay(session_id: str) -> Response:
|
||
"""Generate dewarped image with cell grid drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Support both new cell-based and legacy entry-based formats
|
||
cells = word_result.get("cells")
|
||
if not cells and not word_result.get("entries"):
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Load dewarped image
|
||
dewarped_png = await get_session_image(session_id, "dewarped")
|
||
if not dewarped_png:
|
||
raise HTTPException(status_code=404, detail="Dewarped image not available")
|
||
|
||
arr = np.frombuffer(dewarped_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
img_h, img_w = img.shape[:2]
|
||
|
||
overlay = img.copy()
|
||
|
||
if cells:
|
||
# New cell-based overlay: color by column index
|
||
col_palette = [
|
||
(255, 180, 0), # Blue (BGR)
|
||
(0, 200, 0), # Green
|
||
(0, 140, 255), # Orange
|
||
(200, 100, 200), # Purple
|
||
(200, 200, 0), # Cyan
|
||
(100, 200, 200), # Yellow-ish
|
||
]
|
||
|
||
for cell in cells:
|
||
bbox = cell.get("bbox_px", {})
|
||
cx = bbox.get("x", 0)
|
||
cy = bbox.get("y", 0)
|
||
cw = bbox.get("w", 0)
|
||
ch = bbox.get("h", 0)
|
||
if cw <= 0 or ch <= 0:
|
||
continue
|
||
|
||
col_idx = cell.get("col_index", 0)
|
||
color = col_palette[col_idx % len(col_palette)]
|
||
|
||
# Cell rectangle border
|
||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||
|
||
# Cell-ID label (top-left corner)
|
||
cell_id = cell.get("cell_id", "")
|
||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||
|
||
# Text label (bottom of cell)
|
||
text = cell.get("text", "")
|
||
if text:
|
||
conf = cell.get("confidence", 0)
|
||
if conf >= 70:
|
||
text_color = (0, 180, 0)
|
||
elif conf >= 50:
|
||
text_color = (0, 180, 220)
|
||
else:
|
||
text_color = (0, 0, 220)
|
||
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
else:
|
||
# Legacy fallback: entry-based overlay (for old sessions)
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
col_colors = {
|
||
"column_en": (255, 180, 0),
|
||
"column_de": (0, 200, 0),
|
||
"column_example": (0, 140, 255),
|
||
}
|
||
|
||
columns = []
|
||
if column_result and column_result.get("columns"):
|
||
columns = [c for c in column_result["columns"]
|
||
if c.get("type", "").startswith("column_")]
|
||
|
||
content_rows_data = []
|
||
if row_result and row_result.get("rows"):
|
||
content_rows_data = [r for r in row_result["rows"]
|
||
if r.get("row_type") == "content"]
|
||
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
color = col_colors.get(col_type, (200, 200, 200))
|
||
cx, cw = col["x"], col["width"]
|
||
for row in content_rows_data:
|
||
ry, rh = row["y"], row["height"]
|
||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||
|
||
entries = word_result["entries"]
|
||
entry_by_row: Dict[int, Dict] = {}
|
||
for entry in entries:
|
||
entry_by_row[entry.get("row_index", -1)] = entry
|
||
|
||
for row_idx, row in enumerate(content_rows_data):
|
||
entry = entry_by_row.get(row_idx)
|
||
if not entry:
|
||
continue
|
||
conf = entry.get("confidence", 0)
|
||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||
ry, rh = row["y"], row["height"]
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
cx, cw = col["x"], col["width"]
|
||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||
text = entry.get(field, "") if field else ""
|
||
if text:
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
|
||
# Blend overlay at 10% opacity
|
||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|