Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 1m58s
CI / test-python-agent-core (push) Successful in 18s
CI / test-nodejs-website (push) Successful in 19s
Add cv_graphic_detect.py for detecting non-text visual elements (arrows, circles, lines, exclamation marks, icons, illustrations). Draw detected graphics on structure overlay image and display them in the frontend StepStructureDetection component with shape counts and individual listings. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
5200 lines
198 KiB
Python
5200 lines
198 KiB
Python
"""
|
||
OCR Pipeline API - Schrittweise Seitenrekonstruktion.
|
||
|
||
Zerlegt den OCR-Prozess in 10 einzelne Schritte:
|
||
1. Orientierung - 90/180/270° Drehungen korrigieren (orientation_crop_api.py)
|
||
2. Begradigung (Deskew) - Scan begradigen
|
||
3. Entzerrung (Dewarp) - Buchwoelbung entzerren
|
||
4. Zuschneiden - Scannerraender/Buchruecken entfernen (orientation_crop_api.py)
|
||
5. Spaltenerkennung - Unsichtbare Spalten finden
|
||
6. Zeilenerkennung - Horizontale Zeilen + Kopf-/Fusszeilen
|
||
7. Worterkennung - OCR mit Bounding Boxes
|
||
8. LLM-Korrektur - OCR-Fehler per LLM korrigieren
|
||
9. Seitenrekonstruktion - Seite nachbauen
|
||
10. Ground Truth Validierung - Gesamtpruefung
|
||
|
||
Lizenz: Apache 2.0
|
||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import time
|
||
import uuid
|
||
from dataclasses import asdict
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from fastapi import APIRouter, File, Form, HTTPException, Query, Request, UploadFile
|
||
from fastapi.responses import Response, StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from cv_vocab_pipeline import (
|
||
OLLAMA_REVIEW_MODEL,
|
||
DocumentTypeResult,
|
||
PageRegion,
|
||
RowGeometry,
|
||
_cells_to_vocab_entries,
|
||
_detect_header_footer_gaps,
|
||
_detect_sub_columns,
|
||
_fix_character_confusion,
|
||
_fix_phonetic_brackets,
|
||
fix_cell_phonetics,
|
||
analyze_layout,
|
||
analyze_layout_by_words,
|
||
build_cell_grid,
|
||
build_cell_grid_streaming,
|
||
build_cell_grid_v2,
|
||
build_cell_grid_v2_streaming,
|
||
build_word_grid,
|
||
classify_column_types,
|
||
create_layout_image,
|
||
create_ocr_image,
|
||
deskew_image,
|
||
deskew_image_by_word_alignment,
|
||
deskew_image_iterative,
|
||
deskew_two_pass,
|
||
detect_column_geometry,
|
||
detect_column_geometry_zoned,
|
||
detect_document_type,
|
||
detect_row_geometry,
|
||
expand_narrow_columns,
|
||
_apply_shear,
|
||
dewarp_image,
|
||
dewarp_image_manual,
|
||
llm_review_entries,
|
||
llm_review_entries_streaming,
|
||
render_image_high_res,
|
||
render_pdf_high_res,
|
||
)
|
||
from cv_box_detect import detect_boxes, split_page_into_zones
|
||
from cv_color_detect import detect_word_colors, recover_colored_text, _COLOR_RANGES, _COLOR_HEX
|
||
from cv_graphic_detect import detect_graphic_elements
|
||
from cv_words_first import build_grid_from_words
|
||
from ocr_pipeline_session_store import (
|
||
create_session_db,
|
||
delete_all_sessions_db,
|
||
delete_session_db,
|
||
get_session_db,
|
||
get_session_image,
|
||
get_sub_sessions,
|
||
init_ocr_pipeline_tables,
|
||
list_sessions_db,
|
||
update_session_db,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||
# DB is source of truth, cache holds BGR arrays during active processing.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
|
||
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
||
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
||
for img_type in ("cropped", "dewarped", "original"):
|
||
png_data = await get_session_image(session_id, img_type)
|
||
if png_data:
|
||
return png_data
|
||
return None
|
||
|
||
|
||
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
if session_id in _cache:
|
||
return _cache[session_id]
|
||
|
||
cache_entry: Dict[str, Any] = {
|
||
"id": session_id,
|
||
**session,
|
||
"original_bgr": None,
|
||
"oriented_bgr": None,
|
||
"cropped_bgr": None,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
}
|
||
|
||
# Decode images from DB into BGR numpy arrays
|
||
for img_type, bgr_key in [
|
||
("original", "original_bgr"),
|
||
("oriented", "oriented_bgr"),
|
||
("cropped", "cropped_bgr"),
|
||
("deskewed", "deskewed_bgr"),
|
||
("dewarped", "dewarped_bgr"),
|
||
]:
|
||
png_data = await get_session_image(session_id, img_type)
|
||
if png_data:
|
||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
cache_entry[bgr_key] = bgr
|
||
|
||
# Sub-sessions: original image IS the cropped box region.
|
||
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
||
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
||
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
||
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
||
|
||
_cache[session_id] = cache_entry
|
||
return cache_entry
|
||
|
||
|
||
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||
"""Get from cache or raise 404."""
|
||
entry = _cache.get(session_id)
|
||
if not entry:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||
return entry
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic Models
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ManualDeskewRequest(BaseModel):
|
||
angle: float
|
||
|
||
|
||
class DeskewGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_angle: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualDewarpRequest(BaseModel):
|
||
shear_degrees: float
|
||
|
||
|
||
class CombinedAdjustRequest(BaseModel):
|
||
rotation_degrees: float = 0.0
|
||
shear_degrees: float = 0.0
|
||
|
||
|
||
class DewarpGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_shear: Optional[float] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
VALID_DOCUMENT_CATEGORIES = {
|
||
'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
||
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
||
}
|
||
|
||
|
||
class UpdateSessionRequest(BaseModel):
|
||
name: Optional[str] = None
|
||
document_category: Optional[str] = None
|
||
|
||
|
||
class ManualColumnsRequest(BaseModel):
|
||
columns: List[Dict[str, Any]]
|
||
|
||
|
||
class ColumnGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class ManualRowsRequest(BaseModel):
|
||
rows: List[Dict[str, Any]]
|
||
|
||
|
||
class RowGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
class RemoveHandwritingRequest(BaseModel):
|
||
method: str = "auto" # "auto" | "telea" | "ns"
|
||
target_ink: str = "all" # "all" | "colored" | "pencil"
|
||
dilation: int = 2 # mask dilation iterations (0-5)
|
||
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Session Management Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions")
|
||
async def list_sessions(include_sub_sessions: bool = False):
|
||
"""List OCR pipeline sessions.
|
||
|
||
By default, sub-sessions (box regions) are hidden.
|
||
Pass ?include_sub_sessions=true to show them.
|
||
"""
|
||
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||
return {"sessions": sessions}
|
||
|
||
|
||
@router.post("/sessions")
|
||
async def create_session(
|
||
file: UploadFile = File(...),
|
||
name: Optional[str] = Form(None),
|
||
):
|
||
"""Upload a PDF or image file and create a pipeline session."""
|
||
file_data = await file.read()
|
||
filename = file.filename or "upload"
|
||
content_type = file.content_type or ""
|
||
|
||
session_id = str(uuid.uuid4())
|
||
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||
|
||
try:
|
||
if is_pdf:
|
||
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||
else:
|
||
img_bgr = render_image_high_res(file_data)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||
|
||
# Encode original as PNG bytes
|
||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||
|
||
original_png = png_buf.tobytes()
|
||
session_name = name or filename
|
||
|
||
# Persist to DB
|
||
await create_session_db(
|
||
session_id=session_id,
|
||
name=session_name,
|
||
filename=filename,
|
||
original_png=original_png,
|
||
)
|
||
|
||
# Cache BGR array for immediate processing
|
||
_cache[session_id] = {
|
||
"id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"original_bgr": img_bgr,
|
||
"oriented_bgr": None,
|
||
"cropped_bgr": None,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
"orientation_result": None,
|
||
"crop_result": None,
|
||
"deskew_result": None,
|
||
"dewarp_result": None,
|
||
"ground_truth": {},
|
||
"current_step": 1,
|
||
}
|
||
|
||
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"filename": filename,
|
||
"name": session_name,
|
||
"image_width": img_bgr.shape[1],
|
||
"image_height": img_bgr.shape[0],
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}")
|
||
async def get_session_info(session_id: str):
|
||
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# Get image dimensions from original PNG
|
||
original_png = await get_session_image(session_id, "original")
|
||
if original_png:
|
||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||
else:
|
||
img_w, img_h = 0, 0
|
||
|
||
result = {
|
||
"session_id": session["id"],
|
||
"filename": session.get("filename", ""),
|
||
"name": session.get("name", ""),
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||
"current_step": session.get("current_step", 1),
|
||
"document_category": session.get("document_category"),
|
||
"doc_type": session.get("doc_type"),
|
||
}
|
||
|
||
if session.get("orientation_result"):
|
||
result["orientation_result"] = session["orientation_result"]
|
||
if session.get("crop_result"):
|
||
result["crop_result"] = session["crop_result"]
|
||
if session.get("deskew_result"):
|
||
result["deskew_result"] = session["deskew_result"]
|
||
if session.get("dewarp_result"):
|
||
result["dewarp_result"] = session["dewarp_result"]
|
||
if session.get("column_result"):
|
||
result["column_result"] = session["column_result"]
|
||
if session.get("row_result"):
|
||
result["row_result"] = session["row_result"]
|
||
if session.get("word_result"):
|
||
result["word_result"] = session["word_result"]
|
||
if session.get("doc_type_result"):
|
||
result["doc_type_result"] = session["doc_type_result"]
|
||
|
||
# Sub-session info
|
||
if session.get("parent_session_id"):
|
||
result["parent_session_id"] = session["parent_session_id"]
|
||
result["box_index"] = session.get("box_index")
|
||
else:
|
||
# Check for sub-sessions
|
||
subs = await get_sub_sessions(session_id)
|
||
if subs:
|
||
result["sub_sessions"] = [
|
||
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||
for s in subs
|
||
]
|
||
|
||
return result
|
||
|
||
|
||
@router.put("/sessions/{session_id}")
|
||
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||
"""Update session name and/or document category."""
|
||
kwargs: Dict[str, Any] = {}
|
||
if req.name is not None:
|
||
kwargs["name"] = req.name
|
||
if req.document_category is not None:
|
||
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||
)
|
||
kwargs["document_category"] = req.document_category
|
||
if not kwargs:
|
||
raise HTTPException(status_code=400, detail="Nothing to update")
|
||
updated = await update_session_db(session_id, **kwargs)
|
||
if not updated:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, **kwargs}
|
||
|
||
|
||
@router.delete("/sessions/{session_id}")
|
||
async def delete_session(session_id: str):
|
||
"""Delete a session."""
|
||
_cache.pop(session_id, None)
|
||
deleted = await delete_session_db(session_id)
|
||
if not deleted:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "deleted": True}
|
||
|
||
|
||
@router.delete("/sessions")
|
||
async def delete_all_sessions():
|
||
"""Delete ALL sessions (cleanup)."""
|
||
_cache.clear()
|
||
count = await delete_all_sessions_db()
|
||
return {"deleted_count": count}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/create-box-sessions")
|
||
async def create_box_sessions(session_id: str):
|
||
"""Create sub-sessions for each detected box region.
|
||
|
||
Crops box regions from the cropped/dewarped image and creates
|
||
independent sub-sessions that can be processed through the pipeline.
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
if not column_result:
|
||
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||
|
||
zones = column_result.get("zones") or []
|
||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||
if not box_zones:
|
||
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||
|
||
# Check for existing sub-sessions
|
||
existing = await get_sub_sessions(session_id)
|
||
if existing:
|
||
return {
|
||
"session_id": session_id,
|
||
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||
"message": f"{len(existing)} sub-session(s) already exist",
|
||
}
|
||
|
||
# Load base image
|
||
base_png = await get_session_image(session_id, "cropped")
|
||
if not base_png:
|
||
base_png = await get_session_image(session_id, "dewarped")
|
||
if not base_png:
|
||
raise HTTPException(status_code=400, detail="No base image available")
|
||
|
||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
parent_name = session.get("name", "Session")
|
||
created = []
|
||
|
||
for i, zone in enumerate(box_zones):
|
||
box = zone["box"]
|
||
bx, by = box["x"], box["y"]
|
||
bw, bh = box["width"], box["height"]
|
||
|
||
# Crop box region with small padding
|
||
pad = 5
|
||
y1 = max(0, by - pad)
|
||
y2 = min(img.shape[0], by + bh + pad)
|
||
x1 = max(0, bx - pad)
|
||
x2 = min(img.shape[1], bx + bw + pad)
|
||
crop = img[y1:y2, x1:x2]
|
||
|
||
# Encode as PNG
|
||
success, png_buf = cv2.imencode(".png", crop)
|
||
if not success:
|
||
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||
continue
|
||
|
||
sub_id = str(uuid.uuid4())
|
||
sub_name = f"{parent_name} — Box {i + 1}"
|
||
|
||
await create_session_db(
|
||
session_id=sub_id,
|
||
name=sub_name,
|
||
filename=session.get("filename", "box-crop.png"),
|
||
original_png=png_buf.tobytes(),
|
||
parent_session_id=session_id,
|
||
box_index=i,
|
||
)
|
||
|
||
# Cache the BGR for immediate processing
|
||
# Promote original to cropped so column/row/word detection finds it
|
||
box_bgr = crop.copy()
|
||
_cache[sub_id] = {
|
||
"id": sub_id,
|
||
"filename": session.get("filename", "box-crop.png"),
|
||
"name": sub_name,
|
||
"parent_session_id": session_id,
|
||
"original_bgr": box_bgr,
|
||
"oriented_bgr": None,
|
||
"cropped_bgr": box_bgr,
|
||
"deskewed_bgr": None,
|
||
"dewarped_bgr": None,
|
||
"orientation_result": None,
|
||
"crop_result": None,
|
||
"deskew_result": None,
|
||
"dewarp_result": None,
|
||
"ground_truth": {},
|
||
"current_step": 1,
|
||
}
|
||
|
||
created.append({
|
||
"id": sub_id,
|
||
"name": sub_name,
|
||
"box_index": i,
|
||
"box": box,
|
||
"image_width": crop.shape[1],
|
||
"image_height": crop.shape[0],
|
||
})
|
||
|
||
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"sub_sessions": created,
|
||
"total": len(created),
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/thumbnail")
|
||
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||
"""Return a small thumbnail of the original image."""
|
||
original_png = await get_session_image(session_id, "original")
|
||
if not original_png:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
h, w = img.shape[:2]
|
||
scale = size / max(h, w)
|
||
new_w, new_h = int(w * scale), int(h * scale)
|
||
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||
_, png_bytes = cv2.imencode(".png", thumb)
|
||
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||
headers={"Cache-Control": "public, max-age=3600"})
|
||
|
||
|
||
@router.get("/sessions/{session_id}/pipeline-log")
|
||
async def get_pipeline_log(session_id: str):
|
||
"""Get the pipeline execution log for a session."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||
|
||
|
||
@router.get("/categories")
|
||
async def list_categories():
|
||
"""List valid document categories."""
|
||
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pipeline Log Helper
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _append_pipeline_log(
|
||
session_id: str,
|
||
step_name: str,
|
||
metrics: Dict[str, Any],
|
||
success: bool = True,
|
||
duration_ms: Optional[int] = None,
|
||
):
|
||
"""Append a step entry to the session's pipeline_log JSONB."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
return
|
||
log = session.get("pipeline_log") or {"steps": []}
|
||
if not isinstance(log, dict):
|
||
log = {"steps": []}
|
||
entry = {
|
||
"step": step_name,
|
||
"completed_at": datetime.utcnow().isoformat(),
|
||
"success": success,
|
||
"metrics": metrics,
|
||
}
|
||
if duration_ms is not None:
|
||
entry["duration_ms"] = duration_ms
|
||
log.setdefault("steps", []).append(entry)
|
||
await update_session_db(session_id, pipeline_log=log)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Image Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions/{session_id}/image/{image_type}")
|
||
async def get_image(session_id: str, image_type: str):
|
||
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||
if image_type not in valid_types:
|
||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||
|
||
if image_type == "structure-overlay":
|
||
return await _get_structure_overlay(session_id)
|
||
|
||
if image_type == "columns-overlay":
|
||
return await _get_columns_overlay(session_id)
|
||
|
||
if image_type == "rows-overlay":
|
||
return await _get_rows_overlay(session_id)
|
||
|
||
if image_type == "words-overlay":
|
||
return await _get_words_overlay(session_id)
|
||
|
||
# Try cache first for fast serving
|
||
cached = _cache.get(session_id)
|
||
if cached:
|
||
png_key = f"{image_type}_png" if image_type != "original" else None
|
||
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||
|
||
# For binarized, check if we have it cached as PNG
|
||
if image_type == "binarized" and cached.get("binarized_png"):
|
||
return Response(content=cached["binarized_png"], media_type="image/png")
|
||
|
||
# Load from DB — for cropped/dewarped, fall back through the chain
|
||
if image_type in ("cropped", "dewarped"):
|
||
data = await _get_base_image_png(session_id)
|
||
else:
|
||
data = await get_session_image(session_id, image_type)
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||
|
||
return Response(content=data, media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Deskew Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/deskew")
|
||
async def auto_deskew(session_id: str):
|
||
"""Two-pass deskew: iterative projection (wide range) + word-alignment residual."""
|
||
# Ensure session is in cache
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
# Deskew runs right after orientation — use oriented image, fall back to original
|
||
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
||
if (v := cached.get(k)) is not None), None)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
||
|
||
t0 = time.time()
|
||
|
||
# Two-pass deskew: iterative (±5°) + word-alignment residual check
|
||
deskewed_bgr, angle_applied, two_pass_debug = deskew_two_pass(img_bgr.copy())
|
||
|
||
# Also run individual methods for reporting (non-authoritative)
|
||
try:
|
||
_, angle_hough = deskew_image(img_bgr.copy())
|
||
except Exception:
|
||
angle_hough = 0.0
|
||
|
||
success_enc, png_orig = cv2.imencode(".png", img_bgr)
|
||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||
try:
|
||
_, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||
except Exception:
|
||
angle_wa = 0.0
|
||
|
||
angle_iterative = two_pass_debug.get("pass1_angle", 0.0)
|
||
angle_residual = two_pass_debug.get("pass2_angle", 0.0)
|
||
angle_textline = two_pass_debug.get("pass3_angle", 0.0)
|
||
|
||
duration = time.time() - t0
|
||
|
||
method_used = "three_pass" if abs(angle_textline) >= 0.01 else (
|
||
"two_pass" if abs(angle_residual) >= 0.01 else "iterative"
|
||
)
|
||
|
||
# Encode as PNG
|
||
success, deskewed_png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
deskewed_png = deskewed_png_buf.tobytes() if success else b""
|
||
|
||
# Create binarized version
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(deskewed_bgr)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception as e:
|
||
logger.warning(f"Binarization failed: {e}")
|
||
|
||
confidence = max(0.5, 1.0 - abs(angle_applied) / 5.0)
|
||
|
||
deskew_result = {
|
||
"angle_hough": round(angle_hough, 3),
|
||
"angle_word_alignment": round(angle_wa, 3),
|
||
"angle_iterative": round(angle_iterative, 3),
|
||
"angle_residual": round(angle_residual, 3),
|
||
"angle_textline": round(angle_textline, 3),
|
||
"angle_applied": round(angle_applied, 3),
|
||
"method_used": method_used,
|
||
"confidence": round(confidence, 2),
|
||
"duration_seconds": round(duration, 2),
|
||
"two_pass_debug": two_pass_debug,
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = deskewed_bgr
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
"current_step": 3,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: deskew session {session_id}: "
|
||
f"hough={angle_hough:.2f} wa={angle_wa:.2f} "
|
||
f"iter={angle_iterative:.2f} residual={angle_residual:.2f} "
|
||
f"textline={angle_textline:.2f} "
|
||
f"-> {method_used} total={angle_applied:.2f}")
|
||
|
||
await _append_pipeline_log(session_id, "deskew", {
|
||
"angle_applied": round(angle_applied, 3),
|
||
"angle_iterative": round(angle_iterative, 3),
|
||
"angle_residual": round(angle_residual, 3),
|
||
"angle_textline": round(angle_textline, 3),
|
||
"confidence": round(confidence, 2),
|
||
"method": method_used,
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**deskew_result,
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
"binarized_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/binarized",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/deskew/manual")
|
||
async def manual_deskew(session_id: str, req: ManualDeskewRequest):
|
||
"""Apply a manual rotation angle to the oriented image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = next((v for k in ("oriented_bgr", "original_bgr")
|
||
if (v := cached.get(k)) is not None), None)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="No image available for deskewing")
|
||
|
||
angle = max(-5.0, min(5.0, req.angle))
|
||
|
||
h, w = img_bgr.shape[:2]
|
||
center = (w // 2, h // 2)
|
||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||
rotated = cv2.warpAffine(img_bgr, M, (w, h),
|
||
flags=cv2.INTER_LINEAR,
|
||
borderMode=cv2.BORDER_REPLICATE)
|
||
|
||
success, png_buf = cv2.imencode(".png", rotated)
|
||
deskewed_png = png_buf.tobytes() if success else b""
|
||
|
||
# Binarize
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(rotated)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception:
|
||
pass
|
||
|
||
deskew_result = {
|
||
**(cached.get("deskew_result") or {}),
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = rotated
|
||
cached["binarized_png"] = binarized_png
|
||
cached["deskew_result"] = deskew_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"deskewed_png": deskewed_png,
|
||
"deskew_result": deskew_result,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: manual deskew session {session_id}: {angle:.2f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"angle_applied": round(angle, 3),
|
||
"method_used": "manual",
|
||
"deskewed_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/deskewed",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/deskew")
|
||
async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthRequest):
|
||
"""Save ground truth feedback for the deskew step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_angle": req.corrected_angle,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"deskew_result": session.get("deskew_result"),
|
||
}
|
||
ground_truth["deskew"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
# Update cache
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth deskew session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_angle={req.corrected_angle}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dewarp Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||
|
||
The VLM is shown the image and asked: are the column/table borders tilted?
|
||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||
"""
|
||
import httpx
|
||
import base64
|
||
import re
|
||
|
||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||
|
||
prompt = (
|
||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||
"Are they perfectly vertical, or do they tilt slightly? "
|
||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||
)
|
||
|
||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"images": [img_b64],
|
||
"stream": False,
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||
resp.raise_for_status()
|
||
text = resp.json().get("response", "")
|
||
|
||
# Parse JSON from response (may have surrounding text)
|
||
match = re.search(r'\{[^}]+\}', text)
|
||
if match:
|
||
import json
|
||
data = json.loads(match.group(0))
|
||
shear = float(data.get("shear_degrees", 0.0))
|
||
conf = float(data.get("confidence", 0.0))
|
||
# Clamp to reasonable range
|
||
shear = max(-3.0, min(3.0, shear))
|
||
conf = max(0.0, min(1.0, conf))
|
||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||
except Exception as e:
|
||
logger.warning(f"VLM dewarp failed: {e}")
|
||
|
||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/dewarp")
|
||
async def auto_dewarp(
|
||
session_id: str,
|
||
method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"),
|
||
):
|
||
"""Detect and correct vertical shear on the deskewed image.
|
||
|
||
Methods:
|
||
- **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough)
|
||
- **cv**: CV ensemble only (same as ensemble)
|
||
- **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually
|
||
"""
|
||
if method not in ("ensemble", "cv", "vlm"):
|
||
raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm")
|
||
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
t0 = time.time()
|
||
|
||
if method == "vlm":
|
||
# Encode deskewed image to PNG for VLM
|
||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
img_bytes = png_buf.tobytes() if success else b""
|
||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||
shear_deg = vlm_det["shear_degrees"]
|
||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||
else:
|
||
dewarped_bgr = deskewed_bgr
|
||
dewarp_info = {
|
||
"method": vlm_det["method"],
|
||
"shear_degrees": shear_deg,
|
||
"confidence": vlm_det["confidence"],
|
||
"detections": [vlm_det],
|
||
}
|
||
else:
|
||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||
|
||
duration = time.time() - t0
|
||
|
||
# Encode as PNG
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
"method_used": dewarp_info["method"],
|
||
"shear_degrees": dewarp_info["shear_degrees"],
|
||
"confidence": dewarp_info["confidence"],
|
||
"duration_seconds": round(duration, 2),
|
||
"detections": dewarp_info.get("detections", []),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||
current_step=4,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: dewarp session {session_id}: "
|
||
f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} "
|
||
f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)")
|
||
|
||
await _append_pipeline_log(session_id, "dewarp", {
|
||
"shear_degrees": dewarp_info["shear_degrees"],
|
||
"confidence": dewarp_info["confidence"],
|
||
"method": dewarp_info["method"],
|
||
"ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])],
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**dewarp_result,
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/dewarp/manual")
|
||
async def manual_dewarp(session_id: str, req: ManualDewarpRequest):
|
||
"""Apply shear correction with a manual angle."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp")
|
||
|
||
shear_deg = max(-2.0, min(2.0, req.shear_degrees))
|
||
|
||
if abs(shear_deg) < 0.001:
|
||
dewarped_bgr = deskewed_bgr
|
||
else:
|
||
dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg)
|
||
|
||
success, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
dewarp_result = {
|
||
**(cached.get("dewarp_result") or {}),
|
||
"method_used": "manual",
|
||
"shear_degrees": round(shear_deg, 3),
|
||
}
|
||
|
||
# Update cache
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
)
|
||
|
||
logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"shear_degrees": round(shear_deg, 3),
|
||
"method_used": "manual",
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/adjust-combined")
|
||
async def adjust_combined(session_id: str, req: CombinedAdjustRequest):
|
||
"""Apply rotation + shear combined to the original image.
|
||
|
||
Used by the fine-tuning sliders to preview arbitrary rotation/shear
|
||
combinations without re-running the full deskew/dewarp pipeline.
|
||
"""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("original_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Original image not available")
|
||
|
||
rotation = max(-15.0, min(15.0, req.rotation_degrees))
|
||
shear_deg = max(-5.0, min(5.0, req.shear_degrees))
|
||
|
||
h, w = img_bgr.shape[:2]
|
||
result_bgr = img_bgr
|
||
|
||
# Step 1: Apply rotation
|
||
if abs(rotation) >= 0.001:
|
||
center = (w // 2, h // 2)
|
||
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
|
||
result_bgr = cv2.warpAffine(result_bgr, M, (w, h),
|
||
flags=cv2.INTER_LINEAR,
|
||
borderMode=cv2.BORDER_REPLICATE)
|
||
|
||
# Step 2: Apply shear
|
||
if abs(shear_deg) >= 0.001:
|
||
result_bgr = dewarp_image_manual(result_bgr, shear_deg)
|
||
|
||
# Encode
|
||
success, png_buf = cv2.imencode(".png", result_bgr)
|
||
dewarped_png = png_buf.tobytes() if success else b""
|
||
|
||
# Binarize
|
||
binarized_png = None
|
||
try:
|
||
binarized = create_ocr_image(result_bgr)
|
||
success_bin, bin_buf = cv2.imencode(".png", binarized)
|
||
binarized_png = bin_buf.tobytes() if success_bin else None
|
||
except Exception:
|
||
pass
|
||
|
||
# Build combined result dicts
|
||
deskew_result = {
|
||
**(cached.get("deskew_result") or {}),
|
||
"angle_applied": round(rotation, 3),
|
||
"method_used": "manual_combined",
|
||
}
|
||
dewarp_result = {
|
||
**(cached.get("dewarp_result") or {}),
|
||
"method_used": "manual_combined",
|
||
"shear_degrees": round(shear_deg, 3),
|
||
}
|
||
|
||
# Update cache
|
||
cached["deskewed_bgr"] = result_bgr
|
||
cached["dewarped_bgr"] = result_bgr
|
||
cached["deskew_result"] = deskew_result
|
||
cached["dewarp_result"] = dewarp_result
|
||
|
||
# Persist to DB
|
||
db_update = {
|
||
"dewarped_png": dewarped_png,
|
||
"deskew_result": deskew_result,
|
||
"dewarp_result": dewarp_result,
|
||
}
|
||
if binarized_png:
|
||
db_update["binarized_png"] = binarized_png
|
||
db_update["deskewed_png"] = dewarped_png
|
||
await update_session_db(session_id, **db_update)
|
||
|
||
logger.info(f"OCR Pipeline: combined adjust session {session_id}: "
|
||
f"rotation={rotation:.3f} shear={shear_deg:.3f}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"rotation_degrees": round(rotation, 3),
|
||
"shear_degrees": round(shear_deg, 3),
|
||
"method_used": "manual_combined",
|
||
"dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped",
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/dewarp")
|
||
async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest):
|
||
"""Save ground truth feedback for the dewarp step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_shear": req.corrected_shear,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"dewarp_result": session.get("dewarp_result"),
|
||
}
|
||
ground_truth["dewarp"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: "
|
||
f"correct={req.is_correct}, corrected_shear={req.corrected_shear}")
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Document Type Detection (between Dewarp and Columns)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/detect-type")
|
||
async def detect_type(session_id: str):
|
||
"""Detect document type (vocab_table, full_text, generic_table).
|
||
|
||
Should be called after crop (clean image available).
|
||
Falls back to dewarped if crop was skipped.
|
||
Stores result in session for frontend to decide pipeline flow.
|
||
"""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||
|
||
t0 = time.time()
|
||
ocr_img = create_ocr_image(img_bgr)
|
||
result = detect_document_type(ocr_img, img_bgr)
|
||
duration = time.time() - t0
|
||
|
||
result_dict = {
|
||
"doc_type": result.doc_type,
|
||
"confidence": result.confidence,
|
||
"pipeline": result.pipeline,
|
||
"skip_steps": result.skip_steps,
|
||
"features": result.features,
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
doc_type=result.doc_type,
|
||
doc_type_result=result_dict,
|
||
)
|
||
|
||
cached["doc_type_result"] = result_dict
|
||
|
||
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||
|
||
await _append_pipeline_log(session_id, "detect_type", {
|
||
"doc_type": result.doc_type,
|
||
"pipeline": result.pipeline,
|
||
"confidence": result.confidence,
|
||
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {"session_id": session_id, **result_dict}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Structure Detection Endpoint
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/detect-structure")
|
||
async def detect_structure(session_id: str):
|
||
"""Detect document structure: boxes, zones, and color regions.
|
||
|
||
Runs box detection (line + shading) and color analysis on the cropped
|
||
image. Returns structured JSON with all detected elements for the
|
||
structure visualization step.
|
||
"""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = (
|
||
cached.get("cropped_bgr")
|
||
if cached.get("cropped_bgr") is not None
|
||
else cached.get("dewarped_bgr")
|
||
)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||
|
||
t0 = time.time()
|
||
h, w = img_bgr.shape[:2]
|
||
|
||
# --- Content bounds from word result (if available) or full image ---
|
||
word_result = cached.get("word_result")
|
||
words: List[Dict] = []
|
||
if word_result and word_result.get("cells"):
|
||
for cell in word_result["cells"]:
|
||
for wb in (cell.get("word_boxes") or []):
|
||
words.append(wb)
|
||
# If no words yet, use image dimensions with small margin
|
||
if words:
|
||
content_x = max(0, min(int(wb["left"]) for wb in words))
|
||
content_y = max(0, min(int(wb["top"]) for wb in words))
|
||
content_r = min(w, max(int(wb["left"] + wb["width"]) for wb in words))
|
||
content_b = min(h, max(int(wb["top"] + wb["height"]) for wb in words))
|
||
content_w_px = content_r - content_x
|
||
content_h_px = content_b - content_y
|
||
else:
|
||
margin = int(min(w, h) * 0.03)
|
||
content_x, content_y = margin, margin
|
||
content_w_px = w - 2 * margin
|
||
content_h_px = h - 2 * margin
|
||
|
||
# --- Box detection ---
|
||
boxes = detect_boxes(
|
||
img_bgr,
|
||
content_x=content_x,
|
||
content_w=content_w_px,
|
||
content_y=content_y,
|
||
content_h=content_h_px,
|
||
)
|
||
|
||
# --- Zone splitting ---
|
||
from cv_box_detect import split_page_into_zones as _split_zones
|
||
zones = _split_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||
|
||
# --- Color region sampling ---
|
||
# Sample background shading in each detected box
|
||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||
box_colors = []
|
||
for box in boxes:
|
||
# Sample the center region of each box
|
||
cy1 = box.y + box.height // 4
|
||
cy2 = box.y + 3 * box.height // 4
|
||
cx1 = box.x + box.width // 4
|
||
cx2 = box.x + 3 * box.width // 4
|
||
cy1 = max(0, min(cy1, h - 1))
|
||
cy2 = max(0, min(cy2, h - 1))
|
||
cx1 = max(0, min(cx1, w - 1))
|
||
cx2 = max(0, min(cx2, w - 1))
|
||
if cy2 > cy1 and cx2 > cx1:
|
||
roi_hsv = hsv[cy1:cy2, cx1:cx2]
|
||
med_h = float(np.median(roi_hsv[:, :, 0]))
|
||
med_s = float(np.median(roi_hsv[:, :, 1]))
|
||
med_v = float(np.median(roi_hsv[:, :, 2]))
|
||
if med_s > 15:
|
||
from cv_color_detect import _hue_to_color_name
|
||
bg_name = _hue_to_color_name(med_h)
|
||
bg_hex = _COLOR_HEX.get(bg_name, "#6b7280")
|
||
else:
|
||
bg_name = "gray" if med_v < 220 else "white"
|
||
bg_hex = "#6b7280" if bg_name == "gray" else "#ffffff"
|
||
else:
|
||
bg_name = "unknown"
|
||
bg_hex = "#6b7280"
|
||
box_colors.append({"color_name": bg_name, "color_hex": bg_hex})
|
||
|
||
# --- Color text detection overview ---
|
||
# Quick scan for colored text regions across the page
|
||
color_summary: Dict[str, int] = {}
|
||
for color_name, ranges in _COLOR_RANGES.items():
|
||
mask = np.zeros((h, w), dtype=np.uint8)
|
||
for lower, upper in ranges:
|
||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||
pixel_count = int(np.sum(mask > 0))
|
||
if pixel_count > 50: # minimum threshold
|
||
color_summary[color_name] = pixel_count
|
||
|
||
# --- Graphic element detection ---
|
||
box_dicts = [
|
||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height}
|
||
for b in boxes
|
||
]
|
||
graphics = detect_graphic_elements(
|
||
img_bgr, words,
|
||
detected_boxes=box_dicts,
|
||
)
|
||
|
||
duration = time.time() - t0
|
||
|
||
result_dict = {
|
||
"image_width": w,
|
||
"image_height": h,
|
||
"content_bounds": {
|
||
"x": content_x, "y": content_y,
|
||
"w": content_w_px, "h": content_h_px,
|
||
},
|
||
"boxes": [
|
||
{
|
||
"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||
"confidence": b.confidence,
|
||
"border_thickness": b.border_thickness,
|
||
"bg_color_name": box_colors[i]["color_name"],
|
||
"bg_color_hex": box_colors[i]["color_hex"],
|
||
}
|
||
for i, b in enumerate(boxes)
|
||
],
|
||
"zones": [
|
||
{
|
||
"index": z.index,
|
||
"zone_type": z.zone_type,
|
||
"y": z.y, "h": z.height,
|
||
"x": z.x, "w": z.width,
|
||
}
|
||
for z in zones
|
||
],
|
||
"graphics": [
|
||
{
|
||
"x": g.x, "y": g.y, "w": g.width, "h": g.height,
|
||
"area": g.area,
|
||
"shape": g.shape,
|
||
"color_name": g.color_name,
|
||
"color_hex": g.color_hex,
|
||
"confidence": round(g.confidence, 2),
|
||
}
|
||
for g in graphics
|
||
],
|
||
"color_pixel_counts": color_summary,
|
||
"has_words": len(words) > 0,
|
||
"word_count": len(words),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to session
|
||
await update_session_db(session_id, structure_result=result_dict)
|
||
cached["structure_result"] = result_dict
|
||
|
||
logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs",
|
||
session_id, len(boxes), len(zones), len(graphics), duration)
|
||
|
||
return {"session_id": session_id, **result_dict}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Column Detection Endpoints (Step 3)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/columns")
|
||
async def detect_columns(session_id: str):
|
||
"""Run column detection on the cropped (or dewarped) image."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before column detection")
|
||
|
||
# -----------------------------------------------------------------------
|
||
# Sub-sessions (box crops): skip column detection entirely.
|
||
# Instead, create a single pseudo-column spanning the full image width.
|
||
# Also run Tesseract + binarization here so that the row detection step
|
||
# can reuse the cached intermediates (_word_dicts, _inv, _content_bounds)
|
||
# instead of falling back to detect_column_geometry() which may fail
|
||
# on small box images with < 5 words.
|
||
# -----------------------------------------------------------------------
|
||
session = await get_session_db(session_id)
|
||
if session and session.get("parent_session_id"):
|
||
h, w = img_bgr.shape[:2]
|
||
|
||
# Binarize + invert for row detection (horizontal projection profile)
|
||
ocr_img = create_ocr_image(img_bgr)
|
||
inv = cv2.bitwise_not(ocr_img)
|
||
|
||
# Run Tesseract to get word bounding boxes.
|
||
# Word positions are relative to the full image (no ROI crop needed
|
||
# because the sub-session image IS the cropped box already).
|
||
# detect_row_geometry expects word positions relative to content ROI,
|
||
# so with content_bounds = (0, w, 0, h) the coordinates are correct.
|
||
try:
|
||
from PIL import Image as PILImage
|
||
pil_img = PILImage.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||
import pytesseract
|
||
data = pytesseract.image_to_data(pil_img, lang='eng+deu', output_type=pytesseract.Output.DICT)
|
||
word_dicts = []
|
||
for i in range(len(data['text'])):
|
||
conf = int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1
|
||
text = str(data['text'][i]).strip()
|
||
if conf < 30 or not text:
|
||
continue
|
||
word_dicts.append({
|
||
'text': text, 'conf': conf,
|
||
'left': int(data['left'][i]),
|
||
'top': int(data['top'][i]),
|
||
'width': int(data['width'][i]),
|
||
'height': int(data['height'][i]),
|
||
})
|
||
# Log all words including low-confidence ones for debugging
|
||
all_count = sum(1 for i in range(len(data['text']))
|
||
if str(data['text'][i]).strip())
|
||
low_conf = [(str(data['text'][i]).strip(), int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1)
|
||
for i in range(len(data['text']))
|
||
if str(data['text'][i]).strip()
|
||
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) < 30
|
||
and (int(data['conf'][i]) if str(data['conf'][i]).lstrip('-').isdigit() else -1) >= 0]
|
||
if low_conf:
|
||
logger.info(f"OCR Pipeline: sub-session {session_id}: {len(low_conf)} words below conf 30: {low_conf[:20]}")
|
||
logger.info(f"OCR Pipeline: sub-session {session_id}: Tesseract found {len(word_dicts)}/{all_count} words (conf>=30)")
|
||
except Exception as e:
|
||
logger.warning(f"OCR Pipeline: sub-session {session_id}: Tesseract failed: {e}")
|
||
word_dicts = []
|
||
|
||
# Cache intermediates for row detection (detect_rows reuses these)
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (0, w, 0, h)
|
||
|
||
column_result = {
|
||
"columns": [{
|
||
"type": "column_text",
|
||
"x": 0, "y": 0,
|
||
"width": w, "height": h,
|
||
}],
|
||
"zones": None,
|
||
"boxes_detected": 0,
|
||
"duration_seconds": 0,
|
||
"method": "sub_session_pseudo_column",
|
||
}
|
||
await update_session_db(
|
||
session_id,
|
||
column_result=column_result,
|
||
row_result=None,
|
||
word_result=None,
|
||
current_step=6,
|
||
)
|
||
cached["column_result"] = column_result
|
||
cached.pop("row_result", None)
|
||
cached.pop("word_result", None)
|
||
logger.info(f"OCR Pipeline: sub-session {session_id}: pseudo-column {w}x{h}px")
|
||
return {"session_id": session_id, **column_result}
|
||
|
||
t0 = time.time()
|
||
|
||
# Binarized image for layout analysis
|
||
ocr_img = create_ocr_image(img_bgr)
|
||
h, w = ocr_img.shape[:2]
|
||
|
||
# Phase A: Zone-aware geometry detection
|
||
zoned_result = detect_column_geometry_zoned(ocr_img, img_bgr)
|
||
|
||
if zoned_result is None:
|
||
# Fallback to projection-based layout
|
||
layout_img = create_layout_image(img_bgr)
|
||
regions = analyze_layout(layout_img, ocr_img)
|
||
zones_data = None
|
||
boxes_detected = 0
|
||
else:
|
||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv, zones_data, boxes = zoned_result
|
||
content_w = right_x - left_x
|
||
boxes_detected = len(boxes)
|
||
|
||
# Cache intermediates for row detection (avoids second Tesseract run)
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
cached["_zones_data"] = zones_data
|
||
cached["_boxes_detected"] = boxes_detected
|
||
|
||
# Detect header/footer early so sub-column clustering ignores them
|
||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||
|
||
# Split sub-columns (e.g. page references) before classification
|
||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||
|
||
# Expand narrow columns (sub-columns are often very narrow)
|
||
geometries = expand_narrow_columns(geometries, content_w, left_x, word_dicts)
|
||
|
||
# Phase B: Content-based classification
|
||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||
left_x=left_x, right_x=right_x, inv=inv)
|
||
|
||
duration = time.time() - t0
|
||
|
||
columns = [asdict(r) for r in regions]
|
||
|
||
# Determine classification methods used
|
||
methods = list(set(
|
||
c.get("classification_method", "") for c in columns
|
||
if c.get("classification_method")
|
||
))
|
||
|
||
column_result = {
|
||
"columns": columns,
|
||
"classification_methods": methods,
|
||
"duration_seconds": round(duration, 2),
|
||
"boxes_detected": boxes_detected,
|
||
}
|
||
|
||
# Add zone data when boxes are present
|
||
if zones_data and boxes_detected > 0:
|
||
column_result["zones"] = zones_data
|
||
|
||
# Persist to DB — also invalidate downstream results (rows, words)
|
||
await update_session_db(
|
||
session_id,
|
||
column_result=column_result,
|
||
row_result=None,
|
||
word_result=None,
|
||
current_step=6,
|
||
)
|
||
|
||
# Update cache
|
||
cached["column_result"] = column_result
|
||
cached.pop("row_result", None)
|
||
cached.pop("word_result", None)
|
||
|
||
col_count = len([c for c in columns if c["type"].startswith("column")])
|
||
logger.info(f"OCR Pipeline: columns session {session_id}: "
|
||
f"{col_count} columns detected, {boxes_detected} box(es) ({duration:.2f}s)")
|
||
|
||
img_w = img_bgr.shape[1]
|
||
await _append_pipeline_log(session_id, "columns", {
|
||
"total_columns": len(columns),
|
||
"column_widths_pct": [round(c["width"] / img_w * 100, 1) for c in columns],
|
||
"column_types": [c["type"] for c in columns],
|
||
"boxes_detected": boxes_detected,
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**column_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/columns/manual")
|
||
async def set_manual_columns(session_id: str, req: ManualColumnsRequest):
|
||
"""Override detected columns with manual definitions."""
|
||
column_result = {
|
||
"columns": req.columns,
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, column_result=column_result,
|
||
row_result=None, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["column_result"] = column_result
|
||
_cache[session_id].pop("row_result", None)
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual columns session {session_id}: "
|
||
f"{len(req.columns)} columns set")
|
||
|
||
return {"session_id": session_id, **column_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/columns")
|
||
async def save_column_ground_truth(session_id: str, req: ColumnGroundTruthRequest):
|
||
"""Save ground truth feedback for the column detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_columns": req.corrected_columns,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"column_result": session.get("column_result"),
|
||
}
|
||
ground_truth["columns"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/columns")
|
||
async def get_column_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for column detection, including auto vs GT diff."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
columns_gt = ground_truth.get("columns")
|
||
if not columns_gt:
|
||
raise HTTPException(status_code=404, detail="No column ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"columns_gt": columns_gt,
|
||
"columns_auto": session.get("column_result"),
|
||
}
|
||
|
||
|
||
def _draw_box_exclusion_overlay(
|
||
img: np.ndarray,
|
||
zones: List[Dict],
|
||
*,
|
||
label: str = "BOX — separat verarbeitet",
|
||
) -> None:
|
||
"""Draw red semi-transparent rectangles over box zones (in-place).
|
||
|
||
Reusable for columns, rows, and words overlays.
|
||
"""
|
||
for zone in zones:
|
||
if zone.get("zone_type") != "box" or not zone.get("box"):
|
||
continue
|
||
box = zone["box"]
|
||
bx, by = box["x"], box["y"]
|
||
bw, bh = box["width"], box["height"]
|
||
|
||
# Red semi-transparent fill (~25 %)
|
||
box_overlay = img.copy()
|
||
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
|
||
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
|
||
|
||
# Border
|
||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
|
||
|
||
# Label
|
||
cv2.putText(img, label, (bx + 10, by + bh - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
||
|
||
|
||
async def _get_structure_overlay(session_id: str) -> Response:
|
||
"""Generate overlay image showing detected boxes, zones, and color regions."""
|
||
base_png = await _get_base_image_png(session_id)
|
||
if not base_png:
|
||
raise HTTPException(status_code=404, detail="No base image available")
|
||
|
||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
h, w = img.shape[:2]
|
||
|
||
# Get structure result (run detection if not cached)
|
||
session = await get_session_db(session_id)
|
||
structure = (session or {}).get("structure_result")
|
||
|
||
if not structure:
|
||
# Run detection on-the-fly
|
||
margin = int(min(w, h) * 0.03)
|
||
content_x, content_y = margin, margin
|
||
content_w_px = w - 2 * margin
|
||
content_h_px = h - 2 * margin
|
||
boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
|
||
zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
|
||
structure = {
|
||
"boxes": [
|
||
{"x": b.x, "y": b.y, "w": b.width, "h": b.height,
|
||
"confidence": b.confidence, "border_thickness": b.border_thickness}
|
||
for b in boxes
|
||
],
|
||
"zones": [
|
||
{"index": z.index, "zone_type": z.zone_type,
|
||
"y": z.y, "h": z.height, "x": z.x, "w": z.width}
|
||
for z in zones
|
||
],
|
||
}
|
||
|
||
overlay = img.copy()
|
||
|
||
# --- Draw zone boundaries ---
|
||
zone_colors = {
|
||
"content": (200, 200, 200), # light gray
|
||
"box": (255, 180, 0), # blue-ish (BGR)
|
||
}
|
||
for zone in structure.get("zones", []):
|
||
zx = zone["x"]
|
||
zy = zone["y"]
|
||
zw = zone["w"]
|
||
zh = zone["h"]
|
||
color = zone_colors.get(zone["zone_type"], (200, 200, 200))
|
||
|
||
# Draw zone boundary as dashed line
|
||
dash_len = 12
|
||
for edge_x in range(zx, zx + zw, dash_len * 2):
|
||
end_x = min(edge_x + dash_len, zx + zw)
|
||
cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
|
||
cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
|
||
|
||
# Zone label
|
||
zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
|
||
cv2.putText(img, zone_label, (zx + 5, zy + 15),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
|
||
|
||
# --- Draw detected boxes ---
|
||
# Color map for box backgrounds (BGR)
|
||
bg_hex_to_bgr = {
|
||
"#dc2626": (38, 38, 220), # red
|
||
"#2563eb": (235, 99, 37), # blue
|
||
"#16a34a": (74, 163, 22), # green
|
||
"#ea580c": (12, 88, 234), # orange
|
||
"#9333ea": (234, 51, 147), # purple
|
||
"#ca8a04": (4, 138, 202), # yellow
|
||
"#6b7280": (128, 114, 107), # gray
|
||
}
|
||
|
||
for box_data in structure.get("boxes", []):
|
||
bx = box_data["x"]
|
||
by = box_data["y"]
|
||
bw = box_data["w"]
|
||
bh = box_data["h"]
|
||
conf = box_data.get("confidence", 0)
|
||
thickness = box_data.get("border_thickness", 0)
|
||
bg_hex = box_data.get("bg_color_hex", "#6b7280")
|
||
bg_name = box_data.get("bg_color_name", "")
|
||
|
||
# Box fill color
|
||
fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
|
||
|
||
# Solid border
|
||
border_color = fill_bgr
|
||
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
|
||
|
||
# Label
|
||
label = f"BOX"
|
||
if bg_name and bg_name not in ("unknown", "white"):
|
||
label += f" ({bg_name})"
|
||
if thickness > 0:
|
||
label += f" border={thickness}px"
|
||
label += f" {int(conf * 100)}%"
|
||
cv2.putText(img, label, (bx + 8, by + 22),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
|
||
cv2.putText(img, label, (bx + 8, by + 22),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
|
||
|
||
# Blend overlay at 15% opacity
|
||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||
|
||
# --- Draw color regions (HSV masks) ---
|
||
hsv = cv2.cvtColor(
|
||
cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
|
||
cv2.COLOR_BGR2HSV,
|
||
)
|
||
color_bgr_map = {
|
||
"red": (0, 0, 255),
|
||
"orange": (0, 140, 255),
|
||
"yellow": (0, 200, 255),
|
||
"green": (0, 200, 0),
|
||
"blue": (255, 150, 0),
|
||
"purple": (200, 0, 200),
|
||
}
|
||
for color_name, ranges in _COLOR_RANGES.items():
|
||
mask = np.zeros((h, w), dtype=np.uint8)
|
||
for lower, upper in ranges:
|
||
mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
|
||
# Only draw if there are significant colored pixels
|
||
if np.sum(mask > 0) < 100:
|
||
continue
|
||
# Draw colored contours
|
||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
draw_color = color_bgr_map.get(color_name, (200, 200, 200))
|
||
for cnt in contours:
|
||
area = cv2.contourArea(cnt)
|
||
if area < 20:
|
||
continue
|
||
cv2.drawContours(img, [cnt], -1, draw_color, 2)
|
||
|
||
# --- Draw graphic elements ---
|
||
graphics_data = structure.get("graphics", [])
|
||
shape_icons = {
|
||
"arrow": "ARROW",
|
||
"circle": "CIRCLE",
|
||
"line": "LINE",
|
||
"exclamation": "!",
|
||
"dot": "DOT",
|
||
"icon": "ICON",
|
||
"illustration": "ILLUST",
|
||
}
|
||
for gfx in graphics_data:
|
||
gx, gy = gfx["x"], gfx["y"]
|
||
gw, gh = gfx["w"], gfx["h"]
|
||
shape = gfx.get("shape", "icon")
|
||
color_hex = gfx.get("color_hex", "#6b7280")
|
||
conf = gfx.get("confidence", 0)
|
||
|
||
# Pick draw color based on element color (BGR)
|
||
gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
|
||
|
||
# Draw bounding box (dashed style via short segments)
|
||
dash = 6
|
||
for seg_x in range(gx, gx + gw, dash * 2):
|
||
end_x = min(seg_x + dash, gx + gw)
|
||
cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
|
||
cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
|
||
for seg_y in range(gy, gy + gh, dash * 2):
|
||
end_y = min(seg_y + dash, gy + gh)
|
||
cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
|
||
cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
|
||
|
||
# Label
|
||
icon = shape_icons.get(shape, shape.upper()[:5])
|
||
label = f"{icon} {int(conf * 100)}%"
|
||
# White background for readability
|
||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
||
lx = gx + 2
|
||
ly = max(gy - 4, th + 4)
|
||
cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
|
||
cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
|
||
|
||
# Encode result
|
||
_, png_buf = cv2.imencode(".png", img)
|
||
return Response(content=png_buf.tobytes(), media_type="image/png")
|
||
|
||
|
||
async def _get_columns_overlay(session_id: str) -> Response:
|
||
"""Generate cropped (or dewarped) image with column borders drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise HTTPException(status_code=404, detail="No column data available")
|
||
|
||
# Load best available base image (cropped > dewarped > original)
|
||
base_png = await _get_base_image_png(session_id)
|
||
if not base_png:
|
||
raise HTTPException(status_code=404, detail="No base image available")
|
||
|
||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for region types (BGR)
|
||
colors = {
|
||
"column_en": (255, 180, 0), # Blue
|
||
"column_de": (0, 200, 0), # Green
|
||
"column_example": (0, 140, 255), # Orange
|
||
"column_text": (200, 200, 0), # Cyan/Turquoise
|
||
"page_ref": (200, 0, 200), # Purple
|
||
"column_marker": (0, 0, 220), # Red
|
||
"column_ignore": (180, 180, 180), # Light Gray
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
"margin_top": (100, 100, 100), # Dark Gray
|
||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for col in column_result["columns"]:
|
||
x, y = col["x"], col["y"]
|
||
w, h = col["width"], col["height"]
|
||
color = colors.get(col.get("type", ""), (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
|
||
|
||
# Label with confidence
|
||
label = col.get("type", "unknown").replace("column_", "").upper()
|
||
conf = col.get("classification_confidence")
|
||
if conf is not None and conf < 1.0:
|
||
label = f"{label} {int(conf * 100)}%"
|
||
cv2.putText(img, label, (x + 10, y + 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||
|
||
# Blend overlay at 20% opacity
|
||
cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
|
||
|
||
# Draw detected box boundaries as dashed rectangles
|
||
zones = column_result.get("zones") or []
|
||
for zone in zones:
|
||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||
box = zone["box"]
|
||
bx, by = box["x"], box["y"]
|
||
bw, bh = box["width"], box["height"]
|
||
box_color = (0, 200, 255) # Yellow (BGR)
|
||
# Draw dashed rectangle by drawing short line segments
|
||
dash_len = 15
|
||
for edge_x in range(bx, bx + bw, dash_len * 2):
|
||
end_x = min(edge_x + dash_len, bx + bw)
|
||
cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
|
||
cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
|
||
for edge_y in range(by, by + bh, dash_len * 2):
|
||
end_y = min(edge_y + dash_len, by + bh)
|
||
cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
|
||
cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
|
||
cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
|
||
|
||
# Red semi-transparent overlay for box zones
|
||
_draw_box_exclusion_overlay(img, zones)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Row Detection Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/rows")
|
||
async def detect_rows(session_id: str):
|
||
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
|
||
|
||
t0 = time.time()
|
||
|
||
# Try to reuse cached word_dicts and inv from column detection
|
||
word_dicts = cached.get("_word_dicts")
|
||
inv = cached.get("_inv")
|
||
content_bounds = cached.get("_content_bounds")
|
||
|
||
if word_dicts is None or inv is None or content_bounds is None:
|
||
# Not cached — run column geometry to get intermediates
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||
if geo_result is None:
|
||
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
else:
|
||
left_x, right_x, top_y, bottom_y = content_bounds
|
||
|
||
# Read zones from column_result to exclude box regions
|
||
session = await get_session_db(session_id)
|
||
column_result = (session or {}).get("column_result") or {}
|
||
is_sub_session = bool((session or {}).get("parent_session_id"))
|
||
|
||
# Sub-sessions (box crops): use word-grouping instead of gap-based
|
||
# row detection. Box images are small with complex internal layouts
|
||
# (headings, sub-columns) where the horizontal projection approach
|
||
# merges rows. Word-grouping directly clusters words by Y proximity,
|
||
# which is more robust for these cases.
|
||
if is_sub_session and word_dicts:
|
||
from cv_layout import _build_rows_from_word_grouping
|
||
rows = _build_rows_from_word_grouping(
|
||
word_dicts, left_x, right_x, top_y, bottom_y,
|
||
right_x - left_x, bottom_y - top_y,
|
||
)
|
||
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
|
||
else:
|
||
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
|
||
|
||
# Collect box y-ranges for filtering.
|
||
# Use border_thickness to shrink the exclusion zone: the border pixels
|
||
# belong visually to the box frame, but text rows above/below the box
|
||
# may overlap with the border area and must not be clipped.
|
||
box_ranges = [] # [(y_start, y_end)]
|
||
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
|
||
for zone in zones:
|
||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||
box = zone["box"]
|
||
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||
box_ranges.append((box["y"], box["y"] + box["height"]))
|
||
# Inner range: shrink by border thickness so boundary rows aren't excluded
|
||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||
|
||
if box_ranges and inv is not None:
|
||
# Combined-image approach: strip box regions from inv image,
|
||
# run row detection on the combined image, then remap y-coords back.
|
||
content_strips = [] # [(y_start, y_end)] in absolute coords
|
||
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
|
||
# Using inner ranges means the border area is included in the content
|
||
# strips, so the last row above a box isn't clipped by the border.
|
||
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
|
||
strip_start = top_y
|
||
for by_start, by_end in sorted_boxes:
|
||
if by_start > strip_start:
|
||
content_strips.append((strip_start, by_start))
|
||
strip_start = max(strip_start, by_end)
|
||
if strip_start < bottom_y:
|
||
content_strips.append((strip_start, bottom_y))
|
||
|
||
# Filter to strips with meaningful height
|
||
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
|
||
|
||
if content_strips:
|
||
# Stack content strips vertically
|
||
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
|
||
combined_inv = np.vstack(inv_strips)
|
||
|
||
# Filter word_dicts to only include words from content strips
|
||
combined_words = []
|
||
cum_y = 0
|
||
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
|
||
for ys, ye in content_strips:
|
||
h = ye - ys
|
||
strip_offsets.append((cum_y, h, ys))
|
||
for w in word_dicts:
|
||
w_abs_y = w['top'] + top_y # word y is relative to content top
|
||
w_center = w_abs_y + w['height'] / 2
|
||
if ys <= w_center < ye:
|
||
# Remap to combined coordinates
|
||
w_copy = dict(w)
|
||
w_copy['top'] = cum_y + (w_abs_y - ys)
|
||
combined_words.append(w_copy)
|
||
cum_y += h
|
||
|
||
# Run row detection on combined image
|
||
combined_h = combined_inv.shape[0]
|
||
rows = detect_row_geometry(
|
||
combined_inv, combined_words, left_x, right_x, 0, combined_h,
|
||
)
|
||
|
||
# Remap y-coordinates back to absolute page coords
|
||
def _combined_y_to_abs(cy: int) -> int:
|
||
for c_start, s_h, abs_start in strip_offsets:
|
||
if cy < c_start + s_h:
|
||
return abs_start + (cy - c_start)
|
||
last_c, last_h, last_abs = strip_offsets[-1]
|
||
return last_abs + last_h
|
||
|
||
for r in rows:
|
||
abs_y = _combined_y_to_abs(r.y)
|
||
abs_y_end = _combined_y_to_abs(r.y + r.height)
|
||
r.y = abs_y
|
||
r.height = abs_y_end - abs_y
|
||
else:
|
||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||
else:
|
||
# No boxes — standard row detection
|
||
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||
|
||
duration = time.time() - t0
|
||
|
||
# Assign zone_index based on which content zone each row falls in
|
||
# Build content zone list with indices
|
||
zones = column_result.get("zones") or []
|
||
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
|
||
|
||
# Build serializable result (exclude words to keep payload small)
|
||
rows_data = []
|
||
for r in rows:
|
||
# Determine zone_index
|
||
zone_idx = 0
|
||
row_center_y = r.y + r.height / 2
|
||
for zi, zone in content_zones:
|
||
zy = zone["y"]
|
||
zh = zone["height"]
|
||
if zy <= row_center_y < zy + zh:
|
||
zone_idx = zi
|
||
break
|
||
|
||
rd = {
|
||
"index": r.index,
|
||
"x": r.x,
|
||
"y": r.y,
|
||
"width": r.width,
|
||
"height": r.height,
|
||
"word_count": r.word_count,
|
||
"row_type": r.row_type,
|
||
"gap_before": r.gap_before,
|
||
"zone_index": zone_idx,
|
||
}
|
||
rows_data.append(rd)
|
||
|
||
type_counts = {}
|
||
for r in rows:
|
||
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||
|
||
row_result = {
|
||
"rows": rows_data,
|
||
"summary": type_counts,
|
||
"total_rows": len(rows),
|
||
"duration_seconds": round(duration, 2),
|
||
}
|
||
|
||
# Persist to DB — also invalidate word_result since rows changed
|
||
await update_session_db(
|
||
session_id,
|
||
row_result=row_result,
|
||
word_result=None,
|
||
current_step=7,
|
||
)
|
||
|
||
cached["row_result"] = row_result
|
||
cached.pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||
|
||
content_rows = sum(1 for r in rows if r.row_type == "content")
|
||
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
|
||
await _append_pipeline_log(session_id, "rows", {
|
||
"total_rows": len(rows),
|
||
"content_rows": content_rows,
|
||
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
|
||
"avg_row_height_px": avg_height,
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**row_result,
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/rows/manual")
|
||
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||
"""Override detected rows with manual definitions."""
|
||
row_result = {
|
||
"rows": req.rows,
|
||
"total_rows": len(req.rows),
|
||
"duration_seconds": 0,
|
||
"method": "manual",
|
||
}
|
||
|
||
await update_session_db(session_id, row_result=row_result, word_result=None)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["row_result"] = row_result
|
||
_cache[session_id].pop("word_result", None)
|
||
|
||
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||
f"{len(req.rows)} rows set")
|
||
|
||
return {"session_id": session_id, **row_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||
"""Save ground truth feedback for the row detection step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_rows": req.corrected_rows,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"row_result": session.get("row_result"),
|
||
}
|
||
ground_truth["rows"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||
async def get_row_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for row detection."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
rows_gt = ground_truth.get("rows")
|
||
if not rows_gt:
|
||
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"rows_gt": rows_gt,
|
||
"rows_auto": session.get("row_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Word Recognition Endpoints (Step 5)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/words")
|
||
async def detect_words(
|
||
session_id: str,
|
||
request: Request,
|
||
engine: str = "auto",
|
||
pronunciation: str = "british",
|
||
stream: bool = False,
|
||
skip_heal_gaps: bool = False,
|
||
grid_method: str = "v2",
|
||
):
|
||
"""Build word grid from columns × rows, OCR each cell.
|
||
|
||
Query params:
|
||
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
skip_heal_gaps: false (default). When true, cells keep exact row geometry
|
||
positions without gap-healing expansion. Better for overlay rendering.
|
||
grid_method: 'v2' (default) or 'words_first' — grid construction strategy.
|
||
'v2' uses pre-detected columns/rows (top-down).
|
||
'words_first' clusters words bottom-up (no column/row detection needed).
|
||
"""
|
||
# PaddleOCR is full-page remote OCR → force words_first grid method
|
||
if engine == "paddle" and grid_method != "words_first":
|
||
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
||
grid_method = "words_first"
|
||
|
||
if session_id not in _cache:
|
||
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
||
await _load_session_to_cache(session_id)
|
||
cached = _get_cached(session_id)
|
||
|
||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
if dewarped_bgr is None:
|
||
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
||
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
# No column detection — synthesize a single full-page pseudo-column.
|
||
# This enables the overlay pipeline which skips column detection.
|
||
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
||
column_result = {
|
||
"columns": [{
|
||
"type": "column_text",
|
||
"x": 0, "y": 0,
|
||
"width": img_w_tmp, "height": img_h_tmp,
|
||
"classification_confidence": 1.0,
|
||
"classification_method": "full_page_fallback",
|
||
}],
|
||
"zones": [],
|
||
"duration_seconds": 0,
|
||
}
|
||
logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
||
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||
|
||
# Convert column dicts back to PageRegion objects
|
||
col_regions = [
|
||
PageRegion(
|
||
type=c["type"],
|
||
x=c["x"], y=c["y"],
|
||
width=c["width"], height=c["height"],
|
||
classification_confidence=c.get("classification_confidence", 1.0),
|
||
classification_method=c.get("classification_method", ""),
|
||
)
|
||
for c in column_result["columns"]
|
||
]
|
||
|
||
# Convert row dicts back to RowGeometry objects
|
||
row_geoms = [
|
||
RowGeometry(
|
||
index=r["index"],
|
||
x=r["x"], y=r["y"],
|
||
width=r["width"], height=r["height"],
|
||
word_count=r.get("word_count", 0),
|
||
words=[],
|
||
row_type=r.get("row_type", "content"),
|
||
gap_before=r.get("gap_before", 0),
|
||
)
|
||
for r in row_result["rows"]
|
||
]
|
||
|
||
# Cell-First OCR (v2): no full-page word re-population needed.
|
||
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
|
||
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
|
||
# so populate from cached words if available (just for counting).
|
||
word_dicts = cached.get("_word_dicts")
|
||
if word_dicts is None:
|
||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||
if geo_result is not None:
|
||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
if word_dicts:
|
||
content_bounds = cached.get("_content_bounds")
|
||
if content_bounds:
|
||
_lx, _rx, top_y, _by = content_bounds
|
||
else:
|
||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||
|
||
for row in row_geoms:
|
||
row_y_rel = row.y - top_y
|
||
row_bottom_rel = row_y_rel + row.height
|
||
row.words = [
|
||
w for w in word_dicts
|
||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||
]
|
||
row.word_count = len(row.words)
|
||
|
||
# Exclude rows that fall within box zones.
|
||
# Use inner box range (shrunk by border_thickness) so that rows at
|
||
# the boundary (overlapping with the box border) are NOT excluded.
|
||
zones = column_result.get("zones") or []
|
||
box_ranges_inner = []
|
||
for zone in zones:
|
||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||
box = zone["box"]
|
||
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||
|
||
if box_ranges_inner:
|
||
def _row_in_box(r):
|
||
center_y = r.y + r.height / 2
|
||
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
||
|
||
before_count = len(row_geoms)
|
||
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
||
excluded = before_count - len(row_geoms)
|
||
if excluded:
|
||
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
||
|
||
# --- Words-First path: bottom-up grid from word boxes ---
|
||
if grid_method == "words_first":
|
||
t0 = time.time()
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# For paddle engine: run remote PaddleOCR full-page instead of Tesseract
|
||
if engine == "paddle":
|
||
from cv_ocr_engines import ocr_region_paddle
|
||
|
||
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
||
# PaddleOCR returns absolute coordinates, no content_bounds offset needed
|
||
cached["_paddle_word_dicts"] = wf_word_dicts
|
||
else:
|
||
# Get word_dicts from cache or run Tesseract full-page
|
||
wf_word_dicts = cached.get("_word_dicts")
|
||
if wf_word_dicts is None:
|
||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||
if geo_result is not None:
|
||
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = wf_word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
if not wf_word_dicts:
|
||
raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid")
|
||
|
||
# Convert word coordinates to absolute image coordinates if needed
|
||
# (detect_column_geometry returns words relative to content ROI)
|
||
# PaddleOCR already returns absolute coordinates — skip offset.
|
||
if engine != "paddle":
|
||
content_bounds = cached.get("_content_bounds")
|
||
if content_bounds:
|
||
lx, _rx, ty, _by = content_bounds
|
||
abs_words = []
|
||
for w in wf_word_dicts:
|
||
abs_words.append({
|
||
**w,
|
||
'left': w['left'] + lx,
|
||
'top': w['top'] + ty,
|
||
})
|
||
wf_word_dicts = abs_words
|
||
|
||
cells, columns_meta = build_grid_from_words(wf_word_dicts, img_w, img_h)
|
||
duration = time.time() - t0
|
||
|
||
# Apply IPA phonetic fixes
|
||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||
|
||
# Add zone_index for backward compat
|
||
for cell in cells:
|
||
cell.setdefault("zone_index", 0)
|
||
|
||
col_types = {c['type'] for c in columns_meta}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
||
n_cols = len(columns_meta)
|
||
used_engine = "paddle" if engine == "paddle" else "words_first"
|
||
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"grid_method": "words_first",
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
if is_vocab or 'column_text' in col_types:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
|
||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
||
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
||
|
||
await _append_pipeline_log(session_id, "words", {
|
||
"grid_method": "words_first",
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||
"ocr_engine": used_engine,
|
||
"layout": word_result["layout"],
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {"session_id": session_id, **word_result}
|
||
|
||
if stream:
|
||
# Cell-First OCR v2: use batch-then-stream approach instead of
|
||
# per-cell streaming. The parallel ThreadPoolExecutor in
|
||
# build_cell_grid_v2 is much faster than sequential streaming.
|
||
return StreamingResponse(
|
||
_word_batch_stream_generator(
|
||
session_id, cached, col_regions, row_geoms,
|
||
dewarped_bgr, engine, pronunciation, request,
|
||
skip_heal_gaps=skip_heal_gaps,
|
||
),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# --- Non-streaming path (grid_method=v2) ---
|
||
t0 = time.time()
|
||
|
||
# Create binarized OCR image (for Tesseract)
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
|
||
cells, columns_meta = build_cell_grid_v2(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
skip_heal_gaps=skip_heal_gaps,
|
||
)
|
||
duration = time.time() - t0
|
||
|
||
# Add zone_index to each cell (default 0 for backward compatibility)
|
||
for cell in cells:
|
||
cell.setdefault("zone_index", 0)
|
||
|
||
# Layout detection
|
||
col_types = {c['type'] for c in columns_meta}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Count content rows and columns for grid_shape
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
n_cols = len(columns_meta)
|
||
|
||
# Determine which engine was actually used
|
||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||
|
||
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||
|
||
# Grid result (always generic)
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||
# to vocab entries (row→entry).
|
||
has_text_col = 'column_text' in col_types
|
||
if is_vocab or has_text_col:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=8,
|
||
)
|
||
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||
|
||
await _append_pipeline_log(session_id, "words", {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||
"low_confidence_count": word_result["summary"]["low_confidence"],
|
||
"ocr_engine": used_engine,
|
||
"layout": word_result["layout"],
|
||
"entry_count": word_result.get("entry_count", 0),
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
**word_result,
|
||
}
|
||
|
||
|
||
async def _word_batch_stream_generator(
|
||
session_id: str,
|
||
cached: Dict[str, Any],
|
||
col_regions: List[PageRegion],
|
||
row_geoms: List[RowGeometry],
|
||
dewarped_bgr: np.ndarray,
|
||
engine: str,
|
||
pronunciation: str,
|
||
request: Request,
|
||
skip_heal_gaps: bool = False,
|
||
):
|
||
"""SSE generator that runs batch OCR (parallel) then streams results.
|
||
|
||
Unlike the old per-cell streaming, this uses build_cell_grid_v2 with
|
||
ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events.
|
||
The 'preparing' event keeps the connection alive during OCR processing.
|
||
"""
|
||
import asyncio
|
||
|
||
t0 = time.time()
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
total_cells = n_content_rows * n_cols
|
||
|
||
# 1. Send meta event immediately
|
||
meta_event = {
|
||
"type": "meta",
|
||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
}
|
||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||
|
||
# 2. Send preparing event (keepalive for proxy)
|
||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
||
|
||
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
||
# The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE
|
||
# connections after 30-60s. Send keepalive every 5s to prevent this.
|
||
loop = asyncio.get_event_loop()
|
||
ocr_future = loop.run_in_executor(
|
||
None,
|
||
lambda: build_cell_grid_v2(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
skip_heal_gaps=skip_heal_gaps,
|
||
),
|
||
)
|
||
|
||
# Send keepalive events every 5 seconds while OCR runs
|
||
keepalive_count = 0
|
||
while not ocr_future.done():
|
||
try:
|
||
cells, columns_meta = await asyncio.wait_for(
|
||
asyncio.shield(ocr_future), timeout=5.0,
|
||
)
|
||
break # OCR finished
|
||
except asyncio.TimeoutError:
|
||
keepalive_count += 1
|
||
elapsed = int(time.time() - t0)
|
||
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
||
ocr_future.cancel()
|
||
return
|
||
else:
|
||
cells, columns_meta = ocr_future.result()
|
||
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
||
return
|
||
|
||
# 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||
|
||
# 5. Send columns meta
|
||
if columns_meta:
|
||
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
||
|
||
# 6. Stream all cells
|
||
for idx, cell in enumerate(cells):
|
||
cell_event = {
|
||
"type": "cell",
|
||
"cell": cell,
|
||
"progress": {"current": idx + 1, "total": len(cells)},
|
||
}
|
||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||
|
||
# 6. Build final result and persist
|
||
duration = time.time() - t0
|
||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
vocab_entries = None
|
||
has_text_col = 'column_text' in col_types
|
||
if is_vocab or has_text_col:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
vocab_entries = entries
|
||
|
||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
||
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
||
|
||
# 7. Send complete event
|
||
complete_event = {
|
||
"type": "complete",
|
||
"summary": word_result["summary"],
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
}
|
||
if vocab_entries is not None:
|
||
complete_event["vocab_entries"] = vocab_entries
|
||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||
|
||
|
||
async def _word_stream_generator(
|
||
session_id: str,
|
||
cached: Dict[str, Any],
|
||
col_regions: List[PageRegion],
|
||
row_geoms: List[RowGeometry],
|
||
dewarped_bgr: np.ndarray,
|
||
engine: str,
|
||
pronunciation: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||
t0 = time.time()
|
||
|
||
ocr_img = create_ocr_image(dewarped_bgr)
|
||
img_h, img_w = dewarped_bgr.shape[:2]
|
||
|
||
# Compute grid shape upfront for the meta event
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||
|
||
# Determine layout
|
||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
|
||
# Start streaming — first event: meta
|
||
columns_meta = None # will be set from first yield
|
||
total_cells = n_content_rows * n_cols
|
||
|
||
meta_event = {
|
||
"type": "meta",
|
||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
}
|
||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||
|
||
# Keepalive: send preparing event so proxy doesn't timeout during OCR init
|
||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
||
|
||
# Stream cells one by one
|
||
all_cells: List[Dict[str, Any]] = []
|
||
cell_idx = 0
|
||
last_keepalive = time.time()
|
||
|
||
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||
):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||
return
|
||
|
||
if columns_meta is None:
|
||
columns_meta = cols_meta
|
||
# Send columns_used as part of first cell or update meta
|
||
meta_update = {
|
||
"type": "columns",
|
||
"columns_used": cols_meta,
|
||
}
|
||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||
|
||
all_cells.append(cell)
|
||
cell_idx += 1
|
||
|
||
cell_event = {
|
||
"type": "cell",
|
||
"cell": cell,
|
||
"progress": {"current": cell_idx, "total": total},
|
||
}
|
||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||
|
||
# All cells done — build final result
|
||
duration = time.time() - t0
|
||
if columns_meta is None:
|
||
columns_meta = []
|
||
|
||
# Post-OCR: remove rows where ALL cells are empty (inter-row gaps
|
||
# that had stray Tesseract artifacts giving word_count > 0).
|
||
rows_with_text: set = set()
|
||
for c in all_cells:
|
||
if c.get("text", "").strip():
|
||
rows_with_text.add(c["row_index"])
|
||
before_filter = len(all_cells)
|
||
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
||
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
||
if empty_rows_removed > 0:
|
||
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
||
|
||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||
|
||
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
||
|
||
word_result = {
|
||
"cells": all_cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(all_cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(all_cells),
|
||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||
# to vocab entries (row→entry).
|
||
vocab_entries = None
|
||
has_text_col = 'column_text' in col_types
|
||
if is_vocab or has_text_col:
|
||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["entry_count"] = len(entries)
|
||
word_result["summary"]["total_entries"] = len(entries)
|
||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||
vocab_entries = entries
|
||
|
||
# Persist to DB
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
current_step=8,
|
||
)
|
||
cached["word_result"] = word_result
|
||
|
||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||
f"layout={word_result['layout']}, "
|
||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||
|
||
# Final complete event
|
||
complete_event = {
|
||
"type": "complete",
|
||
"summary": word_result["summary"],
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
}
|
||
if vocab_entries is not None:
|
||
complete_event["vocab_entries"] = vocab_entries
|
||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/paddle-direct")
|
||
async def paddle_direct(session_id: str):
|
||
"""Run PaddleOCR on the preprocessed image and build a word grid directly.
|
||
|
||
Expects orientation/deskew/dewarp/crop to be done already.
|
||
Uses the cropped image (falls back to dewarped, then original).
|
||
The used image is stored as cropped_png so OverlayReconstruction
|
||
can display it as the background.
|
||
"""
|
||
# Try preprocessed images first (crop > dewarp > original)
|
||
img_png = await get_session_image(session_id, "cropped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "dewarped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "original")
|
||
if not img_png:
|
||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||
|
||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Failed to decode original image")
|
||
|
||
img_h, img_w = img_bgr.shape[:2]
|
||
|
||
from cv_ocr_engines import ocr_region_paddle
|
||
|
||
t0 = time.time()
|
||
word_dicts = await ocr_region_paddle(img_bgr, region=None)
|
||
if not word_dicts:
|
||
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
|
||
|
||
# Reuse build_grid_from_words — same function that works in the regular
|
||
# pipeline with PaddleOCR (engine=paddle, grid_method=words_first).
|
||
# Handles phrase splitting, column clustering, and reading order.
|
||
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
|
||
duration = time.time() - t0
|
||
|
||
# Tag cells as paddle_direct
|
||
for cell in cells:
|
||
cell["ocr_engine"] = "paddle_direct"
|
||
|
||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||
n_cols = len(columns_meta)
|
||
col_types = {c.get("type") for c in columns_meta}
|
||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_rows,
|
||
"cols": n_cols,
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": "paddle_direct",
|
||
"grid_method": "paddle_direct",
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
# Store preprocessed image as cropped_png so OverlayReconstruction shows it
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
cropped_png=img_png,
|
||
current_step=8,
|
||
)
|
||
|
||
logger.info(
|
||
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
|
||
session_id, len(cells), n_rows, n_cols, duration,
|
||
)
|
||
|
||
await _append_pipeline_log(session_id, "paddle_direct", {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||
"ocr_engine": "paddle_direct",
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {"session_id": session_id, **word_result}
|
||
|
||
|
||
def _split_paddle_multi_words(words: list) -> list:
|
||
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
||
|
||
PaddleOCR often returns entire phrases as a single box, e.g.
|
||
"More than 200 singers took part in the" with one bounding box.
|
||
This splits them into individual words with proportional widths.
|
||
Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"])
|
||
and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]).
|
||
"""
|
||
import re
|
||
|
||
result = []
|
||
for w in words:
|
||
raw_text = w.get("text", "").strip()
|
||
if not raw_text:
|
||
continue
|
||
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
||
tokens = re.split(
|
||
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
||
)
|
||
tokens = [t for t in tokens if t]
|
||
|
||
if len(tokens) <= 1:
|
||
result.append(w)
|
||
else:
|
||
# Split proportionally by character count
|
||
total_chars = sum(len(t) for t in tokens)
|
||
if total_chars == 0:
|
||
continue
|
||
n_gaps = len(tokens) - 1
|
||
gap_px = w["width"] * 0.02
|
||
usable_w = w["width"] - gap_px * n_gaps
|
||
cursor = w["left"]
|
||
for t in tokens:
|
||
token_w = max(1, usable_w * len(t) / total_chars)
|
||
result.append({
|
||
"text": t,
|
||
"left": round(cursor),
|
||
"top": w["top"],
|
||
"width": round(token_w),
|
||
"height": w["height"],
|
||
"conf": w.get("conf", 0),
|
||
})
|
||
cursor += token_w + gap_px
|
||
return result
|
||
|
||
|
||
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
||
"""Group words into rows by Y-position clustering.
|
||
|
||
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
||
Returns list of rows, each row is a list of words sorted left-to-right.
|
||
"""
|
||
if not words:
|
||
return []
|
||
# Sort by vertical center
|
||
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
||
rows: list = []
|
||
current_row: list = [sorted_words[0]]
|
||
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
||
|
||
for w in sorted_words[1:]:
|
||
cy = w["top"] + w.get("height", 0) / 2
|
||
if abs(cy - current_cy) <= row_gap:
|
||
current_row.append(w)
|
||
else:
|
||
# Sort current row left-to-right before saving
|
||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||
current_row = [w]
|
||
current_cy = cy
|
||
if current_row:
|
||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||
return rows
|
||
|
||
|
||
def _row_center_y(row: list) -> float:
|
||
"""Average vertical center of a row of words."""
|
||
if not row:
|
||
return 0.0
|
||
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
||
|
||
|
||
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
||
"""Merge two word sequences from the same row using sequence alignment.
|
||
|
||
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
||
- If words match (same/similar text): take Paddle text with averaged coords
|
||
- If they don't match: the extra word is unique to one engine, include it
|
||
|
||
This prevents duplicates because both engines produce words in the same order.
|
||
"""
|
||
merged = []
|
||
pi, ti = 0, 0
|
||
|
||
while pi < len(paddle_row) and ti < len(tess_row):
|
||
pw = paddle_row[pi]
|
||
tw = tess_row[ti]
|
||
|
||
# Check if these are the same word
|
||
pt = pw.get("text", "").lower().strip()
|
||
tt = tw.get("text", "").lower().strip()
|
||
|
||
# Same text or one contains the other
|
||
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
||
|
||
# Spatial overlap check: if words overlap >= 50% horizontally,
|
||
# they're the same physical word regardless of OCR text differences
|
||
if not is_same:
|
||
overlap_left = max(pw["left"], tw["left"])
|
||
overlap_right = min(
|
||
pw["left"] + pw.get("width", 0),
|
||
tw["left"] + tw.get("width", 0),
|
||
)
|
||
overlap_w = max(0, overlap_right - overlap_left)
|
||
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
||
if min_w > 0 and overlap_w / min_w >= 0.5:
|
||
is_same = True
|
||
|
||
if is_same:
|
||
# Matched — average coordinates weighted by confidence
|
||
pc = pw.get("conf", 80)
|
||
tc = tw.get("conf", 50)
|
||
total = pc + tc
|
||
if total == 0:
|
||
total = 1
|
||
merged.append({
|
||
"text": pw["text"], # Paddle text preferred
|
||
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
||
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
||
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
||
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
||
"conf": max(pc, tc),
|
||
})
|
||
pi += 1
|
||
ti += 1
|
||
else:
|
||
# Different text — one engine found something extra
|
||
# Look ahead: is the current Paddle word somewhere in Tesseract ahead?
|
||
paddle_ahead = any(
|
||
tess_row[t].get("text", "").lower().strip() == pt
|
||
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
||
)
|
||
# Is the current Tesseract word somewhere in Paddle ahead?
|
||
tess_ahead = any(
|
||
paddle_row[p].get("text", "").lower().strip() == tt
|
||
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
||
)
|
||
|
||
if paddle_ahead and not tess_ahead:
|
||
# Tesseract has an extra word (e.g. "!" or bullet) → include it
|
||
if tw.get("conf", 0) >= 30:
|
||
merged.append(tw)
|
||
ti += 1
|
||
elif tess_ahead and not paddle_ahead:
|
||
# Paddle has an extra word → include it
|
||
merged.append(pw)
|
||
pi += 1
|
||
else:
|
||
# Both have unique words or neither found ahead → take leftmost first
|
||
if pw["left"] <= tw["left"]:
|
||
merged.append(pw)
|
||
pi += 1
|
||
else:
|
||
if tw.get("conf", 0) >= 30:
|
||
merged.append(tw)
|
||
ti += 1
|
||
|
||
# Remaining words from either engine
|
||
while pi < len(paddle_row):
|
||
merged.append(paddle_row[pi])
|
||
pi += 1
|
||
while ti < len(tess_row):
|
||
tw = tess_row[ti]
|
||
if tw.get("conf", 0) >= 30:
|
||
merged.append(tw)
|
||
ti += 1
|
||
|
||
return merged
|
||
|
||
|
||
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
||
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.
|
||
|
||
Strategy:
|
||
1. Group each engine's words into rows (by Y-position clustering)
|
||
2. Match rows between engines (by vertical center proximity)
|
||
3. Within each matched row: merge sequences left-to-right, deduplicating
|
||
words that appear in both engines at the same sequence position
|
||
4. Unmatched rows from either engine: keep as-is
|
||
|
||
This prevents:
|
||
- Cross-line averaging (words from different lines being merged)
|
||
- Duplicate words (same word from both engines shown twice)
|
||
"""
|
||
if not paddle_words and not tess_words:
|
||
return []
|
||
if not paddle_words:
|
||
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
||
if not tess_words:
|
||
return list(paddle_words)
|
||
|
||
# Step 1: Group into rows
|
||
paddle_rows = _group_words_into_rows(paddle_words)
|
||
tess_rows = _group_words_into_rows(tess_words)
|
||
|
||
# Step 2: Match rows between engines by vertical center proximity
|
||
used_tess_rows: set = set()
|
||
merged_all: list = []
|
||
|
||
for pr in paddle_rows:
|
||
pr_cy = _row_center_y(pr)
|
||
best_dist, best_tri = float("inf"), -1
|
||
for tri, tr in enumerate(tess_rows):
|
||
if tri in used_tess_rows:
|
||
continue
|
||
tr_cy = _row_center_y(tr)
|
||
dist = abs(pr_cy - tr_cy)
|
||
if dist < best_dist:
|
||
best_dist, best_tri = dist, tri
|
||
|
||
# Row height threshold — rows must be within ~1.5x typical line height
|
||
max_row_dist = max(
|
||
max((w.get("height", 20) for w in pr), default=20),
|
||
15,
|
||
)
|
||
|
||
if best_tri >= 0 and best_dist <= max_row_dist:
|
||
# Matched row — merge sequences
|
||
tr = tess_rows[best_tri]
|
||
used_tess_rows.add(best_tri)
|
||
merged_all.extend(_merge_row_sequences(pr, tr))
|
||
else:
|
||
# No matching Tesseract row — keep Paddle row as-is
|
||
merged_all.extend(pr)
|
||
|
||
# Add unmatched Tesseract rows
|
||
for tri, tr in enumerate(tess_rows):
|
||
if tri not in used_tess_rows:
|
||
for tw in tr:
|
||
if tw.get("conf", 0) >= 40:
|
||
merged_all.append(tw)
|
||
|
||
return merged_all
|
||
|
||
|
||
@router.post("/sessions/{session_id}/paddle-kombi")
|
||
async def paddle_kombi(session_id: str):
|
||
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results.
|
||
|
||
Both engines run on the same preprocessed (cropped/dewarped) image.
|
||
Word boxes are matched by IoU and coordinates are averaged weighted by
|
||
confidence. Unmatched Tesseract words (bullets, symbols) are added.
|
||
"""
|
||
img_png = await get_session_image(session_id, "cropped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "dewarped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "original")
|
||
if not img_png:
|
||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||
|
||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||
|
||
img_h, img_w = img_bgr.shape[:2]
|
||
|
||
from cv_ocr_engines import ocr_region_paddle
|
||
|
||
t0 = time.time()
|
||
|
||
# --- PaddleOCR ---
|
||
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
||
if not paddle_words:
|
||
paddle_words = []
|
||
|
||
# --- Tesseract ---
|
||
from PIL import Image
|
||
import pytesseract
|
||
|
||
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||
data = pytesseract.image_to_data(
|
||
pil_img, lang="eng+deu",
|
||
config="--psm 6 --oem 3",
|
||
output_type=pytesseract.Output.DICT,
|
||
)
|
||
tess_words = []
|
||
for i in range(len(data["text"])):
|
||
text = str(data["text"][i]).strip()
|
||
conf_raw = str(data["conf"][i])
|
||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||
if not text or conf < 20:
|
||
continue
|
||
tess_words.append({
|
||
"text": text,
|
||
"left": data["left"][i],
|
||
"top": data["top"][i],
|
||
"width": data["width"][i],
|
||
"height": data["height"][i],
|
||
"conf": conf,
|
||
})
|
||
|
||
# --- Split multi-word Paddle boxes into individual words ---
|
||
paddle_words_split = _split_paddle_multi_words(paddle_words)
|
||
logger.info(
|
||
"paddle_kombi: split %d paddle boxes → %d individual words",
|
||
len(paddle_words), len(paddle_words_split),
|
||
)
|
||
|
||
# --- Merge ---
|
||
if not paddle_words_split and not tess_words:
|
||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||
|
||
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
|
||
|
||
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||
duration = time.time() - t0
|
||
|
||
for cell in cells:
|
||
cell["ocr_engine"] = "kombi"
|
||
|
||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||
n_cols = len(columns_meta)
|
||
col_types = {c.get("type") for c in columns_meta}
|
||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": "kombi",
|
||
"grid_method": "kombi",
|
||
"raw_paddle_words": paddle_words,
|
||
"raw_paddle_words_split": paddle_words_split,
|
||
"raw_tesseract_words": tess_words,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
"paddle_words": len(paddle_words),
|
||
"paddle_words_split": len(paddle_words_split),
|
||
"tesseract_words": len(tess_words),
|
||
"merged_words": len(merged_words),
|
||
},
|
||
}
|
||
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
cropped_png=img_png,
|
||
current_step=8,
|
||
)
|
||
|
||
logger.info(
|
||
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||
"[paddle=%d, tess=%d, merged=%d]",
|
||
session_id, len(cells), n_rows, n_cols, duration,
|
||
len(paddle_words), len(tess_words), len(merged_words),
|
||
)
|
||
|
||
await _append_pipeline_log(session_id, "paddle_kombi", {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||
"paddle_words": len(paddle_words),
|
||
"tesseract_words": len(tess_words),
|
||
"merged_words": len(merged_words),
|
||
"ocr_engine": "kombi",
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {"session_id": session_id, **word_result}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/rapid-kombi")
|
||
async def rapid_kombi(session_id: str):
|
||
"""Run RapidOCR + Tesseract on the preprocessed image and merge results.
|
||
|
||
Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime)
|
||
instead of remote PaddleOCR service.
|
||
"""
|
||
img_png = await get_session_image(session_id, "cropped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "dewarped")
|
||
if not img_png:
|
||
img_png = await get_session_image(session_id, "original")
|
||
if not img_png:
|
||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||
|
||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||
if img_bgr is None:
|
||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||
|
||
img_h, img_w = img_bgr.shape[:2]
|
||
|
||
from cv_ocr_engines import ocr_region_rapid
|
||
from cv_vocab_types import PageRegion
|
||
|
||
t0 = time.time()
|
||
|
||
# --- RapidOCR (local, synchronous) ---
|
||
full_region = PageRegion(
|
||
type="full_page", x=0, y=0, width=img_w, height=img_h,
|
||
)
|
||
rapid_words = ocr_region_rapid(img_bgr, full_region)
|
||
if not rapid_words:
|
||
rapid_words = []
|
||
|
||
# --- Tesseract ---
|
||
from PIL import Image
|
||
import pytesseract
|
||
|
||
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||
data = pytesseract.image_to_data(
|
||
pil_img, lang="eng+deu",
|
||
config="--psm 6 --oem 3",
|
||
output_type=pytesseract.Output.DICT,
|
||
)
|
||
tess_words = []
|
||
for i in range(len(data["text"])):
|
||
text = str(data["text"][i]).strip()
|
||
conf_raw = str(data["conf"][i])
|
||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||
if not text or conf < 20:
|
||
continue
|
||
tess_words.append({
|
||
"text": text,
|
||
"left": data["left"][i],
|
||
"top": data["top"][i],
|
||
"width": data["width"][i],
|
||
"height": data["height"][i],
|
||
"conf": conf,
|
||
})
|
||
|
||
# --- Split multi-word RapidOCR boxes into individual words ---
|
||
rapid_words_split = _split_paddle_multi_words(rapid_words)
|
||
logger.info(
|
||
"rapid_kombi: split %d rapid boxes → %d individual words",
|
||
len(rapid_words), len(rapid_words_split),
|
||
)
|
||
|
||
# --- Merge ---
|
||
if not rapid_words_split and not tess_words:
|
||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||
|
||
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
|
||
|
||
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||
duration = time.time() - t0
|
||
|
||
for cell in cells:
|
||
cell["ocr_engine"] = "rapid_kombi"
|
||
|
||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||
n_cols = len(columns_meta)
|
||
col_types = {c.get("type") for c in columns_meta}
|
||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||
|
||
word_result = {
|
||
"cells": cells,
|
||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": "rapid_kombi",
|
||
"grid_method": "rapid_kombi",
|
||
"raw_rapid_words": rapid_words,
|
||
"raw_rapid_words_split": rapid_words_split,
|
||
"raw_tesseract_words": tess_words,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
"rapid_words": len(rapid_words),
|
||
"rapid_words_split": len(rapid_words_split),
|
||
"tesseract_words": len(tess_words),
|
||
"merged_words": len(merged_words),
|
||
},
|
||
}
|
||
|
||
await update_session_db(
|
||
session_id,
|
||
word_result=word_result,
|
||
cropped_png=img_png,
|
||
current_step=8,
|
||
)
|
||
|
||
logger.info(
|
||
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||
"[rapid=%d, tess=%d, merged=%d]",
|
||
session_id, len(cells), n_rows, n_cols, duration,
|
||
len(rapid_words), len(tess_words), len(merged_words),
|
||
)
|
||
|
||
await _append_pipeline_log(session_id, "rapid_kombi", {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||
"rapid_words": len(rapid_words),
|
||
"tesseract_words": len(tess_words),
|
||
"merged_words": len(merged_words),
|
||
"ocr_engine": "rapid_kombi",
|
||
}, duration_ms=int(duration * 1000))
|
||
|
||
return {"session_id": session_id, **word_result}
|
||
|
||
|
||
class WordGroundTruthRequest(BaseModel):
|
||
is_correct: bool
|
||
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||
notes: Optional[str] = None
|
||
|
||
|
||
@router.post("/sessions/{session_id}/ground-truth/words")
|
||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||
"""Save ground truth feedback for the word recognition step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
gt = {
|
||
"is_correct": req.is_correct,
|
||
"corrected_entries": req.corrected_entries,
|
||
"notes": req.notes,
|
||
"saved_at": datetime.utcnow().isoformat(),
|
||
"word_result": session.get("word_result"),
|
||
}
|
||
ground_truth["words"] = gt
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
return {"session_id": session_id, "ground_truth": gt}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/ground-truth/words")
|
||
async def get_word_ground_truth(session_id: str):
|
||
"""Retrieve saved ground truth for word recognition."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
words_gt = ground_truth.get("words")
|
||
if not words_gt:
|
||
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"words_gt": words_gt,
|
||
"words_auto": session.get("word_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM Review Endpoints (Step 6)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@router.post("/sessions/{session_id}/llm-review")
|
||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||
"""Run LLM-based correction on vocab entries from Step 5.
|
||
|
||
Query params:
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if not entries:
|
||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||
|
||
# Optional model override from request body
|
||
body = {}
|
||
try:
|
||
body = await request.json()
|
||
except Exception:
|
||
pass
|
||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||
|
||
if stream:
|
||
return StreamingResponse(
|
||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
# Non-streaming path
|
||
try:
|
||
result = await llm_review_entries(entries, model=model)
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||
|
||
# Store result inside word_result as a sub-key
|
||
word_result["llm_review"] = {
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"entries_corrected": result["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||
|
||
await _append_pipeline_log(session_id, "correction", {
|
||
"engine": "llm",
|
||
"model": result["model_used"],
|
||
"total_entries": len(entries),
|
||
"corrections_proposed": len(result["changes"]),
|
||
}, duration_ms=result["duration_ms"])
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"total_entries": len(entries),
|
||
"corrections_found": len(result["changes"]),
|
||
}
|
||
|
||
|
||
async def _llm_review_stream_generator(
|
||
session_id: str,
|
||
entries: List[Dict],
|
||
word_result: Dict,
|
||
model: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||
try:
|
||
async for event in llm_review_entries_streaming(entries, model=model):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||
return
|
||
|
||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||
|
||
# On complete: persist to DB
|
||
if event.get("type") == "complete":
|
||
word_result["llm_review"] = {
|
||
"changes": event["changes"],
|
||
"model_used": event["model_used"],
|
||
"duration_ms": event["duration_ms"],
|
||
"entries_corrected": event["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||
yield f"data: {json.dumps(error_event)}\n\n"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/llm-review/apply")
|
||
async def apply_llm_corrections(session_id: str, request: Request):
|
||
"""Apply selected LLM corrections to vocab entries."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
llm_review = word_result.get("llm_review")
|
||
if not llm_review:
|
||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||
|
||
body = await request.json()
|
||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||
|
||
changes = llm_review.get("changes", [])
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
|
||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||
corrections = {}
|
||
applied_count = 0
|
||
for idx, change in enumerate(changes):
|
||
if idx in accepted_indices:
|
||
key = (change["row_index"], change["field"])
|
||
corrections[key] = change["new"]
|
||
applied_count += 1
|
||
|
||
# Apply corrections to entries
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
for field_name in ("english", "german", "example"):
|
||
key = (row_idx, field_name)
|
||
if key in corrections:
|
||
entry[field_name] = corrections[key]
|
||
entry["llm_corrected"] = True
|
||
|
||
# Update word_result
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["llm_review"]["applied_count"] = applied_count
|
||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||
|
||
await update_session_db(session_id, word_result=word_result)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"applied_count": applied_count,
|
||
"total_changes": len(changes),
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction")
|
||
async def save_reconstruction(session_id: str, request: Request):
|
||
"""Save edited cell texts from reconstruction step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
body = await request.json()
|
||
cell_updates = body.get("cells", [])
|
||
|
||
if not cell_updates:
|
||
await update_session_db(session_id, current_step=10)
|
||
return {"session_id": session_id, "updated": 0}
|
||
|
||
# Build update map: cell_id -> new text
|
||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||
|
||
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||
main_updates: Dict[str, str] = {}
|
||
for cell_id, text in update_map.items():
|
||
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||
if m:
|
||
bi = int(m.group(1))
|
||
original_id = m.group(2)
|
||
sub_updates.setdefault(bi, {})[original_id] = text
|
||
else:
|
||
main_updates[cell_id] = text
|
||
|
||
# Update main session cells
|
||
cells = word_result.get("cells", [])
|
||
updated_count = 0
|
||
for cell in cells:
|
||
if cell["cell_id"] in main_updates:
|
||
cell["text"] = main_updates[cell["cell_id"]]
|
||
cell["status"] = "edited"
|
||
updated_count += 1
|
||
|
||
word_result["cells"] = cells
|
||
|
||
# Also update vocab_entries if present
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if entries:
|
||
# Map cell_id pattern "R{row}_C{col}" to entry fields
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
# Check each field's cell
|
||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||
# Also try without zero-padding
|
||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||
if new_text is not None:
|
||
entry[field_name] = new_text
|
||
|
||
word_result["vocab_entries"] = entries
|
||
if "entries" in word_result:
|
||
word_result["entries"] = entries
|
||
|
||
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
# Route sub-session updates
|
||
sub_updated = 0
|
||
if sub_updates:
|
||
subs = await get_sub_sessions(session_id)
|
||
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||
for bi, updates in sub_updates.items():
|
||
sub_id = sub_by_index.get(bi)
|
||
if not sub_id:
|
||
continue
|
||
sub_session = await get_session_db(sub_id)
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result")
|
||
if not sub_word:
|
||
continue
|
||
sub_cells = sub_word.get("cells", [])
|
||
for cell in sub_cells:
|
||
if cell["cell_id"] in updates:
|
||
cell["text"] = updates[cell["cell_id"]]
|
||
cell["status"] = "edited"
|
||
sub_updated += 1
|
||
sub_word["cells"] = sub_cells
|
||
await update_session_db(sub_id, word_result=sub_word)
|
||
if sub_id in _cache:
|
||
_cache[sub_id]["word_result"] = sub_word
|
||
|
||
total_updated = updated_count + sub_updated
|
||
logger.info(f"Reconstruction saved for session {session_id}: "
|
||
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"updated": total_updated,
|
||
"main_updated": updated_count,
|
||
"sub_updated": sub_updated,
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||
async def get_fabric_json(session_id: str):
|
||
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
|
||
|
||
If the session has sub-sessions (box regions), their cells are merged
|
||
into the result at the correct Y positions.
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = list(word_result.get("cells", []))
|
||
img_w = word_result.get("image_width", 800)
|
||
img_h = word_result.get("image_height", 600)
|
||
|
||
# Merge sub-session cells at box positions
|
||
subs = await get_sub_sessions(session_id)
|
||
if subs:
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||
|
||
for sub in subs:
|
||
sub_session = await get_session_db(sub["id"])
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result")
|
||
if not sub_word or not sub_word.get("cells"):
|
||
continue
|
||
|
||
bi = sub.get("box_index", 0)
|
||
if bi < len(box_zones):
|
||
box = box_zones[bi]["box"]
|
||
box_y, box_x = box["y"], box["x"]
|
||
else:
|
||
box_y, box_x = 0, 0
|
||
|
||
# Offset sub-session cells to absolute page coordinates
|
||
for cell in sub_word["cells"]:
|
||
cell_copy = dict(cell)
|
||
# Prefix cell_id with box index
|
||
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||
cell_copy["source"] = f"box_{bi}"
|
||
# Offset bbox_px
|
||
bbox = cell_copy.get("bbox_px", {})
|
||
if bbox:
|
||
bbox = dict(bbox)
|
||
bbox["x"] = bbox.get("x", 0) + box_x
|
||
bbox["y"] = bbox.get("y", 0) + box_y
|
||
cell_copy["bbox_px"] = bbox
|
||
cells.append(cell_copy)
|
||
|
||
from services.layout_reconstruction_service import cells_to_fabric_json
|
||
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||
|
||
return fabric_json
|
||
|
||
|
||
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||
async def get_merged_vocab_entries(session_id: str):
|
||
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result") or {}
|
||
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||
|
||
# Tag main entries
|
||
for e in entries:
|
||
e.setdefault("source", "main")
|
||
|
||
# Merge sub-session entries
|
||
subs = await get_sub_sessions(session_id)
|
||
if subs:
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||
|
||
for sub in subs:
|
||
sub_session = await get_session_db(sub["id"])
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result") or {}
|
||
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||
|
||
bi = sub.get("box_index", 0)
|
||
box_y = 0
|
||
if bi < len(box_zones):
|
||
box_y = box_zones[bi]["box"]["y"]
|
||
|
||
for e in sub_entries:
|
||
e_copy = dict(e)
|
||
e_copy["source"] = f"box_{bi}"
|
||
e_copy["source_y"] = box_y # for sorting
|
||
entries.append(e_copy)
|
||
|
||
# Sort by approximate Y position
|
||
def _sort_key(e):
|
||
if e.get("source", "main") == "main":
|
||
return e.get("row_index", 0) * 100 # main entries by row index
|
||
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||
|
||
entries.sort(key=_sort_key)
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"entries": entries,
|
||
"total": len(entries),
|
||
"sources": list(set(e.get("source", "main") for e in entries)),
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||
async def export_reconstruction_pdf(session_id: str):
|
||
"""Export the reconstructed cell grid as a PDF table."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = word_result.get("cells", [])
|
||
columns_used = word_result.get("columns_used", [])
|
||
grid_shape = word_result.get("grid_shape", {})
|
||
n_rows = grid_shape.get("rows", 0)
|
||
n_cols = grid_shape.get("cols", 0)
|
||
|
||
# Build table data: rows × columns
|
||
table_data: list[list[str]] = []
|
||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||
if not header:
|
||
header = [f"Col {i}" for i in range(n_cols)]
|
||
table_data.append(header)
|
||
|
||
for r in range(n_rows):
|
||
row_texts = []
|
||
for ci in range(n_cols):
|
||
cell_id = f"R{r:02d}_C{ci}"
|
||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||
row_texts.append(cell.get("text", "") if cell else "")
|
||
table_data.append(row_texts)
|
||
|
||
# Generate PDF with reportlab
|
||
try:
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.lib import colors
|
||
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||
import io as _io
|
||
|
||
buf = _io.BytesIO()
|
||
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||
if not table_data or not table_data[0]:
|
||
raise HTTPException(status_code=400, detail="No data to export")
|
||
|
||
t = Table(table_data)
|
||
t.setStyle(TableStyle([
|
||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||
('WORDWRAP', (0, 0), (-1, -1), True),
|
||
]))
|
||
doc.build([t])
|
||
buf.seek(0)
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/pdf",
|
||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||
)
|
||
except ImportError:
|
||
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||
async def export_reconstruction_docx(session_id: str):
|
||
"""Export the reconstructed cell grid as a DOCX table."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = word_result.get("cells", [])
|
||
columns_used = word_result.get("columns_used", [])
|
||
grid_shape = word_result.get("grid_shape", {})
|
||
n_rows = grid_shape.get("rows", 0)
|
||
n_cols = grid_shape.get("cols", 0)
|
||
|
||
try:
|
||
from docx import Document
|
||
from docx.shared import Pt
|
||
import io as _io
|
||
|
||
doc = Document()
|
||
doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1)
|
||
|
||
# Build header
|
||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||
if not header:
|
||
header = [f"Col {i}" for i in range(n_cols)]
|
||
|
||
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||
table.style = 'Table Grid'
|
||
|
||
# Header row
|
||
for ci, h in enumerate(header):
|
||
table.rows[0].cells[ci].text = h
|
||
|
||
# Data rows
|
||
for r in range(n_rows):
|
||
for ci in range(n_cols):
|
||
cell_id = f"R{r:02d}_C{ci}"
|
||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||
|
||
buf = _io.BytesIO()
|
||
doc.save(buf)
|
||
buf.seek(0)
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||
)
|
||
except ImportError:
|
||
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Step 8: Validation — Original vs. Reconstruction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
STYLE_SUFFIXES = {
|
||
"educational": "educational illustration, textbook style, clear, colorful",
|
||
"cartoon": "cartoon, child-friendly, simple shapes",
|
||
"sketch": "pencil sketch, hand-drawn, black and white",
|
||
"clipart": "clipart, flat vector style, simple",
|
||
"realistic": "photorealistic, high detail",
|
||
}
|
||
|
||
|
||
class ValidationRequest(BaseModel):
|
||
notes: Optional[str] = None
|
||
score: Optional[int] = None
|
||
|
||
|
||
class GenerateImageRequest(BaseModel):
|
||
region_index: int
|
||
prompt: str
|
||
style: str = "educational"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||
async def detect_image_regions(session_id: str):
|
||
"""Detect illustration/image regions in the original scan using VLM.
|
||
|
||
Sends the original image to qwen2.5vl to find non-text, non-table
|
||
image areas, returning bounding boxes (in %) and descriptions.
|
||
"""
|
||
import base64
|
||
import httpx
|
||
import re
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# Get original image bytes
|
||
original_png = await get_session_image(session_id, "original")
|
||
if not original_png:
|
||
raise HTTPException(status_code=400, detail="No original image found")
|
||
|
||
# Build context from vocab entries for richer descriptions
|
||
word_result = session.get("word_result") or {}
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
vocab_context = ""
|
||
if entries:
|
||
sample = entries[:10]
|
||
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||
if words:
|
||
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||
|
||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||
|
||
prompt = (
|
||
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||
"(NOT text, NOT table cells, NOT blank areas). "
|
||
"For each image region found, return its bounding box as percentage of page dimensions "
|
||
"and a short English description of what the image shows. "
|
||
"Reply with ONLY a JSON array like: "
|
||
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||
"If there are NO images on the page, return an empty array: []"
|
||
f"{vocab_context}"
|
||
)
|
||
|
||
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"images": [img_b64],
|
||
"stream": False,
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||
resp.raise_for_status()
|
||
text = resp.json().get("response", "")
|
||
|
||
# Parse JSON array from response
|
||
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||
if match:
|
||
raw_regions = json.loads(match.group(0))
|
||
else:
|
||
raw_regions = []
|
||
|
||
# Normalize to ImageRegion format
|
||
regions = []
|
||
for r in raw_regions:
|
||
regions.append({
|
||
"bbox_pct": {
|
||
"x": max(0, min(100, float(r.get("x", 0)))),
|
||
"y": max(0, min(100, float(r.get("y", 0)))),
|
||
"w": max(1, min(100, float(r.get("w", 10)))),
|
||
"h": max(1, min(100, float(r.get("h", 10)))),
|
||
},
|
||
"description": r.get("description", ""),
|
||
"prompt": r.get("description", ""),
|
||
"image_b64": None,
|
||
"style": "educational",
|
||
})
|
||
|
||
# Enrich prompts with nearby vocab context
|
||
if entries:
|
||
for region in regions:
|
||
ry = region["bbox_pct"]["y"]
|
||
rh = region["bbox_pct"]["h"]
|
||
nearby = [
|
||
e for e in entries
|
||
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||
]
|
||
if nearby:
|
||
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||
if en_words or de_words:
|
||
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||
if de_words:
|
||
context += f" / {', '.join(de_words[:5])}"
|
||
context += ")"
|
||
region["prompt"] = region["description"] + context
|
||
|
||
# Save to ground_truth JSONB
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
validation["image_regions"] = regions
|
||
validation["detected_at"] = datetime.utcnow().isoformat()
|
||
ground_truth["validation"] = validation
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||
|
||
return {"regions": regions, "count": len(regions)}
|
||
|
||
except httpx.ConnectError:
|
||
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||
except Exception as e:
|
||
logger.error(f"Image detection failed for {session_id}: {e}")
|
||
return {"regions": [], "count": 0, "error": str(e)}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||
"""Generate a replacement image for a detected region using mflux.
|
||
|
||
Sends the prompt (with style suffix) to the mflux-service running
|
||
natively on the Mac Mini (Metal GPU required).
|
||
"""
|
||
import httpx
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
regions = validation.get("image_regions") or []
|
||
|
||
if req.region_index < 0 or req.region_index >= len(regions):
|
||
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||
|
||
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||
full_prompt = f"{req.prompt}, {style_suffix}"
|
||
|
||
# Determine image size from region aspect ratio (snap to multiples of 64)
|
||
region = regions[req.region_index]
|
||
bbox = region["bbox_pct"]
|
||
aspect = bbox["w"] / max(bbox["h"], 1)
|
||
if aspect > 1.3:
|
||
width, height = 768, 512
|
||
elif aspect < 0.7:
|
||
width, height = 512, 768
|
||
else:
|
||
width, height = 512, 512
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||
resp = await client.post(f"{mflux_url}/generate", json={
|
||
"prompt": full_prompt,
|
||
"width": width,
|
||
"height": height,
|
||
"steps": 4,
|
||
})
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
image_b64 = data.get("image_b64")
|
||
|
||
if not image_b64:
|
||
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||
|
||
# Save to ground_truth
|
||
regions[req.region_index]["image_b64"] = image_b64
|
||
regions[req.region_index]["prompt"] = req.prompt
|
||
regions[req.region_index]["style"] = req.style
|
||
validation["image_regions"] = regions
|
||
ground_truth["validation"] = validation
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||
return {"image_b64": image_b64, "success": True}
|
||
|
||
except httpx.ConnectError:
|
||
logger.warning(f"mflux-service not available at {mflux_url}")
|
||
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||
except Exception as e:
|
||
logger.error(f"Image generation failed for {session_id}: {e}")
|
||
return {"image_b64": None, "success": False, "error": str(e)}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||
async def save_validation(session_id: str, req: ValidationRequest):
|
||
"""Save final validation results for step 8.
|
||
|
||
Stores notes, score, and preserves any detected/generated image regions.
|
||
Sets current_step = 10 to mark pipeline as complete.
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
validation["validated_at"] = datetime.utcnow().isoformat()
|
||
validation["notes"] = req.notes
|
||
validation["score"] = req.score
|
||
ground_truth["validation"] = validation
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||
|
||
return {"session_id": session_id, "validation": validation}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||
async def get_validation(session_id: str):
|
||
"""Retrieve saved validation data for step 8."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"validation": validation,
|
||
"word_result": session.get("word_result"),
|
||
}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reprocess")
|
||
async def reprocess_session(session_id: str, request: Request):
|
||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||
|
||
Body: {"from_step": 5} (1-indexed step number)
|
||
|
||
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
|
||
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
|
||
|
||
Clears downstream results:
|
||
- from_step <= 1: orientation_result + all downstream
|
||
- from_step <= 2: deskew_result + all downstream
|
||
- from_step <= 3: dewarp_result + all downstream
|
||
- from_step <= 4: crop_result + all downstream
|
||
- from_step <= 5: column_result, row_result, word_result
|
||
- from_step <= 6: row_result, word_result
|
||
- from_step <= 7: word_result (cells, vocab_entries)
|
||
- from_step <= 8: word_result.llm_review only
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
body = await request.json()
|
||
from_step = body.get("from_step", 1)
|
||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||
|
||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||
|
||
# Clear downstream data based on from_step
|
||
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
|
||
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
|
||
if from_step <= 8:
|
||
update_kwargs["word_result"] = None
|
||
elif from_step == 9:
|
||
# Only clear LLM review from word_result
|
||
word_result = session.get("word_result")
|
||
if word_result:
|
||
word_result.pop("llm_review", None)
|
||
word_result.pop("llm_corrections", None)
|
||
update_kwargs["word_result"] = word_result
|
||
|
||
if from_step <= 7:
|
||
update_kwargs["row_result"] = None
|
||
if from_step <= 6:
|
||
update_kwargs["column_result"] = None
|
||
if from_step <= 4:
|
||
update_kwargs["crop_result"] = None
|
||
if from_step <= 3:
|
||
update_kwargs["dewarp_result"] = None
|
||
if from_step <= 2:
|
||
update_kwargs["deskew_result"] = None
|
||
if from_step <= 1:
|
||
update_kwargs["orientation_result"] = None
|
||
|
||
await update_session_db(session_id, **update_kwargs)
|
||
|
||
# Also clear cache
|
||
if session_id in _cache:
|
||
for key in list(update_kwargs.keys()):
|
||
if key != "current_step":
|
||
_cache[session_id][key] = update_kwargs[key]
|
||
_cache[session_id]["current_step"] = from_step
|
||
|
||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"from_step": from_step,
|
||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||
}
|
||
|
||
|
||
async def _get_rows_overlay(session_id: str) -> Response:
|
||
"""Generate cropped (or dewarped) image with row bands drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
row_result = session.get("row_result")
|
||
if not row_result or not row_result.get("rows"):
|
||
raise HTTPException(status_code=404, detail="No row data available")
|
||
|
||
# Load best available base image (cropped > dewarped > original)
|
||
base_png = await _get_base_image_png(session_id)
|
||
if not base_png:
|
||
raise HTTPException(status_code=404, detail="No base image available")
|
||
|
||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
# Color map for row types (BGR)
|
||
row_colors = {
|
||
"content": (255, 180, 0), # Blue
|
||
"header": (128, 128, 128), # Gray
|
||
"footer": (128, 128, 128), # Gray
|
||
"margin_top": (100, 100, 100), # Dark Gray
|
||
"margin_bottom": (100, 100, 100), # Dark Gray
|
||
}
|
||
|
||
overlay = img.copy()
|
||
for row in row_result["rows"]:
|
||
x, y = row["x"], row["y"]
|
||
w, h = row["width"], row["height"]
|
||
row_type = row.get("row_type", "content")
|
||
color = row_colors.get(row_type, (200, 200, 200))
|
||
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
|
||
|
||
# Solid border
|
||
cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
|
||
|
||
# Label
|
||
idx = row.get("index", 0)
|
||
label = f"R{idx} {row_type.upper()}"
|
||
wc = row.get("word_count", 0)
|
||
if wc:
|
||
label = f"{label} ({wc}w)"
|
||
cv2.putText(img, label, (x + 5, y + 18),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||
|
||
# Blend overlay at 15% opacity
|
||
cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
|
||
|
||
# Draw zone separator lines if zones exist
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
if zones:
|
||
img_w_px = img.shape[1]
|
||
zone_color = (0, 200, 255) # Yellow (BGR)
|
||
dash_len = 20
|
||
for zone in zones:
|
||
if zone.get("zone_type") == "box":
|
||
zy = zone["y"]
|
||
zh = zone["height"]
|
||
for line_y in [zy, zy + zh]:
|
||
for sx in range(0, img_w_px, dash_len * 2):
|
||
ex = min(sx + dash_len, img_w_px)
|
||
cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
|
||
|
||
# Red semi-transparent overlay for box zones
|
||
_draw_box_exclusion_overlay(img, zones)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
async def _get_words_overlay(session_id: str) -> Response:
|
||
"""Generate cropped (or dewarped) image with cell grid drawn on it."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Support both new cell-based and legacy entry-based formats
|
||
cells = word_result.get("cells")
|
||
if not cells and not word_result.get("entries"):
|
||
raise HTTPException(status_code=404, detail="No word data available")
|
||
|
||
# Load best available base image (cropped > dewarped > original)
|
||
base_png = await _get_base_image_png(session_id)
|
||
if not base_png:
|
||
raise HTTPException(status_code=404, detail="No base image available")
|
||
|
||
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||
|
||
img_h, img_w = img.shape[:2]
|
||
|
||
overlay = img.copy()
|
||
|
||
if cells:
|
||
# New cell-based overlay: color by column index
|
||
col_palette = [
|
||
(255, 180, 0), # Blue (BGR)
|
||
(0, 200, 0), # Green
|
||
(0, 140, 255), # Orange
|
||
(200, 100, 200), # Purple
|
||
(200, 200, 0), # Cyan
|
||
(100, 200, 200), # Yellow-ish
|
||
]
|
||
|
||
for cell in cells:
|
||
bbox = cell.get("bbox_px", {})
|
||
cx = bbox.get("x", 0)
|
||
cy = bbox.get("y", 0)
|
||
cw = bbox.get("w", 0)
|
||
ch = bbox.get("h", 0)
|
||
if cw <= 0 or ch <= 0:
|
||
continue
|
||
|
||
col_idx = cell.get("col_index", 0)
|
||
color = col_palette[col_idx % len(col_palette)]
|
||
|
||
# Cell rectangle border
|
||
cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
|
||
# Semi-transparent fill
|
||
cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
|
||
|
||
# Cell-ID label (top-left corner)
|
||
cell_id = cell.get("cell_id", "")
|
||
cv2.putText(img, cell_id, (cx + 2, cy + 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
|
||
|
||
# Text label (bottom of cell)
|
||
text = cell.get("text", "")
|
||
if text:
|
||
conf = cell.get("confidence", 0)
|
||
if conf >= 70:
|
||
text_color = (0, 180, 0)
|
||
elif conf >= 50:
|
||
text_color = (0, 180, 220)
|
||
else:
|
||
text_color = (0, 0, 220)
|
||
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, cy + ch - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
else:
|
||
# Legacy fallback: entry-based overlay (for old sessions)
|
||
column_result = session.get("column_result")
|
||
row_result = session.get("row_result")
|
||
col_colors = {
|
||
"column_en": (255, 180, 0),
|
||
"column_de": (0, 200, 0),
|
||
"column_example": (0, 140, 255),
|
||
}
|
||
|
||
columns = []
|
||
if column_result and column_result.get("columns"):
|
||
columns = [c for c in column_result["columns"]
|
||
if c.get("type", "").startswith("column_")]
|
||
|
||
content_rows_data = []
|
||
if row_result and row_result.get("rows"):
|
||
content_rows_data = [r for r in row_result["rows"]
|
||
if r.get("row_type") == "content"]
|
||
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
color = col_colors.get(col_type, (200, 200, 200))
|
||
cx, cw = col["x"], col["width"]
|
||
for row in content_rows_data:
|
||
ry, rh = row["y"], row["height"]
|
||
cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
|
||
cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
|
||
|
||
entries = word_result["entries"]
|
||
entry_by_row: Dict[int, Dict] = {}
|
||
for entry in entries:
|
||
entry_by_row[entry.get("row_index", -1)] = entry
|
||
|
||
for row_idx, row in enumerate(content_rows_data):
|
||
entry = entry_by_row.get(row_idx)
|
||
if not entry:
|
||
continue
|
||
conf = entry.get("confidence", 0)
|
||
text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
|
||
ry, rh = row["y"], row["height"]
|
||
for col in columns:
|
||
col_type = col.get("type", "")
|
||
cx, cw = col["x"], col["width"]
|
||
field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
|
||
text = entry.get(field, "") if field else ""
|
||
if text:
|
||
label = text.replace('\n', ' ')[:30]
|
||
cv2.putText(img, label, (cx + 3, ry + rh - 4),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
|
||
|
||
# Blend overlay at 10% opacity
|
||
cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
|
||
|
||
# Red semi-transparent overlay for box zones
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
_draw_box_exclusion_overlay(img, zones)
|
||
|
||
success, result_png = cv2.imencode(".png", img)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||
|
||
return Response(content=result_png.tobytes(), media_type="image/png")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Handwriting Removal Endpoint
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/remove-handwriting")
|
||
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||
"""
|
||
Remove handwriting from a session image using inpainting.
|
||
|
||
Steps:
|
||
1. Load source image (auto → deskewed if available, else original)
|
||
2. Detect handwriting mask (filtered by target_ink)
|
||
3. Dilate mask to cover stroke edges
|
||
4. Inpaint the image
|
||
5. Store result as clean_png in the session
|
||
|
||
Returns metadata including the URL to fetch the clean image.
|
||
"""
|
||
import time as _time
|
||
t0 = _time.monotonic()
|
||
|
||
from services.handwriting_detection import detect_handwriting
|
||
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# 1. Determine source image
|
||
source = req.use_source
|
||
if source == "auto":
|
||
deskewed = await get_session_image(session_id, "deskewed")
|
||
source = "deskewed" if deskewed else "original"
|
||
|
||
image_bytes = await get_session_image(session_id, source)
|
||
if not image_bytes:
|
||
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||
|
||
# 2. Detect handwriting mask
|
||
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||
|
||
# 3. Convert mask to PNG bytes and dilate
|
||
import io
|
||
from PIL import Image as _PILImage
|
||
mask_img = _PILImage.fromarray(detection.mask)
|
||
mask_buf = io.BytesIO()
|
||
mask_img.save(mask_buf, format="PNG")
|
||
mask_bytes = mask_buf.getvalue()
|
||
|
||
if req.dilation > 0:
|
||
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||
|
||
# 4. Inpaint
|
||
method_map = {
|
||
"telea": InpaintingMethod.OPENCV_TELEA,
|
||
"ns": InpaintingMethod.OPENCV_NS,
|
||
"auto": InpaintingMethod.AUTO,
|
||
}
|
||
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||
|
||
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||
if not result.success:
|
||
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||
|
||
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||
|
||
meta = {
|
||
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||
"detection_confidence": round(detection.confidence, 4),
|
||
"target_ink": req.target_ink,
|
||
"dilation": req.dilation,
|
||
"source_image": source,
|
||
"processing_time_ms": elapsed_ms,
|
||
}
|
||
|
||
# 5. Persist clean image (convert BGR ndarray → PNG bytes)
|
||
clean_png_bytes = image_to_png(result.image)
|
||
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||
|
||
return {
|
||
**meta,
|
||
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||
"session_id": session_id,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Auto-Mode Endpoint (Improvement 3)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class RunAutoRequest(BaseModel):
|
||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||
pronunciation: str = "british"
|
||
skip_llm_review: bool = False
|
||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||
|
||
|
||
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||
"""Format a single SSE event line."""
|
||
import json as _json
|
||
payload = {"step": step, "status": status, **data}
|
||
return f"data: {_json.dumps(payload)}\n\n"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/run-auto")
|
||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||
|
||
Steps:
|
||
1. Deskew — straighten the scan
|
||
2. Dewarp — correct vertical shear (ensemble CV or VLM)
|
||
3. Columns — detect column layout
|
||
4. Rows — detect row layout
|
||
5. Words — OCR each cell
|
||
6. LLM review — correct OCR errors (optional)
|
||
|
||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||
Yields SSE events of the form:
|
||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||
|
||
Final event:
|
||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||
"""
|
||
if req.from_step < 1 or req.from_step > 6:
|
||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||
|
||
if session_id not in _cache:
|
||
await _load_session_to_cache(session_id)
|
||
|
||
async def _generate():
|
||
steps_run: List[str] = []
|
||
steps_skipped: List[str] = []
|
||
error_step: Optional[str] = None
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||
return
|
||
|
||
cached = _get_cached(session_id)
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 1: Deskew
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 1:
|
||
yield await _auto_sse_event("deskew", "start", {})
|
||
try:
|
||
t0 = time.time()
|
||
orig_bgr = cached.get("original_bgr")
|
||
if orig_bgr is None:
|
||
raise ValueError("Original image not loaded")
|
||
|
||
# Method 1: Hough lines
|
||
try:
|
||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||
except Exception:
|
||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||
|
||
# Method 2: Word alignment
|
||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||
try:
|
||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||
except Exception:
|
||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||
|
||
# Pick best method
|
||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||
method_used = "word_alignment"
|
||
angle_applied = angle_wa
|
||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||
if deskewed_bgr is None:
|
||
deskewed_bgr = deskewed_hough
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
else:
|
||
method_used = "hough"
|
||
angle_applied = angle_hough
|
||
deskewed_bgr = deskewed_hough
|
||
|
||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
deskewed_png = png_buf.tobytes() if success else b""
|
||
|
||
deskew_result = {
|
||
"method_used": method_used,
|
||
"rotation_degrees": round(float(angle_applied), 3),
|
||
"duration_seconds": round(time.time() - t0, 2),
|
||
}
|
||
|
||
cached["deskewed_bgr"] = deskewed_bgr
|
||
cached["deskew_result"] = deskew_result
|
||
await update_session_db(
|
||
session_id,
|
||
deskewed_png=deskewed_png,
|
||
deskew_result=deskew_result,
|
||
auto_rotation_degrees=float(angle_applied),
|
||
current_step=3,
|
||
)
|
||
session = await get_session_db(session_id)
|
||
|
||
steps_run.append("deskew")
|
||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||
except Exception as e:
|
||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||
error_step = "deskew"
|
||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||
return
|
||
else:
|
||
steps_skipped.append("deskew")
|
||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 2: Dewarp
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 2:
|
||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||
try:
|
||
t0 = time.time()
|
||
deskewed_bgr = cached.get("deskewed_bgr")
|
||
if deskewed_bgr is None:
|
||
raise ValueError("Deskewed image not available")
|
||
|
||
if req.dewarp_method == "vlm":
|
||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||
img_bytes = png_buf.tobytes() if success_enc else b""
|
||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||
shear_deg = vlm_det["shear_degrees"]
|
||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||
else:
|
||
dewarped_bgr = deskewed_bgr
|
||
dewarp_info = {
|
||
"method": vlm_det["method"],
|
||
"shear_degrees": shear_deg,
|
||
"confidence": vlm_det["confidence"],
|
||
"detections": [vlm_det],
|
||
}
|
||
else:
|
||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||
|
||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||
|
||
dewarp_result = {
|
||
"method_used": dewarp_info["method"],
|
||
"shear_degrees": dewarp_info["shear_degrees"],
|
||
"confidence": dewarp_info["confidence"],
|
||
"duration_seconds": round(time.time() - t0, 2),
|
||
"detections": dewarp_info.get("detections", []),
|
||
}
|
||
|
||
cached["dewarped_bgr"] = dewarped_bgr
|
||
cached["dewarp_result"] = dewarp_result
|
||
await update_session_db(
|
||
session_id,
|
||
dewarped_png=dewarped_png,
|
||
dewarp_result=dewarp_result,
|
||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||
current_step=4,
|
||
)
|
||
session = await get_session_db(session_id)
|
||
|
||
steps_run.append("dewarp")
|
||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||
except Exception as e:
|
||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||
error_step = "dewarp"
|
||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||
return
|
||
else:
|
||
steps_skipped.append("dewarp")
|
||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 3: Columns
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 3:
|
||
yield await _auto_sse_event("columns", "start", {})
|
||
try:
|
||
t0 = time.time()
|
||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
if col_img is None:
|
||
raise ValueError("Cropped/dewarped image not available")
|
||
|
||
ocr_img = create_ocr_image(col_img)
|
||
h, w = ocr_img.shape[:2]
|
||
|
||
geo_result = detect_column_geometry(ocr_img, col_img)
|
||
if geo_result is None:
|
||
layout_img = create_layout_image(col_img)
|
||
regions = analyze_layout(layout_img, ocr_img)
|
||
cached["_word_dicts"] = None
|
||
cached["_inv"] = None
|
||
cached["_content_bounds"] = None
|
||
else:
|
||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||
content_w = right_x - left_x
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||
|
||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||
left_x=left_x, right_x=right_x, inv=inv)
|
||
|
||
columns = [asdict(r) for r in regions]
|
||
column_result = {
|
||
"columns": columns,
|
||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||
"duration_seconds": round(time.time() - t0, 2),
|
||
}
|
||
|
||
cached["column_result"] = column_result
|
||
await update_session_db(session_id, column_result=column_result,
|
||
row_result=None, word_result=None, current_step=6)
|
||
session = await get_session_db(session_id)
|
||
|
||
steps_run.append("columns")
|
||
yield await _auto_sse_event("columns", "done", {
|
||
"column_count": len(columns),
|
||
"duration_seconds": column_result["duration_seconds"],
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||
error_step = "columns"
|
||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||
return
|
||
else:
|
||
steps_skipped.append("columns")
|
||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 4: Rows
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 4:
|
||
yield await _auto_sse_event("rows", "start", {})
|
||
try:
|
||
t0 = time.time()
|
||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
session = await get_session_db(session_id)
|
||
column_result = session.get("column_result") or cached.get("column_result")
|
||
if not column_result or not column_result.get("columns"):
|
||
raise ValueError("Column detection must complete first")
|
||
|
||
col_regions = [
|
||
PageRegion(
|
||
type=c["type"], x=c["x"], y=c["y"],
|
||
width=c["width"], height=c["height"],
|
||
classification_confidence=c.get("classification_confidence", 1.0),
|
||
classification_method=c.get("classification_method", ""),
|
||
)
|
||
for c in column_result["columns"]
|
||
]
|
||
|
||
word_dicts = cached.get("_word_dicts")
|
||
inv = cached.get("_inv")
|
||
content_bounds = cached.get("_content_bounds")
|
||
|
||
if word_dicts is None or inv is None or content_bounds is None:
|
||
ocr_img_tmp = create_ocr_image(row_img)
|
||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||
if geo_result is None:
|
||
raise ValueError("Column geometry detection failed — cannot detect rows")
|
||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||
cached["_word_dicts"] = word_dicts
|
||
cached["_inv"] = inv
|
||
cached["_content_bounds"] = (lx, rx, ty, by)
|
||
content_bounds = (lx, rx, ty, by)
|
||
|
||
left_x, right_x, top_y, bottom_y = content_bounds
|
||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||
|
||
row_list = [
|
||
{
|
||
"index": r.index, "x": r.x, "y": r.y,
|
||
"width": r.width, "height": r.height,
|
||
"word_count": r.word_count,
|
||
"row_type": r.row_type,
|
||
"gap_before": r.gap_before,
|
||
}
|
||
for r in row_geoms
|
||
]
|
||
row_result = {
|
||
"rows": row_list,
|
||
"row_count": len(row_list),
|
||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||
"duration_seconds": round(time.time() - t0, 2),
|
||
}
|
||
|
||
cached["row_result"] = row_result
|
||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||
session = await get_session_db(session_id)
|
||
|
||
steps_run.append("rows")
|
||
yield await _auto_sse_event("rows", "done", {
|
||
"row_count": len(row_list),
|
||
"content_rows": row_result["content_rows"],
|
||
"duration_seconds": row_result["duration_seconds"],
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||
error_step = "rows"
|
||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||
return
|
||
else:
|
||
steps_skipped.append("rows")
|
||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 5: Words (OCR)
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 5:
|
||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||
try:
|
||
t0 = time.time()
|
||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||
session = await get_session_db(session_id)
|
||
|
||
column_result = session.get("column_result") or cached.get("column_result")
|
||
row_result = session.get("row_result") or cached.get("row_result")
|
||
|
||
col_regions = [
|
||
PageRegion(
|
||
type=c["type"], x=c["x"], y=c["y"],
|
||
width=c["width"], height=c["height"],
|
||
classification_confidence=c.get("classification_confidence", 1.0),
|
||
classification_method=c.get("classification_method", ""),
|
||
)
|
||
for c in column_result["columns"]
|
||
]
|
||
row_geoms = [
|
||
RowGeometry(
|
||
index=r["index"], x=r["x"], y=r["y"],
|
||
width=r["width"], height=r["height"],
|
||
word_count=r.get("word_count", 0), words=[],
|
||
row_type=r.get("row_type", "content"),
|
||
gap_before=r.get("gap_before", 0),
|
||
)
|
||
for r in row_result["rows"]
|
||
]
|
||
|
||
word_dicts = cached.get("_word_dicts")
|
||
if word_dicts is not None:
|
||
content_bounds = cached.get("_content_bounds")
|
||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||
for row in row_geoms:
|
||
row_y_rel = row.y - top_y
|
||
row_bottom_rel = row_y_rel + row.height
|
||
row.words = [
|
||
w for w in word_dicts
|
||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||
]
|
||
row.word_count = len(row.words)
|
||
|
||
ocr_img = create_ocr_image(word_img)
|
||
img_h, img_w = word_img.shape[:2]
|
||
|
||
cells, columns_meta = build_cell_grid(
|
||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||
)
|
||
duration = time.time() - t0
|
||
|
||
col_types = {c['type'] for c in columns_meta}
|
||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||
|
||
# Apply IPA phonetic fixes directly to cell texts
|
||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||
|
||
word_result_data = {
|
||
"cells": cells,
|
||
"grid_shape": {
|
||
"rows": n_content_rows,
|
||
"cols": len(columns_meta),
|
||
"total_cells": len(cells),
|
||
},
|
||
"columns_used": columns_meta,
|
||
"layout": "vocab" if is_vocab else "generic",
|
||
"image_width": img_w,
|
||
"image_height": img_h,
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": {
|
||
"total_cells": len(cells),
|
||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||
},
|
||
}
|
||
|
||
has_text_col = 'column_text' in col_types
|
||
if is_vocab or has_text_col:
|
||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||
entries = _fix_character_confusion(entries)
|
||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||
word_result_data["vocab_entries"] = entries
|
||
word_result_data["entries"] = entries
|
||
word_result_data["entry_count"] = len(entries)
|
||
word_result_data["summary"]["total_entries"] = len(entries)
|
||
|
||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||
cached["word_result"] = word_result_data
|
||
session = await get_session_db(session_id)
|
||
|
||
steps_run.append("words")
|
||
yield await _auto_sse_event("words", "done", {
|
||
"total_cells": len(cells),
|
||
"layout": word_result_data["layout"],
|
||
"duration_seconds": round(duration, 2),
|
||
"ocr_engine": used_engine,
|
||
"summary": word_result_data["summary"],
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||
error_step = "words"
|
||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||
return
|
||
else:
|
||
steps_skipped.append("words")
|
||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Step 6: LLM Review (optional)
|
||
# -----------------------------------------------------------------
|
||
if req.from_step <= 6 and not req.skip_llm_review:
|
||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||
try:
|
||
session = await get_session_db(session_id)
|
||
word_result = session.get("word_result") or cached.get("word_result")
|
||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||
|
||
if not entries:
|
||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||
steps_skipped.append("llm_review")
|
||
else:
|
||
reviewed = await llm_review_entries(entries)
|
||
|
||
session = await get_session_db(session_id)
|
||
word_result_updated = dict(session.get("word_result") or {})
|
||
word_result_updated["entries"] = reviewed
|
||
word_result_updated["vocab_entries"] = reviewed
|
||
word_result_updated["llm_reviewed"] = True
|
||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||
|
||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||
cached["word_result"] = word_result_updated
|
||
|
||
steps_run.append("llm_review")
|
||
yield await _auto_sse_event("llm_review", "done", {
|
||
"entries_reviewed": len(reviewed),
|
||
"model": OLLAMA_REVIEW_MODEL,
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||
steps_skipped.append("llm_review")
|
||
else:
|
||
steps_skipped.append("llm_review")
|
||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||
|
||
# -----------------------------------------------------------------
|
||
# Final event
|
||
# -----------------------------------------------------------------
|
||
yield await _auto_sse_event("complete", "done", {
|
||
"steps_run": steps_run,
|
||
"steps_skipped": steps_skipped,
|
||
})
|
||
|
||
return StreamingResponse(
|
||
_generate(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|