refactor: split ocr_pipeline_api.py (5426 lines) into 8 modules
Each module is under 1050 lines: - ocr_pipeline_common.py (354) - shared state, cache, models, helpers - ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type - ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns - ocr_pipeline_rows.py (348) - row detection, box-overlay helper - ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct - ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints - ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export - ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports router, _cache, and test-imported symbols for backward compatibility. No changes needed in main.py or tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
705
klausur-service/backend/ocr_pipeline_auto.py
Normal file
705
klausur-service/backend/ocr_pipeline_auto.py
Normal file
@@ -0,0 +1,705 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py — contains:
|
||||||
|
- POST /sessions/{session_id}/reprocess (clear downstream + restart from step)
|
||||||
|
- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming)
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
OLLAMA_REVIEW_MODEL,
|
||||||
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
|
_cells_to_vocab_entries,
|
||||||
|
_detect_header_footer_gaps,
|
||||||
|
_detect_sub_columns,
|
||||||
|
_fix_character_confusion,
|
||||||
|
_fix_phonetic_brackets,
|
||||||
|
fix_cell_phonetics,
|
||||||
|
analyze_layout,
|
||||||
|
build_cell_grid,
|
||||||
|
classify_column_types,
|
||||||
|
create_layout_image,
|
||||||
|
create_ocr_image,
|
||||||
|
deskew_image,
|
||||||
|
deskew_image_by_word_alignment,
|
||||||
|
detect_column_geometry,
|
||||||
|
detect_row_geometry,
|
||||||
|
_apply_shear,
|
||||||
|
dewarp_image,
|
||||||
|
llm_review_entries,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_get_base_image_png,
|
||||||
|
_append_pipeline_log,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reprocess endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reprocess")
|
||||||
|
async def reprocess_session(session_id: str, request: Request):
|
||||||
|
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||||
|
|
||||||
|
Body: {"from_step": 5} (1-indexed step number)
|
||||||
|
|
||||||
|
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
|
||||||
|
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
|
||||||
|
|
||||||
|
Clears downstream results:
|
||||||
|
- from_step <= 1: orientation_result + all downstream
|
||||||
|
- from_step <= 2: deskew_result + all downstream
|
||||||
|
- from_step <= 3: dewarp_result + all downstream
|
||||||
|
- from_step <= 4: crop_result + all downstream
|
||||||
|
- from_step <= 5: column_result, row_result, word_result
|
||||||
|
- from_step <= 6: row_result, word_result
|
||||||
|
- from_step <= 7: word_result (cells, vocab_entries)
|
||||||
|
- from_step <= 8: word_result.llm_review only
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
from_step = body.get("from_step", 1)
|
||||||
|
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||||
|
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||||
|
|
||||||
|
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||||
|
|
||||||
|
# Clear downstream data based on from_step
|
||||||
|
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
|
||||||
|
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
|
||||||
|
if from_step <= 8:
|
||||||
|
update_kwargs["word_result"] = None
|
||||||
|
elif from_step == 9:
|
||||||
|
# Only clear LLM review from word_result
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if word_result:
|
||||||
|
word_result.pop("llm_review", None)
|
||||||
|
word_result.pop("llm_corrections", None)
|
||||||
|
update_kwargs["word_result"] = word_result
|
||||||
|
|
||||||
|
if from_step <= 7:
|
||||||
|
update_kwargs["row_result"] = None
|
||||||
|
if from_step <= 6:
|
||||||
|
update_kwargs["column_result"] = None
|
||||||
|
if from_step <= 4:
|
||||||
|
update_kwargs["crop_result"] = None
|
||||||
|
if from_step <= 3:
|
||||||
|
update_kwargs["dewarp_result"] = None
|
||||||
|
if from_step <= 2:
|
||||||
|
update_kwargs["deskew_result"] = None
|
||||||
|
if from_step <= 1:
|
||||||
|
update_kwargs["orientation_result"] = None
|
||||||
|
|
||||||
|
await update_session_db(session_id, **update_kwargs)
|
||||||
|
|
||||||
|
# Also clear cache
|
||||||
|
if session_id in _cache:
|
||||||
|
for key in list(update_kwargs.keys()):
|
||||||
|
if key != "current_step":
|
||||||
|
_cache[session_id][key] = update_kwargs[key]
|
||||||
|
_cache[session_id]["current_step"] = from_step
|
||||||
|
|
||||||
|
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"from_step": from_step,
|
||||||
|
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VLM shear detection helper (used by dewarp step in auto-mode)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||||
|
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||||
|
|
||||||
|
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||||
|
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||||
|
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
import re
|
||||||
|
|
||||||
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||||
|
"Are they perfectly vertical, or do they tilt slightly? "
|
||||||
|
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||||
|
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||||
|
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||||
|
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
|
||||||
|
# Parse JSON from response (may have surrounding text)
|
||||||
|
match = re.search(r'\{[^}]+\}', text)
|
||||||
|
if match:
|
||||||
|
import json
|
||||||
|
data = json.loads(match.group(0))
|
||||||
|
shear = float(data.get("shear_degrees", 0.0))
|
||||||
|
conf = float(data.get("confidence", 0.0))
|
||||||
|
# Clamp to reasonable range
|
||||||
|
shear = max(-3.0, min(3.0, shear))
|
||||||
|
conf = max(0.0, min(1.0, conf))
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"VLM dewarp failed: {e}")
|
||||||
|
|
||||||
|
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auto-mode orchestrator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RunAutoRequest(BaseModel):
|
||||||
|
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||||
|
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||||
|
pronunciation: str = "british"
|
||||||
|
skip_llm_review: bool = False
|
||||||
|
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||||
|
|
||||||
|
|
||||||
|
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||||
|
"""Format a single SSE event line."""
|
||||||
|
import json as _json
|
||||||
|
payload = {"step": step, "status": status, **data}
|
||||||
|
return f"data: {_json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/run-auto")
|
||||||
|
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||||
|
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Deskew — straighten the scan
|
||||||
|
2. Dewarp — correct vertical shear (ensemble CV or VLM)
|
||||||
|
3. Columns — detect column layout
|
||||||
|
4. Rows — detect row layout
|
||||||
|
5. Words — OCR each cell
|
||||||
|
6. LLM review — correct OCR errors (optional)
|
||||||
|
|
||||||
|
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||||
|
Yields SSE events of the form:
|
||||||
|
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||||
|
|
||||||
|
Final event:
|
||||||
|
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||||
|
"""
|
||||||
|
if req.from_step < 1 or req.from_step > 6:
|
||||||
|
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||||
|
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||||
|
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||||
|
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
|
||||||
|
async def _generate():
|
||||||
|
steps_run: List[str] = []
|
||||||
|
steps_skipped: List[str] = []
|
||||||
|
error_step: Optional[str] = None
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||||
|
return
|
||||||
|
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 1: Deskew
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 1:
|
||||||
|
yield await _auto_sse_event("deskew", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
orig_bgr = cached.get("original_bgr")
|
||||||
|
if orig_bgr is None:
|
||||||
|
raise ValueError("Original image not loaded")
|
||||||
|
|
||||||
|
# Method 1: Hough lines
|
||||||
|
try:
|
||||||
|
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||||
|
except Exception:
|
||||||
|
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||||
|
|
||||||
|
# Method 2: Word alignment
|
||||||
|
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||||
|
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||||
|
try:
|
||||||
|
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||||
|
except Exception:
|
||||||
|
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||||
|
|
||||||
|
# Pick best method
|
||||||
|
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||||
|
method_used = "word_alignment"
|
||||||
|
angle_applied = angle_wa
|
||||||
|
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||||
|
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
deskewed_bgr = deskewed_hough
|
||||||
|
method_used = "hough"
|
||||||
|
angle_applied = angle_hough
|
||||||
|
else:
|
||||||
|
method_used = "hough"
|
||||||
|
angle_applied = angle_hough
|
||||||
|
deskewed_bgr = deskewed_hough
|
||||||
|
|
||||||
|
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
deskewed_png = png_buf.tobytes() if success else b""
|
||||||
|
|
||||||
|
deskew_result = {
|
||||||
|
"method_used": method_used,
|
||||||
|
"rotation_degrees": round(float(angle_applied), 3),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["deskewed_bgr"] = deskewed_bgr
|
||||||
|
cached["deskew_result"] = deskew_result
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
deskewed_png=deskewed_png,
|
||||||
|
deskew_result=deskew_result,
|
||||||
|
auto_rotation_degrees=float(angle_applied),
|
||||||
|
current_step=3,
|
||||||
|
)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("deskew")
|
||||||
|
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||||
|
error_step = "deskew"
|
||||||
|
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("deskew")
|
||||||
|
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 2: Dewarp
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 2:
|
||||||
|
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
deskewed_bgr = cached.get("deskewed_bgr")
|
||||||
|
if deskewed_bgr is None:
|
||||||
|
raise ValueError("Deskewed image not available")
|
||||||
|
|
||||||
|
if req.dewarp_method == "vlm":
|
||||||
|
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||||
|
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||||
|
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||||
|
shear_deg = vlm_det["shear_degrees"]
|
||||||
|
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||||
|
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||||
|
else:
|
||||||
|
dewarped_bgr = deskewed_bgr
|
||||||
|
dewarp_info = {
|
||||||
|
"method": vlm_det["method"],
|
||||||
|
"shear_degrees": shear_deg,
|
||||||
|
"confidence": vlm_det["confidence"],
|
||||||
|
"detections": [vlm_det],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||||
|
|
||||||
|
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||||
|
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||||
|
|
||||||
|
dewarp_result = {
|
||||||
|
"method_used": dewarp_info["method"],
|
||||||
|
"shear_degrees": dewarp_info["shear_degrees"],
|
||||||
|
"confidence": dewarp_info["confidence"],
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
"detections": dewarp_info.get("detections", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["dewarped_bgr"] = dewarped_bgr
|
||||||
|
cached["dewarp_result"] = dewarp_result
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
dewarped_png=dewarped_png,
|
||||||
|
dewarp_result=dewarp_result,
|
||||||
|
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||||
|
current_step=4,
|
||||||
|
)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("dewarp")
|
||||||
|
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||||
|
error_step = "dewarp"
|
||||||
|
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("dewarp")
|
||||||
|
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 3: Columns
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 3:
|
||||||
|
yield await _auto_sse_event("columns", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if col_img is None:
|
||||||
|
raise ValueError("Cropped/dewarped image not available")
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(col_img)
|
||||||
|
h, w = ocr_img.shape[:2]
|
||||||
|
|
||||||
|
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||||
|
if geo_result is None:
|
||||||
|
layout_img = create_layout_image(col_img)
|
||||||
|
regions = analyze_layout(layout_img, ocr_img)
|
||||||
|
cached["_word_dicts"] = None
|
||||||
|
cached["_inv"] = None
|
||||||
|
cached["_content_bounds"] = None
|
||||||
|
else:
|
||||||
|
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
content_w = right_x - left_x
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||||
|
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||||
|
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||||
|
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||||
|
left_x=left_x, right_x=right_x, inv=inv)
|
||||||
|
|
||||||
|
columns = [asdict(r) for r in regions]
|
||||||
|
column_result = {
|
||||||
|
"columns": columns,
|
||||||
|
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["column_result"] = column_result
|
||||||
|
await update_session_db(session_id, column_result=column_result,
|
||||||
|
row_result=None, word_result=None, current_step=6)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("columns")
|
||||||
|
yield await _auto_sse_event("columns", "done", {
|
||||||
|
"column_count": len(columns),
|
||||||
|
"duration_seconds": column_result["duration_seconds"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||||
|
error_step = "columns"
|
||||||
|
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("columns")
|
||||||
|
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 4: Rows
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 4:
|
||||||
|
yield await _auto_sse_event("rows", "start", {})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
column_result = session.get("column_result") or cached.get("column_result")
|
||||||
|
if not column_result or not column_result.get("columns"):
|
||||||
|
raise ValueError("Column detection must complete first")
|
||||||
|
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"], x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
inv = cached.get("_inv")
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
|
||||||
|
if word_dicts is None or inv is None or content_bounds is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(row_img)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||||
|
if geo_result is None:
|
||||||
|
raise ValueError("Column geometry detection failed — cannot detect rows")
|
||||||
|
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||||
|
content_bounds = (lx, rx, ty, by)
|
||||||
|
|
||||||
|
left_x, right_x, top_y, bottom_y = content_bounds
|
||||||
|
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
row_list = [
|
||||||
|
{
|
||||||
|
"index": r.index, "x": r.x, "y": r.y,
|
||||||
|
"width": r.width, "height": r.height,
|
||||||
|
"word_count": r.word_count,
|
||||||
|
"row_type": r.row_type,
|
||||||
|
"gap_before": r.gap_before,
|
||||||
|
}
|
||||||
|
for r in row_geoms
|
||||||
|
]
|
||||||
|
row_result = {
|
||||||
|
"rows": row_list,
|
||||||
|
"row_count": len(row_list),
|
||||||
|
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||||
|
"duration_seconds": round(time.time() - t0, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
cached["row_result"] = row_result
|
||||||
|
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("rows")
|
||||||
|
yield await _auto_sse_event("rows", "done", {
|
||||||
|
"row_count": len(row_list),
|
||||||
|
"content_rows": row_result["content_rows"],
|
||||||
|
"duration_seconds": row_result["duration_seconds"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||||
|
error_step = "rows"
|
||||||
|
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("rows")
|
||||||
|
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 5: Words (OCR)
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 5:
|
||||||
|
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||||
|
try:
|
||||||
|
t0 = time.time()
|
||||||
|
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
column_result = session.get("column_result") or cached.get("column_result")
|
||||||
|
row_result = session.get("row_result") or cached.get("row_result")
|
||||||
|
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"], x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
row_geoms = [
|
||||||
|
RowGeometry(
|
||||||
|
index=r["index"], x=r["x"], y=r["y"],
|
||||||
|
width=r["width"], height=r["height"],
|
||||||
|
word_count=r.get("word_count", 0), words=[],
|
||||||
|
row_type=r.get("row_type", "content"),
|
||||||
|
gap_before=r.get("gap_before", 0),
|
||||||
|
)
|
||||||
|
for r in row_result["rows"]
|
||||||
|
]
|
||||||
|
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
if word_dicts is not None:
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||||
|
for row in row_geoms:
|
||||||
|
row_y_rel = row.y - top_y
|
||||||
|
row_bottom_rel = row_y_rel + row.height
|
||||||
|
row.words = [
|
||||||
|
w for w in word_dicts
|
||||||
|
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||||
|
]
|
||||||
|
row.word_count = len(row.words)
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(word_img)
|
||||||
|
img_h, img_w = word_img.shape[:2]
|
||||||
|
|
||||||
|
cells, columns_meta = build_cell_grid(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||||
|
|
||||||
|
# Apply IPA phonetic fixes directly to cell texts
|
||||||
|
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||||
|
|
||||||
|
word_result_data = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_content_rows,
|
||||||
|
"cols": len(columns_meta),
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_character_confusion(entries)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||||
|
word_result_data["vocab_entries"] = entries
|
||||||
|
word_result_data["entries"] = entries
|
||||||
|
word_result_data["entry_count"] = len(entries)
|
||||||
|
word_result_data["summary"]["total_entries"] = len(entries)
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||||
|
cached["word_result"] = word_result_data
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
|
||||||
|
steps_run.append("words")
|
||||||
|
yield await _auto_sse_event("words", "done", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"layout": word_result_data["layout"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": word_result_data["summary"],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||||
|
error_step = "words"
|
||||||
|
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||||
|
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
steps_skipped.append("words")
|
||||||
|
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Step 6: LLM Review (optional)
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
if req.from_step <= 6 and not req.skip_llm_review:
|
||||||
|
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||||
|
try:
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
word_result = session.get("word_result") or cached.get("word_result")
|
||||||
|
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
else:
|
||||||
|
reviewed = await llm_review_entries(entries)
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
word_result_updated = dict(session.get("word_result") or {})
|
||||||
|
word_result_updated["entries"] = reviewed
|
||||||
|
word_result_updated["vocab_entries"] = reviewed
|
||||||
|
word_result_updated["llm_reviewed"] = True
|
||||||
|
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||||
|
cached["word_result"] = word_result_updated
|
||||||
|
|
||||||
|
steps_run.append("llm_review")
|
||||||
|
yield await _auto_sse_event("llm_review", "done", {
|
||||||
|
"entries_reviewed": len(reviewed),
|
||||||
|
"model": OLLAMA_REVIEW_MODEL,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||||
|
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
else:
|
||||||
|
steps_skipped.append("llm_review")
|
||||||
|
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||||
|
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Final event
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
yield await _auto_sse_event("complete", "done", {
|
||||||
|
"steps_run": steps_run,
|
||||||
|
"steps_skipped": steps_skipped,
|
||||||
|
})
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
354
klausur-service/backend/ocr_pipeline_common.py
Normal file
354
klausur-service/backend/ocr_pipeline_common.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
Shared common module for the OCR pipeline.
|
||||||
|
|
||||||
|
Contains in-memory cache, helper functions, Pydantic request models,
|
||||||
|
pipeline logging, and border-ghost word filtering used by the pipeline
|
||||||
|
API endpoints and related modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ocr_pipeline_session_store import get_session_db, get_session_image, update_session_db
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Cache
|
||||||
|
"_cache",
|
||||||
|
# Helper functions
|
||||||
|
"_get_base_image_png",
|
||||||
|
"_load_session_to_cache",
|
||||||
|
"_get_cached",
|
||||||
|
# Pydantic models
|
||||||
|
"ManualDeskewRequest",
|
||||||
|
"DeskewGroundTruthRequest",
|
||||||
|
"ManualDewarpRequest",
|
||||||
|
"CombinedAdjustRequest",
|
||||||
|
"DewarpGroundTruthRequest",
|
||||||
|
"VALID_DOCUMENT_CATEGORIES",
|
||||||
|
"UpdateSessionRequest",
|
||||||
|
"ManualColumnsRequest",
|
||||||
|
"ColumnGroundTruthRequest",
|
||||||
|
"ManualRowsRequest",
|
||||||
|
"RowGroundTruthRequest",
|
||||||
|
"RemoveHandwritingRequest",
|
||||||
|
# Pipeline log
|
||||||
|
"_append_pipeline_log",
|
||||||
|
# Border-ghost filter
|
||||||
|
"_BORDER_GHOST_CHARS",
|
||||||
|
"_filter_border_ghost_words",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# In-memory cache for active sessions (BGR numpy arrays for processing)
|
||||||
|
# DB is source of truth, cache holds BGR arrays during active processing.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_base_image_png(session_id: str) -> Optional[bytes]:
|
||||||
|
"""Get the best available base image for a session (cropped > dewarped > original)."""
|
||||||
|
for img_type in ("cropped", "dewarped", "original"):
|
||||||
|
png_data = await get_session_image(session_id, img_type)
|
||||||
|
if png_data:
|
||||||
|
return png_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_session_to_cache(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Load session from DB into cache, decoding PNGs to BGR arrays."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
return _cache[session_id]
|
||||||
|
|
||||||
|
cache_entry: Dict[str, Any] = {
|
||||||
|
"id": session_id,
|
||||||
|
**session,
|
||||||
|
"original_bgr": None,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Decode images from DB into BGR numpy arrays
|
||||||
|
for img_type, bgr_key in [
|
||||||
|
("original", "original_bgr"),
|
||||||
|
("oriented", "oriented_bgr"),
|
||||||
|
("cropped", "cropped_bgr"),
|
||||||
|
("deskewed", "deskewed_bgr"),
|
||||||
|
("dewarped", "dewarped_bgr"),
|
||||||
|
]:
|
||||||
|
png_data = await get_session_image(session_id, img_type)
|
||||||
|
if png_data:
|
||||||
|
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||||
|
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
cache_entry[bgr_key] = bgr
|
||||||
|
|
||||||
|
# Sub-sessions: original image IS the cropped box region.
|
||||||
|
# Promote original_bgr to cropped_bgr so downstream steps find it.
|
||||||
|
if session.get("parent_session_id") and cache_entry["original_bgr"] is not None:
|
||||||
|
if cache_entry["cropped_bgr"] is None and cache_entry["dewarped_bgr"] is None:
|
||||||
|
cache_entry["cropped_bgr"] = cache_entry["original_bgr"]
|
||||||
|
|
||||||
|
_cache[session_id] = cache_entry
|
||||||
|
return cache_entry
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cached(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get from cache or raise 404."""
|
||||||
|
entry = _cache.get(session_id)
|
||||||
|
if not entry:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not in cache — reload first")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ManualDeskewRequest(BaseModel):
|
||||||
|
angle: float
|
||||||
|
|
||||||
|
|
||||||
|
class DeskewGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_angle: Optional[float] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualDewarpRequest(BaseModel):
|
||||||
|
shear_degrees: float
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedAdjustRequest(BaseModel):
|
||||||
|
rotation_degrees: float = 0.0
|
||||||
|
shear_degrees: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class DewarpGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_shear: Optional[float] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
VALID_DOCUMENT_CATEGORIES = {
|
||||||
|
'vokabelseite', 'buchseite', 'arbeitsblatt', 'klausurseite',
|
||||||
|
'mathearbeit', 'statistik', 'zeitung', 'formular', 'handschrift', 'sonstiges',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSessionRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
document_category: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualColumnsRequest(BaseModel):
|
||||||
|
columns: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_columns: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ManualRowsRequest(BaseModel):
|
||||||
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class RowGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_rows: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveHandwritingRequest(BaseModel):
|
||||||
|
method: str = "auto" # "auto" | "telea" | "ns"
|
||||||
|
target_ink: str = "all" # "all" | "colored" | "pencil"
|
||||||
|
dilation: int = 2 # mask dilation iterations (0-5)
|
||||||
|
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pipeline Log Helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _append_pipeline_log(
|
||||||
|
session_id: str,
|
||||||
|
step_name: str,
|
||||||
|
metrics: Dict[str, Any],
|
||||||
|
success: bool = True,
|
||||||
|
duration_ms: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Append a step entry to the session's pipeline_log JSONB."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
log = session.get("pipeline_log") or {"steps": []}
|
||||||
|
if not isinstance(log, dict):
|
||||||
|
log = {"steps": []}
|
||||||
|
entry = {
|
||||||
|
"step": step_name,
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
"success": success,
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
if duration_ms is not None:
|
||||||
|
entry["duration_ms"] = duration_ms
|
||||||
|
log.setdefault("steps", []).append(entry)
|
||||||
|
await update_session_db(session_id, pipeline_log=log)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Border-ghost word filter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Characters that OCR produces when reading box-border lines.
|
||||||
|
_BORDER_GHOST_CHARS = set("|1lI![](){}iíì/\\-—–_~.,;:'\"")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_border_ghost_words(
|
||||||
|
word_result: Dict,
|
||||||
|
boxes: List,
|
||||||
|
) -> int:
|
||||||
|
"""Remove OCR words that are actually box border lines.
|
||||||
|
|
||||||
|
A word is considered a border ghost when it sits on a known box edge
|
||||||
|
(left, right, top, or bottom) and looks like a line artefact (narrow
|
||||||
|
aspect ratio or text consists only of line-like characters).
|
||||||
|
|
||||||
|
After removing ghost cells, columns that have become empty are also
|
||||||
|
removed from ``columns_used`` so the grid no longer shows phantom
|
||||||
|
columns.
|
||||||
|
|
||||||
|
Modifies *word_result* in-place and returns the number of removed cells.
|
||||||
|
"""
|
||||||
|
if not boxes or not word_result:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cells = word_result.get("cells")
|
||||||
|
if not cells:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Build border bands — vertical (X) and horizontal (Y)
|
||||||
|
x_bands = [] # list of (x_lo, x_hi)
|
||||||
|
y_bands = [] # list of (y_lo, y_hi)
|
||||||
|
for b in boxes:
|
||||||
|
bx = b.x if hasattr(b, "x") else b.get("x", 0)
|
||||||
|
by = b.y if hasattr(b, "y") else b.get("y", 0)
|
||||||
|
bw = b.width if hasattr(b, "width") else b.get("w", b.get("width", 0))
|
||||||
|
bh = b.height if hasattr(b, "height") else b.get("h", b.get("height", 0))
|
||||||
|
bt = b.border_thickness if hasattr(b, "border_thickness") else b.get("border_thickness", 3)
|
||||||
|
margin = max(bt * 2, 10) + 6 # generous margin
|
||||||
|
|
||||||
|
# Vertical edges (left / right)
|
||||||
|
x_bands.append((bx - margin, bx + margin))
|
||||||
|
x_bands.append((bx + bw - margin, bx + bw + margin))
|
||||||
|
# Horizontal edges (top / bottom)
|
||||||
|
y_bands.append((by - margin, by + margin))
|
||||||
|
y_bands.append((by + bh - margin, by + bh + margin))
|
||||||
|
|
||||||
|
img_w = word_result.get("image_width", 1)
|
||||||
|
img_h = word_result.get("image_height", 1)
|
||||||
|
|
||||||
|
def _is_ghost(cell: Dict) -> bool:
|
||||||
|
text = (cell.get("text") or "").strip()
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compute absolute pixel position
|
||||||
|
if cell.get("bbox_px"):
|
||||||
|
px = cell["bbox_px"]
|
||||||
|
cx = px["x"] + px["w"] / 2
|
||||||
|
cy = px["y"] + px["h"] / 2
|
||||||
|
cw = px["w"]
|
||||||
|
ch = px["h"]
|
||||||
|
elif cell.get("bbox_pct"):
|
||||||
|
pct = cell["bbox_pct"]
|
||||||
|
cx = (pct["x"] / 100) * img_w + (pct["w"] / 100) * img_w / 2
|
||||||
|
cy = (pct["y"] / 100) * img_h + (pct["h"] / 100) * img_h / 2
|
||||||
|
cw = (pct["w"] / 100) * img_w
|
||||||
|
ch = (pct["h"] / 100) * img_h
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if center sits on a vertical or horizontal border
|
||||||
|
on_vertical = any(lo <= cx <= hi for lo, hi in x_bands)
|
||||||
|
on_horizontal = any(lo <= cy <= hi for lo, hi in y_bands)
|
||||||
|
if not on_vertical and not on_horizontal:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Very short text (1-2 chars) on a border → very likely ghost
|
||||||
|
if len(text) <= 2:
|
||||||
|
# Narrow vertically (line-like) or narrow horizontally (dash-like)?
|
||||||
|
if ch > 0 and cw / ch < 0.5:
|
||||||
|
return True
|
||||||
|
if cw > 0 and ch / cw < 0.5:
|
||||||
|
return True
|
||||||
|
# Text is only border-ghost characters?
|
||||||
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Longer text but still only ghost chars and very narrow
|
||||||
|
if all(c in _BORDER_GHOST_CHARS for c in text):
|
||||||
|
if ch > 0 and cw / ch < 0.35:
|
||||||
|
return True
|
||||||
|
if cw > 0 and ch / cw < 0.35:
|
||||||
|
return True
|
||||||
|
return True # all ghost chars on a border → remove
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
before = len(cells)
|
||||||
|
word_result["cells"] = [c for c in cells if not _is_ghost(c)]
|
||||||
|
removed = before - len(word_result["cells"])
|
||||||
|
|
||||||
|
# --- Remove empty columns from columns_used ---
|
||||||
|
columns_used = word_result.get("columns_used")
|
||||||
|
if removed and columns_used and len(columns_used) > 1:
|
||||||
|
remaining_cells = word_result["cells"]
|
||||||
|
occupied_cols = {c.get("col_index") for c in remaining_cells}
|
||||||
|
before_cols = len(columns_used)
|
||||||
|
columns_used = [col for col in columns_used if col.get("index") in occupied_cols]
|
||||||
|
|
||||||
|
# Re-index columns and remap cell col_index values
|
||||||
|
if len(columns_used) < before_cols:
|
||||||
|
old_to_new = {}
|
||||||
|
for new_i, col in enumerate(columns_used):
|
||||||
|
old_to_new[col["index"]] = new_i
|
||||||
|
col["index"] = new_i
|
||||||
|
for cell in remaining_cells:
|
||||||
|
old_ci = cell.get("col_index")
|
||||||
|
if old_ci in old_to_new:
|
||||||
|
cell["col_index"] = old_to_new[old_ci]
|
||||||
|
word_result["columns_used"] = columns_used
|
||||||
|
logger.info("border-ghost: removed %d empty column(s), %d remaining",
|
||||||
|
before_cols - len(columns_used), len(columns_used))
|
||||||
|
|
||||||
|
if removed:
|
||||||
|
# Update summary counts
|
||||||
|
summary = word_result.get("summary", {})
|
||||||
|
summary["total_cells"] = len(word_result["cells"])
|
||||||
|
summary["non_empty_cells"] = sum(1 for c in word_result["cells"] if c.get("text"))
|
||||||
|
word_result["summary"] = summary
|
||||||
|
gs = word_result.get("grid_shape", {})
|
||||||
|
gs["total_cells"] = len(word_result["cells"])
|
||||||
|
if columns_used is not None:
|
||||||
|
gs["cols"] = len(columns_used)
|
||||||
|
word_result["grid_shape"] = gs
|
||||||
|
|
||||||
|
return removed
|
||||||
1025
klausur-service/backend/ocr_pipeline_geometry.py
Normal file
1025
klausur-service/backend/ocr_pipeline_geometry.py
Normal file
File diff suppressed because it is too large
Load Diff
615
klausur-service/backend/ocr_pipeline_ocr_merge.py
Normal file
615
klausur-service/backend/ocr_pipeline_ocr_merge.py
Normal file
@@ -0,0 +1,615 @@
|
|||||||
|
"""
|
||||||
|
OCR Merge Helpers and Kombi Endpoints.
|
||||||
|
|
||||||
|
Contains merge helper functions for combining PaddleOCR/RapidOCR with Tesseract
|
||||||
|
results, plus the paddle-kombi and rapid-kombi endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py for modularity.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_words_first import build_grid_from_words
|
||||||
|
from ocr_pipeline_common import _cache, _append_pipeline_log
|
||||||
|
from ocr_pipeline_session_store import get_session_image, update_session_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Merge helper functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _split_paddle_multi_words(words: list) -> list:
|
||||||
|
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
||||||
|
|
||||||
|
PaddleOCR often returns entire phrases as a single box, e.g.
|
||||||
|
"More than 200 singers took part in the" with one bounding box.
|
||||||
|
This splits them into individual words with proportional widths.
|
||||||
|
Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"])
|
||||||
|
and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]).
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for w in words:
|
||||||
|
raw_text = w.get("text", "").strip()
|
||||||
|
if not raw_text:
|
||||||
|
continue
|
||||||
|
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
||||||
|
tokens = re.split(
|
||||||
|
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
||||||
|
)
|
||||||
|
tokens = [t for t in tokens if t]
|
||||||
|
|
||||||
|
if len(tokens) <= 1:
|
||||||
|
result.append(w)
|
||||||
|
else:
|
||||||
|
# Split proportionally by character count
|
||||||
|
total_chars = sum(len(t) for t in tokens)
|
||||||
|
if total_chars == 0:
|
||||||
|
continue
|
||||||
|
n_gaps = len(tokens) - 1
|
||||||
|
gap_px = w["width"] * 0.02
|
||||||
|
usable_w = w["width"] - gap_px * n_gaps
|
||||||
|
cursor = w["left"]
|
||||||
|
for t in tokens:
|
||||||
|
token_w = max(1, usable_w * len(t) / total_chars)
|
||||||
|
result.append({
|
||||||
|
"text": t,
|
||||||
|
"left": round(cursor),
|
||||||
|
"top": w["top"],
|
||||||
|
"width": round(token_w),
|
||||||
|
"height": w["height"],
|
||||||
|
"conf": w.get("conf", 0),
|
||||||
|
})
|
||||||
|
cursor += token_w + gap_px
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
||||||
|
"""Group words into rows by Y-position clustering.
|
||||||
|
|
||||||
|
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
||||||
|
Returns list of rows, each row is a list of words sorted left-to-right.
|
||||||
|
"""
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
# Sort by vertical center
|
||||||
|
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
||||||
|
rows: list = []
|
||||||
|
current_row: list = [sorted_words[0]]
|
||||||
|
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
||||||
|
|
||||||
|
for w in sorted_words[1:]:
|
||||||
|
cy = w["top"] + w.get("height", 0) / 2
|
||||||
|
if abs(cy - current_cy) <= row_gap:
|
||||||
|
current_row.append(w)
|
||||||
|
else:
|
||||||
|
# Sort current row left-to-right before saving
|
||||||
|
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||||
|
current_row = [w]
|
||||||
|
current_cy = cy
|
||||||
|
if current_row:
|
||||||
|
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def _row_center_y(row: list) -> float:
|
||||||
|
"""Average vertical center of a row of words."""
|
||||||
|
if not row:
|
||||||
|
return 0.0
|
||||||
|
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
||||||
|
"""Merge two word sequences from the same row using sequence alignment.
|
||||||
|
|
||||||
|
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
||||||
|
- If words match (same/similar text): take Paddle text with averaged coords
|
||||||
|
- If they don't match: the extra word is unique to one engine, include it
|
||||||
|
|
||||||
|
This prevents duplicates because both engines produce words in the same order.
|
||||||
|
"""
|
||||||
|
merged = []
|
||||||
|
pi, ti = 0, 0
|
||||||
|
|
||||||
|
while pi < len(paddle_row) and ti < len(tess_row):
|
||||||
|
pw = paddle_row[pi]
|
||||||
|
tw = tess_row[ti]
|
||||||
|
|
||||||
|
# Check if these are the same word
|
||||||
|
pt = pw.get("text", "").lower().strip()
|
||||||
|
tt = tw.get("text", "").lower().strip()
|
||||||
|
|
||||||
|
# Same text or one contains the other
|
||||||
|
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
||||||
|
|
||||||
|
# Spatial overlap check: if words overlap >= 40% horizontally,
|
||||||
|
# they're the same physical word regardless of OCR text differences.
|
||||||
|
# (40% catches borderline cases like "Stick"/"Stück" at 48% overlap)
|
||||||
|
spatial_match = False
|
||||||
|
if not is_same:
|
||||||
|
overlap_left = max(pw["left"], tw["left"])
|
||||||
|
overlap_right = min(
|
||||||
|
pw["left"] + pw.get("width", 0),
|
||||||
|
tw["left"] + tw.get("width", 0),
|
||||||
|
)
|
||||||
|
overlap_w = max(0, overlap_right - overlap_left)
|
||||||
|
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
||||||
|
if min_w > 0 and overlap_w / min_w >= 0.4:
|
||||||
|
is_same = True
|
||||||
|
spatial_match = True
|
||||||
|
|
||||||
|
if is_same:
|
||||||
|
# Matched — average coordinates weighted by confidence
|
||||||
|
pc = pw.get("conf", 80)
|
||||||
|
tc = tw.get("conf", 50)
|
||||||
|
total = pc + tc
|
||||||
|
if total == 0:
|
||||||
|
total = 1
|
||||||
|
# Text: prefer higher-confidence engine when texts differ
|
||||||
|
# (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80)
|
||||||
|
if spatial_match and pc < tc:
|
||||||
|
best_text = tw["text"]
|
||||||
|
else:
|
||||||
|
best_text = pw["text"]
|
||||||
|
merged.append({
|
||||||
|
"text": best_text,
|
||||||
|
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
||||||
|
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
||||||
|
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
||||||
|
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
||||||
|
"conf": max(pc, tc),
|
||||||
|
})
|
||||||
|
pi += 1
|
||||||
|
ti += 1
|
||||||
|
else:
|
||||||
|
# Different text — one engine found something extra
|
||||||
|
# Look ahead: is the current Paddle word somewhere in Tesseract ahead?
|
||||||
|
paddle_ahead = any(
|
||||||
|
tess_row[t].get("text", "").lower().strip() == pt
|
||||||
|
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
||||||
|
)
|
||||||
|
# Is the current Tesseract word somewhere in Paddle ahead?
|
||||||
|
tess_ahead = any(
|
||||||
|
paddle_row[p].get("text", "").lower().strip() == tt
|
||||||
|
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
||||||
|
)
|
||||||
|
|
||||||
|
if paddle_ahead and not tess_ahead:
|
||||||
|
# Tesseract has an extra word (e.g. "!" or bullet) → include it
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
elif tess_ahead and not paddle_ahead:
|
||||||
|
# Paddle has an extra word → include it
|
||||||
|
merged.append(pw)
|
||||||
|
pi += 1
|
||||||
|
else:
|
||||||
|
# Both have unique words or neither found ahead → take leftmost first
|
||||||
|
if pw["left"] <= tw["left"]:
|
||||||
|
merged.append(pw)
|
||||||
|
pi += 1
|
||||||
|
else:
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
|
||||||
|
# Remaining words from either engine
|
||||||
|
while pi < len(paddle_row):
|
||||||
|
merged.append(paddle_row[pi])
|
||||||
|
pi += 1
|
||||||
|
while ti < len(tess_row):
|
||||||
|
tw = tess_row[ti]
|
||||||
|
if tw.get("conf", 0) >= 30:
|
||||||
|
merged.append(tw)
|
||||||
|
ti += 1
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
||||||
|
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
1. Group each engine's words into rows (by Y-position clustering)
|
||||||
|
2. Match rows between engines (by vertical center proximity)
|
||||||
|
3. Within each matched row: merge sequences left-to-right, deduplicating
|
||||||
|
words that appear in both engines at the same sequence position
|
||||||
|
4. Unmatched rows from either engine: keep as-is
|
||||||
|
|
||||||
|
This prevents:
|
||||||
|
- Cross-line averaging (words from different lines being merged)
|
||||||
|
- Duplicate words (same word from both engines shown twice)
|
||||||
|
"""
|
||||||
|
if not paddle_words and not tess_words:
|
||||||
|
return []
|
||||||
|
if not paddle_words:
|
||||||
|
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
||||||
|
if not tess_words:
|
||||||
|
return list(paddle_words)
|
||||||
|
|
||||||
|
# Step 1: Group into rows
|
||||||
|
paddle_rows = _group_words_into_rows(paddle_words)
|
||||||
|
tess_rows = _group_words_into_rows(tess_words)
|
||||||
|
|
||||||
|
# Step 2: Match rows between engines by vertical center proximity
|
||||||
|
used_tess_rows: set = set()
|
||||||
|
merged_all: list = []
|
||||||
|
|
||||||
|
for pr in paddle_rows:
|
||||||
|
pr_cy = _row_center_y(pr)
|
||||||
|
best_dist, best_tri = float("inf"), -1
|
||||||
|
for tri, tr in enumerate(tess_rows):
|
||||||
|
if tri in used_tess_rows:
|
||||||
|
continue
|
||||||
|
tr_cy = _row_center_y(tr)
|
||||||
|
dist = abs(pr_cy - tr_cy)
|
||||||
|
if dist < best_dist:
|
||||||
|
best_dist, best_tri = dist, tri
|
||||||
|
|
||||||
|
# Row height threshold — rows must be within ~1.5x typical line height
|
||||||
|
max_row_dist = max(
|
||||||
|
max((w.get("height", 20) for w in pr), default=20),
|
||||||
|
15,
|
||||||
|
)
|
||||||
|
|
||||||
|
if best_tri >= 0 and best_dist <= max_row_dist:
|
||||||
|
# Matched row — merge sequences
|
||||||
|
tr = tess_rows[best_tri]
|
||||||
|
used_tess_rows.add(best_tri)
|
||||||
|
merged_all.extend(_merge_row_sequences(pr, tr))
|
||||||
|
else:
|
||||||
|
# No matching Tesseract row — keep Paddle row as-is
|
||||||
|
merged_all.extend(pr)
|
||||||
|
|
||||||
|
# Add unmatched Tesseract rows
|
||||||
|
for tri, tr in enumerate(tess_rows):
|
||||||
|
if tri not in used_tess_rows:
|
||||||
|
for tw in tr:
|
||||||
|
if tw.get("conf", 0) >= 40:
|
||||||
|
merged_all.append(tw)
|
||||||
|
|
||||||
|
return merged_all
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate_words(words: list) -> list:
|
||||||
|
"""Remove duplicate words with same text at overlapping positions.
|
||||||
|
|
||||||
|
PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =")
|
||||||
|
that produce duplicate words after splitting. This pass removes them.
|
||||||
|
|
||||||
|
A word is a duplicate only when BOTH horizontal AND vertical overlap
|
||||||
|
exceed 50% — same text on the same visual line at the same position.
|
||||||
|
"""
|
||||||
|
if not words:
|
||||||
|
return words
|
||||||
|
|
||||||
|
result: list = []
|
||||||
|
for w in words:
|
||||||
|
wt = w.get("text", "").lower().strip()
|
||||||
|
if not wt:
|
||||||
|
continue
|
||||||
|
is_dup = False
|
||||||
|
w_right = w["left"] + w.get("width", 0)
|
||||||
|
w_bottom = w["top"] + w.get("height", 0)
|
||||||
|
for existing in result:
|
||||||
|
et = existing.get("text", "").lower().strip()
|
||||||
|
if wt != et:
|
||||||
|
continue
|
||||||
|
# Horizontal overlap
|
||||||
|
ox_l = max(w["left"], existing["left"])
|
||||||
|
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
|
||||||
|
ox = max(0, ox_r - ox_l)
|
||||||
|
min_w = min(w.get("width", 1), existing.get("width", 1))
|
||||||
|
if min_w <= 0 or ox / min_w < 0.5:
|
||||||
|
continue
|
||||||
|
# Vertical overlap — must also be on the same line
|
||||||
|
oy_t = max(w["top"], existing["top"])
|
||||||
|
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
|
||||||
|
oy = max(0, oy_b - oy_t)
|
||||||
|
min_h = min(w.get("height", 1), existing.get("height", 1))
|
||||||
|
if min_h > 0 and oy / min_h >= 0.5:
|
||||||
|
is_dup = True
|
||||||
|
break
|
||||||
|
if not is_dup:
|
||||||
|
result.append(w)
|
||||||
|
|
||||||
|
removed = len(words) - len(result)
|
||||||
|
if removed:
|
||||||
|
logger.info("dedup: removed %d duplicate words", removed)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Kombi endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/paddle-kombi")
|
||||||
|
async def paddle_kombi(session_id: str):
|
||||||
|
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results.
|
||||||
|
|
||||||
|
Both engines run on the same preprocessed (cropped/dewarped) image.
|
||||||
|
Word boxes are matched by IoU and coordinates are averaged weighted by
|
||||||
|
confidence. Unmatched Tesseract words (bullets, symbols) are added.
|
||||||
|
"""
|
||||||
|
img_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "original")
|
||||||
|
if not img_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||||
|
|
||||||
|
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||||
|
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# --- PaddleOCR ---
|
||||||
|
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
||||||
|
if not paddle_words:
|
||||||
|
paddle_words = []
|
||||||
|
|
||||||
|
# --- Tesseract ---
|
||||||
|
from PIL import Image
|
||||||
|
import pytesseract
|
||||||
|
|
||||||
|
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
data = pytesseract.image_to_data(
|
||||||
|
pil_img, lang="eng+deu",
|
||||||
|
config="--psm 6 --oem 3",
|
||||||
|
output_type=pytesseract.Output.DICT,
|
||||||
|
)
|
||||||
|
tess_words = []
|
||||||
|
for i in range(len(data["text"])):
|
||||||
|
text = str(data["text"][i]).strip()
|
||||||
|
conf_raw = str(data["conf"][i])
|
||||||
|
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||||
|
if not text or conf < 20:
|
||||||
|
continue
|
||||||
|
tess_words.append({
|
||||||
|
"text": text,
|
||||||
|
"left": data["left"][i],
|
||||||
|
"top": data["top"][i],
|
||||||
|
"width": data["width"][i],
|
||||||
|
"height": data["height"][i],
|
||||||
|
"conf": conf,
|
||||||
|
})
|
||||||
|
|
||||||
|
# --- Split multi-word Paddle boxes into individual words ---
|
||||||
|
paddle_words_split = _split_paddle_multi_words(paddle_words)
|
||||||
|
logger.info(
|
||||||
|
"paddle_kombi: split %d paddle boxes → %d individual words",
|
||||||
|
len(paddle_words), len(paddle_words_split),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Merge ---
|
||||||
|
if not paddle_words_split and not tess_words:
|
||||||
|
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||||
|
|
||||||
|
merged_words = _merge_paddle_tesseract(paddle_words_split, tess_words)
|
||||||
|
merged_words = _deduplicate_words(merged_words)
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "kombi"
|
||||||
|
|
||||||
|
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
col_types = {c.get("type") for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": "kombi",
|
||||||
|
"grid_method": "kombi",
|
||||||
|
"raw_paddle_words": paddle_words,
|
||||||
|
"raw_paddle_words_split": paddle_words_split,
|
||||||
|
"raw_tesseract_words": tess_words,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
"paddle_words": len(paddle_words),
|
||||||
|
"paddle_words_split": len(paddle_words_split),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
cropped_png=img_png,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
# Update in-memory cache so detect-structure can access word_result
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||||
|
"[paddle=%d, tess=%d, merged=%d]",
|
||||||
|
session_id, len(cells), n_rows, n_cols, duration,
|
||||||
|
len(paddle_words), len(tess_words), len(merged_words),
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "paddle_kombi", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"paddle_words": len(paddle_words),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
"ocr_engine": "kombi",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rapid-kombi")
|
||||||
|
async def rapid_kombi(session_id: str):
|
||||||
|
"""Run RapidOCR + Tesseract on the preprocessed image and merge results.
|
||||||
|
|
||||||
|
Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime)
|
||||||
|
instead of remote PaddleOCR service.
|
||||||
|
"""
|
||||||
|
img_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "original")
|
||||||
|
if not img_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||||
|
|
||||||
|
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||||
|
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_rapid
|
||||||
|
from cv_vocab_types import PageRegion
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# --- RapidOCR (local, synchronous) ---
|
||||||
|
full_region = PageRegion(
|
||||||
|
type="full_page", x=0, y=0, width=img_w, height=img_h,
|
||||||
|
)
|
||||||
|
rapid_words = ocr_region_rapid(img_bgr, full_region)
|
||||||
|
if not rapid_words:
|
||||||
|
rapid_words = []
|
||||||
|
|
||||||
|
# --- Tesseract ---
|
||||||
|
from PIL import Image
|
||||||
|
import pytesseract
|
||||||
|
|
||||||
|
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
data = pytesseract.image_to_data(
|
||||||
|
pil_img, lang="eng+deu",
|
||||||
|
config="--psm 6 --oem 3",
|
||||||
|
output_type=pytesseract.Output.DICT,
|
||||||
|
)
|
||||||
|
tess_words = []
|
||||||
|
for i in range(len(data["text"])):
|
||||||
|
text = str(data["text"][i]).strip()
|
||||||
|
conf_raw = str(data["conf"][i])
|
||||||
|
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||||
|
if not text or conf < 20:
|
||||||
|
continue
|
||||||
|
tess_words.append({
|
||||||
|
"text": text,
|
||||||
|
"left": data["left"][i],
|
||||||
|
"top": data["top"][i],
|
||||||
|
"width": data["width"][i],
|
||||||
|
"height": data["height"][i],
|
||||||
|
"conf": conf,
|
||||||
|
})
|
||||||
|
|
||||||
|
# --- Split multi-word RapidOCR boxes into individual words ---
|
||||||
|
rapid_words_split = _split_paddle_multi_words(rapid_words)
|
||||||
|
logger.info(
|
||||||
|
"rapid_kombi: split %d rapid boxes → %d individual words",
|
||||||
|
len(rapid_words), len(rapid_words_split),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Merge ---
|
||||||
|
if not rapid_words_split and not tess_words:
|
||||||
|
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||||
|
|
||||||
|
merged_words = _merge_paddle_tesseract(rapid_words_split, tess_words)
|
||||||
|
merged_words = _deduplicate_words(merged_words)
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(merged_words, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "rapid_kombi"
|
||||||
|
|
||||||
|
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
col_types = {c.get("type") for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": "rapid_kombi",
|
||||||
|
"grid_method": "rapid_kombi",
|
||||||
|
"raw_rapid_words": rapid_words,
|
||||||
|
"raw_rapid_words_split": rapid_words_split,
|
||||||
|
"raw_tesseract_words": tess_words,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
"rapid_words": len(rapid_words),
|
||||||
|
"rapid_words_split": len(rapid_words_split),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
cropped_png=img_png,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
# Update in-memory cache so detect-structure can access word_result
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||||
|
"[rapid=%d, tess=%d, merged=%d]",
|
||||||
|
session_id, len(cells), n_rows, n_cols, duration,
|
||||||
|
len(rapid_words), len(tess_words), len(merged_words),
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "rapid_kombi", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"rapid_words": len(rapid_words),
|
||||||
|
"tesseract_words": len(tess_words),
|
||||||
|
"merged_words": len(merged_words),
|
||||||
|
"ocr_engine": "rapid_kombi",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
929
klausur-service/backend/ocr_pipeline_postprocess.py
Normal file
929
klausur-service/backend/ocr_pipeline_postprocess.py
Normal file
@@ -0,0 +1,929 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation,
|
||||||
|
image detection/generation, and handwriting removal endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py to keep the main module manageable.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
OLLAMA_REVIEW_MODEL,
|
||||||
|
llm_review_entries,
|
||||||
|
llm_review_entries_streaming,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
get_sub_sessions,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_get_base_image_png,
|
||||||
|
_append_pipeline_log,
|
||||||
|
RemoveHandwritingRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
STYLE_SUFFIXES = {
|
||||||
|
"educational": "educational illustration, textbook style, clear, colorful",
|
||||||
|
"cartoon": "cartoon, child-friendly, simple shapes",
|
||||||
|
"sketch": "pencil sketch, hand-drawn, black and white",
|
||||||
|
"clipart": "clipart, flat vector style, simple",
|
||||||
|
"realistic": "photorealistic, high detail",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationRequest(BaseModel):
|
||||||
|
notes: Optional[str] = None
|
||||||
|
score: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateImageRequest(BaseModel):
|
||||||
|
region_index: int
|
||||||
|
prompt: str
|
||||||
|
style: str = "educational"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 8: LLM Review
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/llm-review")
|
||||||
|
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||||||
|
"""Run LLM-based correction on vocab entries from Step 5.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
stream: false (default) for JSON response, true for SSE streaming
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||||||
|
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
if not entries:
|
||||||
|
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||||||
|
|
||||||
|
# Optional model override from request body
|
||||||
|
body = {}
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-streaming path
|
||||||
|
try:
|
||||||
|
result = await llm_review_entries(entries, model=model)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||||
|
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||||||
|
|
||||||
|
# Store result inside word_result as a sub-key
|
||||||
|
word_result["llm_review"] = {
|
||||||
|
"changes": result["changes"],
|
||||||
|
"model_used": result["model_used"],
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
"entries_corrected": result["entries_corrected"],
|
||||||
|
}
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||||||
|
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "correction", {
|
||||||
|
"engine": "llm",
|
||||||
|
"model": result["model_used"],
|
||||||
|
"total_entries": len(entries),
|
||||||
|
"corrections_proposed": len(result["changes"]),
|
||||||
|
}, duration_ms=result["duration_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"changes": result["changes"],
|
||||||
|
"model_used": result["model_used"],
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
"total_entries": len(entries),
|
||||||
|
"corrections_found": len(result["changes"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_review_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
entries: List[Dict],
|
||||||
|
word_result: Dict,
|
||||||
|
model: str,
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||||||
|
try:
|
||||||
|
async for event in llm_review_entries_streaming(entries, model=model):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# On complete: persist to DB
|
||||||
|
if event.get("type") == "complete":
|
||||||
|
word_result["llm_review"] = {
|
||||||
|
"changes": event["changes"],
|
||||||
|
"model_used": event["model_used"],
|
||||||
|
"duration_ms": event["duration_ms"],
|
||||||
|
"entries_corrected": event["entries_corrected"],
|
||||||
|
}
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||||||
|
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||||
|
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||||||
|
yield f"data: {json.dumps(error_event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/llm-review/apply")
|
||||||
|
async def apply_llm_corrections(session_id: str, request: Request):
|
||||||
|
"""Apply selected LLM corrections to vocab entries."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
llm_review = word_result.get("llm_review")
|
||||||
|
if not llm_review:
|
||||||
|
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||||||
|
|
||||||
|
changes = llm_review.get("changes", [])
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
|
||||||
|
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||||||
|
corrections = {}
|
||||||
|
applied_count = 0
|
||||||
|
for idx, change in enumerate(changes):
|
||||||
|
if idx in accepted_indices:
|
||||||
|
key = (change["row_index"], change["field"])
|
||||||
|
corrections[key] = change["new"]
|
||||||
|
applied_count += 1
|
||||||
|
|
||||||
|
# Apply corrections to entries
|
||||||
|
for entry in entries:
|
||||||
|
row_idx = entry.get("row_index", -1)
|
||||||
|
for field_name in ("english", "german", "example"):
|
||||||
|
key = (row_idx, field_name)
|
||||||
|
if key in corrections:
|
||||||
|
entry[field_name] = corrections[key]
|
||||||
|
entry["llm_corrected"] = True
|
||||||
|
|
||||||
|
# Update word_result
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["llm_review"]["applied_count"] = applied_count
|
||||||
|
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"applied_count": applied_count,
|
||||||
|
"total_changes": len(changes),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 9: Reconstruction + Fabric JSON export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction")
|
||||||
|
async def save_reconstruction(session_id: str, request: Request):
|
||||||
|
"""Save edited cell texts from reconstruction step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
cell_updates = body.get("cells", [])
|
||||||
|
|
||||||
|
if not cell_updates:
|
||||||
|
await update_session_db(session_id, current_step=10)
|
||||||
|
return {"session_id": session_id, "updated": 0}
|
||||||
|
|
||||||
|
# Build update map: cell_id -> new text
|
||||||
|
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||||||
|
|
||||||
|
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||||||
|
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||||||
|
main_updates: Dict[str, str] = {}
|
||||||
|
for cell_id, text in update_map.items():
|
||||||
|
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||||||
|
if m:
|
||||||
|
bi = int(m.group(1))
|
||||||
|
original_id = m.group(2)
|
||||||
|
sub_updates.setdefault(bi, {})[original_id] = text
|
||||||
|
else:
|
||||||
|
main_updates[cell_id] = text
|
||||||
|
|
||||||
|
# Update main session cells
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
updated_count = 0
|
||||||
|
for cell in cells:
|
||||||
|
if cell["cell_id"] in main_updates:
|
||||||
|
cell["text"] = main_updates[cell["cell_id"]]
|
||||||
|
cell["status"] = "edited"
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
word_result["cells"] = cells
|
||||||
|
|
||||||
|
# Also update vocab_entries if present
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
if entries:
|
||||||
|
# Map cell_id pattern "R{row}_C{col}" to entry fields
|
||||||
|
for entry in entries:
|
||||||
|
row_idx = entry.get("row_index", -1)
|
||||||
|
# Check each field's cell
|
||||||
|
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||||||
|
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||||||
|
# Also try without zero-padding
|
||||||
|
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||||||
|
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||||||
|
if new_text is not None:
|
||||||
|
entry[field_name] = new_text
|
||||||
|
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
if "entries" in word_result:
|
||||||
|
word_result["entries"] = entries
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["word_result"] = word_result
|
||||||
|
|
||||||
|
# Route sub-session updates
|
||||||
|
sub_updated = 0
|
||||||
|
if sub_updates:
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||||||
|
for bi, updates in sub_updates.items():
|
||||||
|
sub_id = sub_by_index.get(bi)
|
||||||
|
if not sub_id:
|
||||||
|
continue
|
||||||
|
sub_session = await get_session_db(sub_id)
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result")
|
||||||
|
if not sub_word:
|
||||||
|
continue
|
||||||
|
sub_cells = sub_word.get("cells", [])
|
||||||
|
for cell in sub_cells:
|
||||||
|
if cell["cell_id"] in updates:
|
||||||
|
cell["text"] = updates[cell["cell_id"]]
|
||||||
|
cell["status"] = "edited"
|
||||||
|
sub_updated += 1
|
||||||
|
sub_word["cells"] = sub_cells
|
||||||
|
await update_session_db(sub_id, word_result=sub_word)
|
||||||
|
if sub_id in _cache:
|
||||||
|
_cache[sub_id]["word_result"] = sub_word
|
||||||
|
|
||||||
|
total_updated = updated_count + sub_updated
|
||||||
|
logger.info(f"Reconstruction saved for session {session_id}: "
|
||||||
|
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"updated": total_updated,
|
||||||
|
"main_updated": updated_count,
|
||||||
|
"sub_updated": sub_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||||||
|
async def get_fabric_json(session_id: str):
|
||||||
|
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
|
||||||
|
|
||||||
|
If the session has sub-sessions (box regions), their cells are merged
|
||||||
|
into the result at the correct Y positions.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = list(word_result.get("cells", []))
|
||||||
|
img_w = word_result.get("image_width", 800)
|
||||||
|
img_h = word_result.get("image_height", 600)
|
||||||
|
|
||||||
|
# Merge sub-session cells at box positions
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
|
||||||
|
for sub in subs:
|
||||||
|
sub_session = await get_session_db(sub["id"])
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result")
|
||||||
|
if not sub_word or not sub_word.get("cells"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
bi = sub.get("box_index", 0)
|
||||||
|
if bi < len(box_zones):
|
||||||
|
box = box_zones[bi]["box"]
|
||||||
|
box_y, box_x = box["y"], box["x"]
|
||||||
|
else:
|
||||||
|
box_y, box_x = 0, 0
|
||||||
|
|
||||||
|
# Offset sub-session cells to absolute page coordinates
|
||||||
|
for cell in sub_word["cells"]:
|
||||||
|
cell_copy = dict(cell)
|
||||||
|
# Prefix cell_id with box index
|
||||||
|
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||||||
|
cell_copy["source"] = f"box_{bi}"
|
||||||
|
# Offset bbox_px
|
||||||
|
bbox = cell_copy.get("bbox_px", {})
|
||||||
|
if bbox:
|
||||||
|
bbox = dict(bbox)
|
||||||
|
bbox["x"] = bbox.get("x", 0) + box_x
|
||||||
|
bbox["y"] = bbox.get("y", 0) + box_y
|
||||||
|
cell_copy["bbox_px"] = bbox
|
||||||
|
cells.append(cell_copy)
|
||||||
|
|
||||||
|
from services.layout_reconstruction_service import cells_to_fabric_json
|
||||||
|
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||||||
|
|
||||||
|
return fabric_json
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vocab entries merged + PDF/DOCX export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||||||
|
async def get_merged_vocab_entries(session_id: str):
|
||||||
|
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result") or {}
|
||||||
|
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||||||
|
|
||||||
|
# Tag main entries
|
||||||
|
for e in entries:
|
||||||
|
e.setdefault("source", "main")
|
||||||
|
|
||||||
|
# Merge sub-session entries
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
column_result = session.get("column_result") or {}
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
|
||||||
|
for sub in subs:
|
||||||
|
sub_session = await get_session_db(sub["id"])
|
||||||
|
if not sub_session:
|
||||||
|
continue
|
||||||
|
sub_word = sub_session.get("word_result") or {}
|
||||||
|
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||||||
|
|
||||||
|
bi = sub.get("box_index", 0)
|
||||||
|
box_y = 0
|
||||||
|
if bi < len(box_zones):
|
||||||
|
box_y = box_zones[bi]["box"]["y"]
|
||||||
|
|
||||||
|
for e in sub_entries:
|
||||||
|
e_copy = dict(e)
|
||||||
|
e_copy["source"] = f"box_{bi}"
|
||||||
|
e_copy["source_y"] = box_y # for sorting
|
||||||
|
entries.append(e_copy)
|
||||||
|
|
||||||
|
# Sort by approximate Y position
|
||||||
|
def _sort_key(e):
|
||||||
|
if e.get("source", "main") == "main":
|
||||||
|
return e.get("row_index", 0) * 100 # main entries by row index
|
||||||
|
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||||||
|
|
||||||
|
entries.sort(key=_sort_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"entries": entries,
|
||||||
|
"total": len(entries),
|
||||||
|
"sources": list(set(e.get("source", "main") for e in entries)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||||||
|
async def export_reconstruction_pdf(session_id: str):
|
||||||
|
"""Export the reconstructed cell grid as a PDF table."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
columns_used = word_result.get("columns_used", [])
|
||||||
|
grid_shape = word_result.get("grid_shape", {})
|
||||||
|
n_rows = grid_shape.get("rows", 0)
|
||||||
|
n_cols = grid_shape.get("cols", 0)
|
||||||
|
|
||||||
|
# Build table data: rows x columns
|
||||||
|
table_data: list[list[str]] = []
|
||||||
|
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||||
|
if not header:
|
||||||
|
header = [f"Col {i}" for i in range(n_cols)]
|
||||||
|
table_data.append(header)
|
||||||
|
|
||||||
|
for r in range(n_rows):
|
||||||
|
row_texts = []
|
||||||
|
for ci in range(n_cols):
|
||||||
|
cell_id = f"R{r:02d}_C{ci}"
|
||||||
|
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||||
|
row_texts.append(cell.get("text", "") if cell else "")
|
||||||
|
table_data.append(row_texts)
|
||||||
|
|
||||||
|
# Generate PDF with reportlab
|
||||||
|
try:
|
||||||
|
from reportlab.lib.pagesizes import A4
|
||||||
|
from reportlab.lib import colors
|
||||||
|
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||||||
|
if not table_data or not table_data[0]:
|
||||||
|
raise HTTPException(status_code=400, detail="No data to export")
|
||||||
|
|
||||||
|
t = Table(table_data)
|
||||||
|
t.setStyle(TableStyle([
|
||||||
|
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||||||
|
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||||
|
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||||
|
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||||
|
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||||
|
('WORDWRAP', (0, 0), (-1, -1), True),
|
||||||
|
]))
|
||||||
|
doc.build([t])
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
return StreamingResponse(
|
||||||
|
buf,
|
||||||
|
media_type="application/pdf",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||||||
|
async def export_reconstruction_docx(session_id: str):
|
||||||
|
"""Export the reconstructed cell grid as a DOCX table."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
word_result = session.get("word_result")
|
||||||
|
if not word_result:
|
||||||
|
raise HTTPException(status_code=400, detail="No word result found")
|
||||||
|
|
||||||
|
cells = word_result.get("cells", [])
|
||||||
|
columns_used = word_result.get("columns_used", [])
|
||||||
|
grid_shape = word_result.get("grid_shape", {})
|
||||||
|
n_rows = grid_shape.get("rows", 0)
|
||||||
|
n_cols = grid_shape.get("cols", 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from docx import Document
|
||||||
|
from docx.shared import Pt
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1)
|
||||||
|
|
||||||
|
# Build header
|
||||||
|
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||||
|
if not header:
|
||||||
|
header = [f"Col {i}" for i in range(n_cols)]
|
||||||
|
|
||||||
|
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||||||
|
table.style = 'Table Grid'
|
||||||
|
|
||||||
|
# Header row
|
||||||
|
for ci, h in enumerate(header):
|
||||||
|
table.rows[0].cells[ci].text = h
|
||||||
|
|
||||||
|
# Data rows
|
||||||
|
for r in range(n_rows):
|
||||||
|
for ci in range(n_cols):
|
||||||
|
cell_id = f"R{r:02d}_C{ci}"
|
||||||
|
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||||
|
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||||||
|
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
doc.save(buf)
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
return StreamingResponse(
|
||||||
|
buf,
|
||||||
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step 8: Validation — Original vs. Reconstruction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||||||
|
async def detect_image_regions(session_id: str):
|
||||||
|
"""Detect illustration/image regions in the original scan using VLM.
|
||||||
|
|
||||||
|
Sends the original image to qwen2.5vl to find non-text, non-table
|
||||||
|
image areas, returning bounding boxes (in %) and descriptions.
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import httpx
|
||||||
|
import re
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Get original image bytes
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if not original_png:
|
||||||
|
raise HTTPException(status_code=400, detail="No original image found")
|
||||||
|
|
||||||
|
# Build context from vocab entries for richer descriptions
|
||||||
|
word_result = session.get("word_result") or {}
|
||||||
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||||
|
vocab_context = ""
|
||||||
|
if entries:
|
||||||
|
sample = entries[:10]
|
||||||
|
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||||||
|
if words:
|
||||||
|
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||||||
|
|
||||||
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||||||
|
"(NOT text, NOT table cells, NOT blank areas). "
|
||||||
|
"For each image region found, return its bounding box as percentage of page dimensions "
|
||||||
|
"and a short English description of what the image shows. "
|
||||||
|
"Reply with ONLY a JSON array like: "
|
||||||
|
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||||||
|
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||||||
|
"If there are NO images on the page, return an empty array: []"
|
||||||
|
f"{vocab_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
|
||||||
|
# Parse JSON array from response
|
||||||
|
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
raw_regions = json.loads(match.group(0))
|
||||||
|
else:
|
||||||
|
raw_regions = []
|
||||||
|
|
||||||
|
# Normalize to ImageRegion format
|
||||||
|
regions = []
|
||||||
|
for r in raw_regions:
|
||||||
|
regions.append({
|
||||||
|
"bbox_pct": {
|
||||||
|
"x": max(0, min(100, float(r.get("x", 0)))),
|
||||||
|
"y": max(0, min(100, float(r.get("y", 0)))),
|
||||||
|
"w": max(1, min(100, float(r.get("w", 10)))),
|
||||||
|
"h": max(1, min(100, float(r.get("h", 10)))),
|
||||||
|
},
|
||||||
|
"description": r.get("description", ""),
|
||||||
|
"prompt": r.get("description", ""),
|
||||||
|
"image_b64": None,
|
||||||
|
"style": "educational",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Enrich prompts with nearby vocab context
|
||||||
|
if entries:
|
||||||
|
for region in regions:
|
||||||
|
ry = region["bbox_pct"]["y"]
|
||||||
|
rh = region["bbox_pct"]["h"]
|
||||||
|
nearby = [
|
||||||
|
e for e in entries
|
||||||
|
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||||||
|
]
|
||||||
|
if nearby:
|
||||||
|
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||||||
|
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||||||
|
if en_words or de_words:
|
||||||
|
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||||||
|
if de_words:
|
||||||
|
context += f" / {', '.join(de_words[:5])}"
|
||||||
|
context += ")"
|
||||||
|
region["prompt"] = region["description"] + context
|
||||||
|
|
||||||
|
# Save to ground_truth JSONB
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
validation["image_regions"] = regions
|
||||||
|
validation["detected_at"] = datetime.utcnow().isoformat()
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||||||
|
|
||||||
|
return {"regions": regions, "count": len(regions)}
|
||||||
|
|
||||||
|
except httpx.ConnectError:
|
||||||
|
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||||||
|
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image detection failed for {session_id}: {e}")
|
||||||
|
return {"regions": [], "count": 0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||||||
|
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||||||
|
"""Generate a replacement image for a detected region using mflux.
|
||||||
|
|
||||||
|
Sends the prompt (with style suffix) to the mflux-service running
|
||||||
|
natively on the Mac Mini (Metal GPU required).
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
regions = validation.get("image_regions") or []
|
||||||
|
|
||||||
|
if req.region_index < 0 or req.region_index >= len(regions):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||||||
|
|
||||||
|
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||||||
|
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||||||
|
full_prompt = f"{req.prompt}, {style_suffix}"
|
||||||
|
|
||||||
|
# Determine image size from region aspect ratio (snap to multiples of 64)
|
||||||
|
region = regions[req.region_index]
|
||||||
|
bbox = region["bbox_pct"]
|
||||||
|
aspect = bbox["w"] / max(bbox["h"], 1)
|
||||||
|
if aspect > 1.3:
|
||||||
|
width, height = 768, 512
|
||||||
|
elif aspect < 0.7:
|
||||||
|
width, height = 512, 768
|
||||||
|
else:
|
||||||
|
width, height = 512, 512
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||||
|
resp = await client.post(f"{mflux_url}/generate", json={
|
||||||
|
"prompt": full_prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": 4,
|
||||||
|
})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
image_b64 = data.get("image_b64")
|
||||||
|
|
||||||
|
if not image_b64:
|
||||||
|
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||||||
|
|
||||||
|
# Save to ground_truth
|
||||||
|
regions[req.region_index]["image_b64"] = image_b64
|
||||||
|
regions[req.region_index]["prompt"] = req.prompt
|
||||||
|
regions[req.region_index]["style"] = req.style
|
||||||
|
validation["image_regions"] = regions
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||||||
|
return {"image_b64": image_b64, "success": True}
|
||||||
|
|
||||||
|
except httpx.ConnectError:
|
||||||
|
logger.warning(f"mflux-service not available at {mflux_url}")
|
||||||
|
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image generation failed for {session_id}: {e}")
|
||||||
|
return {"image_b64": None, "success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||||||
|
async def save_validation(session_id: str, req: ValidationRequest):
|
||||||
|
"""Save final validation results for step 8.
|
||||||
|
|
||||||
|
Stores notes, score, and preserves any detected/generated image regions.
|
||||||
|
Sets current_step = 10 to mark pipeline as complete.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation") or {}
|
||||||
|
validation["validated_at"] = datetime.utcnow().isoformat()
|
||||||
|
validation["notes"] = req.notes
|
||||||
|
validation["score"] = req.score
|
||||||
|
ground_truth["validation"] = validation
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||||||
|
|
||||||
|
return {"session_id": session_id, "validation": validation}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||||||
|
async def get_validation(session_id: str):
|
||||||
|
"""Retrieve saved validation data for step 8."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
validation = ground_truth.get("validation")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"validation": validation,
|
||||||
|
"word_result": session.get("word_result"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Remove handwriting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/remove-handwriting")
|
||||||
|
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||||||
|
"""
|
||||||
|
Remove handwriting from a session image using inpainting.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Load source image (auto -> deskewed if available, else original)
|
||||||
|
2. Detect handwriting mask (filtered by target_ink)
|
||||||
|
3. Dilate mask to cover stroke edges
|
||||||
|
4. Inpaint the image
|
||||||
|
5. Store result as clean_png in the session
|
||||||
|
|
||||||
|
Returns metadata including the URL to fetch the clean image.
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
t0 = _time.monotonic()
|
||||||
|
|
||||||
|
from services.handwriting_detection import detect_handwriting
|
||||||
|
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# 1. Determine source image
|
||||||
|
source = req.use_source
|
||||||
|
if source == "auto":
|
||||||
|
deskewed = await get_session_image(session_id, "deskewed")
|
||||||
|
source = "deskewed" if deskewed else "original"
|
||||||
|
|
||||||
|
image_bytes = await get_session_image(session_id, source)
|
||||||
|
if not image_bytes:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||||||
|
|
||||||
|
# 2. Detect handwriting mask
|
||||||
|
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||||||
|
|
||||||
|
# 3. Convert mask to PNG bytes and dilate
|
||||||
|
import io
|
||||||
|
from PIL import Image as _PILImage
|
||||||
|
mask_img = _PILImage.fromarray(detection.mask)
|
||||||
|
mask_buf = io.BytesIO()
|
||||||
|
mask_img.save(mask_buf, format="PNG")
|
||||||
|
mask_bytes = mask_buf.getvalue()
|
||||||
|
|
||||||
|
if req.dilation > 0:
|
||||||
|
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||||||
|
|
||||||
|
# 4. Inpaint
|
||||||
|
method_map = {
|
||||||
|
"telea": InpaintingMethod.OPENCV_TELEA,
|
||||||
|
"ns": InpaintingMethod.OPENCV_NS,
|
||||||
|
"auto": InpaintingMethod.AUTO,
|
||||||
|
}
|
||||||
|
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||||||
|
|
||||||
|
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||||||
|
if not result.success:
|
||||||
|
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||||||
|
|
||||||
|
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||||||
|
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||||||
|
"detection_confidence": round(detection.confidence, 4),
|
||||||
|
"target_ink": req.target_ink,
|
||||||
|
"dilation": req.dilation,
|
||||||
|
"source_image": source,
|
||||||
|
"processing_time_ms": elapsed_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Persist clean image (convert BGR ndarray -> PNG bytes)
|
||||||
|
clean_png_bytes = image_to_png(result.image)
|
||||||
|
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**meta,
|
||||||
|
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
348
klausur-service/backend/ocr_pipeline_rows.py
Normal file
348
klausur-service/backend/ocr_pipeline_rows.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline - Row Detection Endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py.
|
||||||
|
Handles row detection (auto + manual) and row ground truth.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
create_ocr_image,
|
||||||
|
detect_column_geometry,
|
||||||
|
detect_row_geometry,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_append_pipeline_log,
|
||||||
|
ManualRowsRequest,
|
||||||
|
RowGroundTruthRequest,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
get_session_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: Box-exclusion overlay (used by rows overlay and columns overlay)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _draw_box_exclusion_overlay(
|
||||||
|
img: np.ndarray,
|
||||||
|
zones: List[Dict],
|
||||||
|
*,
|
||||||
|
label: str = "BOX — separat verarbeitet",
|
||||||
|
) -> None:
|
||||||
|
"""Draw red semi-transparent rectangles over box zones (in-place).
|
||||||
|
|
||||||
|
Reusable for columns, rows, and words overlays.
|
||||||
|
"""
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") != "box" or not zone.get("box"):
|
||||||
|
continue
|
||||||
|
box = zone["box"]
|
||||||
|
bx, by = box["x"], box["y"]
|
||||||
|
bw, bh = box["width"], box["height"]
|
||||||
|
|
||||||
|
# Red semi-transparent fill (~25 %)
|
||||||
|
box_overlay = img.copy()
|
||||||
|
cv2.rectangle(box_overlay, (bx, by), (bx + bw, by + bh), (0, 0, 200), -1)
|
||||||
|
cv2.addWeighted(box_overlay, 0.25, img, 0.75, 0, img)
|
||||||
|
|
||||||
|
# Border
|
||||||
|
cv2.rectangle(img, (bx, by), (bx + bw, by + bh), (0, 0, 200), 2)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
cv2.putText(img, label, (bx + 10, by + bh - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Row Detection Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rows")
|
||||||
|
async def detect_rows(session_id: str):
|
||||||
|
"""Run row detection on the cropped (or dewarped) image using horizontal gap analysis."""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if dewarped_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before row detection")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Try to reuse cached word_dicts and inv from column detection
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
inv = cached.get("_inv")
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
|
||||||
|
if word_dicts is None or inv is None or content_bounds is None:
|
||||||
|
# Not cached — run column geometry to get intermediates
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img, dewarped_bgr)
|
||||||
|
if geo_result is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Column geometry detection failed — cannot detect rows")
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
else:
|
||||||
|
left_x, right_x, top_y, bottom_y = content_bounds
|
||||||
|
|
||||||
|
# Read zones from column_result to exclude box regions
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
column_result = (session or {}).get("column_result") or {}
|
||||||
|
is_sub_session = bool((session or {}).get("parent_session_id"))
|
||||||
|
|
||||||
|
# Sub-sessions (box crops): use word-grouping instead of gap-based
|
||||||
|
# row detection. Box images are small with complex internal layouts
|
||||||
|
# (headings, sub-columns) where the horizontal projection approach
|
||||||
|
# merges rows. Word-grouping directly clusters words by Y proximity,
|
||||||
|
# which is more robust for these cases.
|
||||||
|
if is_sub_session and word_dicts:
|
||||||
|
from cv_layout import _build_rows_from_word_grouping
|
||||||
|
rows = _build_rows_from_word_grouping(
|
||||||
|
word_dicts, left_x, right_x, top_y, bottom_y,
|
||||||
|
right_x - left_x, bottom_y - top_y,
|
||||||
|
)
|
||||||
|
logger.info(f"OCR Pipeline: sub-session {session_id}: word-grouping found {len(rows)} rows")
|
||||||
|
else:
|
||||||
|
zones = column_result.get("zones") or [] # zones can be None for sub-sessions
|
||||||
|
|
||||||
|
# Collect box y-ranges for filtering.
|
||||||
|
# Use border_thickness to shrink the exclusion zone: the border pixels
|
||||||
|
# belong visually to the box frame, but text rows above/below the box
|
||||||
|
# may overlap with the border area and must not be clipped.
|
||||||
|
box_ranges = [] # [(y_start, y_end)]
|
||||||
|
box_ranges_inner = [] # [(y_start + border, y_end - border)] for row filtering
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box = zone["box"]
|
||||||
|
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||||||
|
box_ranges.append((box["y"], box["y"] + box["height"]))
|
||||||
|
# Inner range: shrink by border thickness so boundary rows aren't excluded
|
||||||
|
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||||
|
|
||||||
|
if box_ranges and inv is not None:
|
||||||
|
# Combined-image approach: strip box regions from inv image,
|
||||||
|
# run row detection on the combined image, then remap y-coords back.
|
||||||
|
content_strips = [] # [(y_start, y_end)] in absolute coords
|
||||||
|
# Build content strips by subtracting box inner ranges from [top_y, bottom_y].
|
||||||
|
# Using inner ranges means the border area is included in the content
|
||||||
|
# strips, so the last row above a box isn't clipped by the border.
|
||||||
|
sorted_boxes = sorted(box_ranges_inner, key=lambda r: r[0])
|
||||||
|
strip_start = top_y
|
||||||
|
for by_start, by_end in sorted_boxes:
|
||||||
|
if by_start > strip_start:
|
||||||
|
content_strips.append((strip_start, by_start))
|
||||||
|
strip_start = max(strip_start, by_end)
|
||||||
|
if strip_start < bottom_y:
|
||||||
|
content_strips.append((strip_start, bottom_y))
|
||||||
|
|
||||||
|
# Filter to strips with meaningful height
|
||||||
|
content_strips = [(ys, ye) for ys, ye in content_strips if ye - ys >= 20]
|
||||||
|
|
||||||
|
if content_strips:
|
||||||
|
# Stack content strips vertically
|
||||||
|
inv_strips = [inv[ys:ye, :] for ys, ye in content_strips]
|
||||||
|
combined_inv = np.vstack(inv_strips)
|
||||||
|
|
||||||
|
# Filter word_dicts to only include words from content strips
|
||||||
|
combined_words = []
|
||||||
|
cum_y = 0
|
||||||
|
strip_offsets = [] # (combined_y_start, strip_height, abs_y_start)
|
||||||
|
for ys, ye in content_strips:
|
||||||
|
h = ye - ys
|
||||||
|
strip_offsets.append((cum_y, h, ys))
|
||||||
|
for w in word_dicts:
|
||||||
|
w_abs_y = w['top'] + top_y # word y is relative to content top
|
||||||
|
w_center = w_abs_y + w['height'] / 2
|
||||||
|
if ys <= w_center < ye:
|
||||||
|
# Remap to combined coordinates
|
||||||
|
w_copy = dict(w)
|
||||||
|
w_copy['top'] = cum_y + (w_abs_y - ys)
|
||||||
|
combined_words.append(w_copy)
|
||||||
|
cum_y += h
|
||||||
|
|
||||||
|
# Run row detection on combined image
|
||||||
|
combined_h = combined_inv.shape[0]
|
||||||
|
rows = detect_row_geometry(
|
||||||
|
combined_inv, combined_words, left_x, right_x, 0, combined_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remap y-coordinates back to absolute page coords
|
||||||
|
def _combined_y_to_abs(cy: int) -> int:
|
||||||
|
for c_start, s_h, abs_start in strip_offsets:
|
||||||
|
if cy < c_start + s_h:
|
||||||
|
return abs_start + (cy - c_start)
|
||||||
|
last_c, last_h, last_abs = strip_offsets[-1]
|
||||||
|
return last_abs + last_h
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
abs_y = _combined_y_to_abs(r.y)
|
||||||
|
abs_y_end = _combined_y_to_abs(r.y + r.height)
|
||||||
|
r.y = abs_y
|
||||||
|
r.height = abs_y_end - abs_y
|
||||||
|
else:
|
||||||
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
else:
|
||||||
|
# No boxes — standard row detection
|
||||||
|
rows = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Assign zone_index based on which content zone each row falls in
|
||||||
|
# Build content zone list with indices
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
content_zones = [(i, z) for i, z in enumerate(zones) if z.get("zone_type") == "content"] if zones else []
|
||||||
|
|
||||||
|
# Build serializable result (exclude words to keep payload small)
|
||||||
|
rows_data = []
|
||||||
|
for r in rows:
|
||||||
|
# Determine zone_index
|
||||||
|
zone_idx = 0
|
||||||
|
row_center_y = r.y + r.height / 2
|
||||||
|
for zi, zone in content_zones:
|
||||||
|
zy = zone["y"]
|
||||||
|
zh = zone["height"]
|
||||||
|
if zy <= row_center_y < zy + zh:
|
||||||
|
zone_idx = zi
|
||||||
|
break
|
||||||
|
|
||||||
|
rd = {
|
||||||
|
"index": r.index,
|
||||||
|
"x": r.x,
|
||||||
|
"y": r.y,
|
||||||
|
"width": r.width,
|
||||||
|
"height": r.height,
|
||||||
|
"word_count": r.word_count,
|
||||||
|
"row_type": r.row_type,
|
||||||
|
"gap_before": r.gap_before,
|
||||||
|
"zone_index": zone_idx,
|
||||||
|
}
|
||||||
|
rows_data.append(rd)
|
||||||
|
|
||||||
|
type_counts = {}
|
||||||
|
for r in rows:
|
||||||
|
type_counts[r.row_type] = type_counts.get(r.row_type, 0) + 1
|
||||||
|
|
||||||
|
row_result = {
|
||||||
|
"rows": rows_data,
|
||||||
|
"summary": type_counts,
|
||||||
|
"total_rows": len(rows),
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to DB — also invalidate word_result since rows changed
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
row_result=row_result,
|
||||||
|
word_result=None,
|
||||||
|
current_step=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["row_result"] = row_result
|
||||||
|
cached.pop("word_result", None)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: rows session {session_id}: "
|
||||||
|
f"{len(rows)} rows detected ({duration:.2f}s): {type_counts}")
|
||||||
|
|
||||||
|
content_rows = sum(1 for r in rows if r.row_type == "content")
|
||||||
|
avg_height = round(sum(r.height for r in rows) / len(rows)) if rows else 0
|
||||||
|
await _append_pipeline_log(session_id, "rows", {
|
||||||
|
"total_rows": len(rows),
|
||||||
|
"content_rows": content_rows,
|
||||||
|
"artifact_rows_removed": type_counts.get("header", 0) + type_counts.get("footer", 0),
|
||||||
|
"avg_row_height_px": avg_height,
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**row_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/rows/manual")
|
||||||
|
async def set_manual_rows(session_id: str, req: ManualRowsRequest):
|
||||||
|
"""Override detected rows with manual definitions."""
|
||||||
|
row_result = {
|
||||||
|
"rows": req.rows,
|
||||||
|
"total_rows": len(req.rows),
|
||||||
|
"duration_seconds": 0,
|
||||||
|
"method": "manual",
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_session_db(session_id, row_result=row_result, word_result=None)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["row_result"] = row_result
|
||||||
|
_cache[session_id].pop("word_result", None)
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: manual rows session {session_id}: "
|
||||||
|
f"{len(req.rows)} rows set")
|
||||||
|
|
||||||
|
return {"session_id": session_id, **row_result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/rows")
|
||||||
|
async def save_row_ground_truth(session_id: str, req: RowGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the row detection step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_rows": req.corrected_rows,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"row_result": session.get("row_result"),
|
||||||
|
}
|
||||||
|
ground_truth["rows"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/ground-truth/rows")
|
||||||
|
async def get_row_ground_truth(session_id: str):
|
||||||
|
"""Retrieve saved ground truth for row detection."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
rows_gt = ground_truth.get("rows")
|
||||||
|
if not rows_gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No row ground truth saved")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"rows_gt": rows_gt,
|
||||||
|
"rows_auto": session.get("row_result"),
|
||||||
|
}
|
||||||
483
klausur-service/backend/ocr_pipeline_sessions.py
Normal file
483
klausur-service/backend/ocr_pipeline_sessions.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Sessions API - Session management and image serving endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py for modularity.
|
||||||
|
Handles: CRUD for sessions, thumbnails, pipeline logs, categories,
|
||||||
|
image serving (with overlay dispatch), and document type detection.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
create_ocr_image,
|
||||||
|
detect_document_type,
|
||||||
|
render_image_high_res,
|
||||||
|
render_pdf_high_res,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_common import (
|
||||||
|
VALID_DOCUMENT_CATEGORIES,
|
||||||
|
UpdateSessionRequest,
|
||||||
|
_append_pipeline_log,
|
||||||
|
_cache,
|
||||||
|
_get_base_image_png,
|
||||||
|
_get_cached,
|
||||||
|
_load_session_to_cache,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_overlays import render_overlay
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
create_session_db,
|
||||||
|
delete_all_sessions_db,
|
||||||
|
delete_session_db,
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
get_sub_sessions,
|
||||||
|
list_sessions_db,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session Management Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions")
|
||||||
|
async def list_sessions(include_sub_sessions: bool = False):
|
||||||
|
"""List OCR pipeline sessions.
|
||||||
|
|
||||||
|
By default, sub-sessions (box regions) are hidden.
|
||||||
|
Pass ?include_sub_sessions=true to show them.
|
||||||
|
"""
|
||||||
|
sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
|
||||||
|
return {"sessions": sessions}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions")
|
||||||
|
async def create_session(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: Optional[str] = Form(None),
|
||||||
|
):
|
||||||
|
"""Upload a PDF or image file and create a pipeline session."""
|
||||||
|
file_data = await file.read()
|
||||||
|
filename = file.filename or "upload"
|
||||||
|
content_type = file.content_type or ""
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_pdf:
|
||||||
|
img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
|
||||||
|
else:
|
||||||
|
img_bgr = render_image_high_res(file_data)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
|
||||||
|
|
||||||
|
# Encode original as PNG bytes
|
||||||
|
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to encode image")
|
||||||
|
|
||||||
|
original_png = png_buf.tobytes()
|
||||||
|
session_name = name or filename
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await create_session_db(
|
||||||
|
session_id=session_id,
|
||||||
|
name=session_name,
|
||||||
|
filename=filename,
|
||||||
|
original_png=original_png,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache BGR array for immediate processing
|
||||||
|
_cache[session_id] = {
|
||||||
|
"id": session_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": session_name,
|
||||||
|
"original_bgr": img_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": None,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
|
||||||
|
f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"filename": filename,
|
||||||
|
"name": session_name,
|
||||||
|
"image_width": img_bgr.shape[1],
|
||||||
|
"image_height": img_bgr.shape[0],
|
||||||
|
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}")
|
||||||
|
async def get_session_info(session_id: str):
|
||||||
|
"""Get session info including deskew/dewarp/column results for step navigation."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Get image dimensions from original PNG
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if original_png:
|
||||||
|
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
|
||||||
|
else:
|
||||||
|
img_w, img_h = 0, 0
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"session_id": session["id"],
|
||||||
|
"filename": session.get("filename", ""),
|
||||||
|
"name": session.get("name", ""),
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
|
||||||
|
"current_step": session.get("current_step", 1),
|
||||||
|
"document_category": session.get("document_category"),
|
||||||
|
"doc_type": session.get("doc_type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.get("orientation_result"):
|
||||||
|
result["orientation_result"] = session["orientation_result"]
|
||||||
|
if session.get("crop_result"):
|
||||||
|
result["crop_result"] = session["crop_result"]
|
||||||
|
if session.get("deskew_result"):
|
||||||
|
result["deskew_result"] = session["deskew_result"]
|
||||||
|
if session.get("dewarp_result"):
|
||||||
|
result["dewarp_result"] = session["dewarp_result"]
|
||||||
|
if session.get("column_result"):
|
||||||
|
result["column_result"] = session["column_result"]
|
||||||
|
if session.get("row_result"):
|
||||||
|
result["row_result"] = session["row_result"]
|
||||||
|
if session.get("word_result"):
|
||||||
|
result["word_result"] = session["word_result"]
|
||||||
|
if session.get("doc_type_result"):
|
||||||
|
result["doc_type_result"] = session["doc_type_result"]
|
||||||
|
|
||||||
|
# Sub-session info
|
||||||
|
if session.get("parent_session_id"):
|
||||||
|
result["parent_session_id"] = session["parent_session_id"]
|
||||||
|
result["box_index"] = session.get("box_index")
|
||||||
|
else:
|
||||||
|
# Check for sub-sessions
|
||||||
|
subs = await get_sub_sessions(session_id)
|
||||||
|
if subs:
|
||||||
|
result["sub_sessions"] = [
|
||||||
|
{"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
|
||||||
|
for s in subs
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/sessions/{session_id}")
|
||||||
|
async def update_session(session_id: str, req: UpdateSessionRequest):
|
||||||
|
"""Update session name and/or document category."""
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if req.name is not None:
|
||||||
|
kwargs["name"] = req.name
|
||||||
|
if req.document_category is not None:
|
||||||
|
if req.document_category not in VALID_DOCUMENT_CATEGORIES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
|
||||||
|
)
|
||||||
|
kwargs["document_category"] = req.document_category
|
||||||
|
if not kwargs:
|
||||||
|
raise HTTPException(status_code=400, detail="Nothing to update")
|
||||||
|
updated = await update_session_db(session_id, **kwargs)
|
||||||
|
if not updated:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}")
|
||||||
|
async def delete_session(session_id: str):
|
||||||
|
"""Delete a session."""
|
||||||
|
_cache.pop(session_id, None)
|
||||||
|
deleted = await delete_session_db(session_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, "deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions")
|
||||||
|
async def delete_all_sessions():
|
||||||
|
"""Delete ALL sessions (cleanup)."""
|
||||||
|
_cache.clear()
|
||||||
|
count = await delete_all_sessions_db()
|
||||||
|
return {"deleted_count": count}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/create-box-sessions")
|
||||||
|
async def create_box_sessions(session_id: str):
|
||||||
|
"""Create sub-sessions for each detected box region.
|
||||||
|
|
||||||
|
Crops box regions from the cropped/dewarped image and creates
|
||||||
|
independent sub-sessions that can be processed through the pipeline.
|
||||||
|
"""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
if not column_result:
|
||||||
|
raise HTTPException(status_code=400, detail="Column detection must be completed first")
|
||||||
|
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||||
|
if not box_zones:
|
||||||
|
return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
|
||||||
|
|
||||||
|
# Check for existing sub-sessions
|
||||||
|
existing = await get_sub_sessions(session_id)
|
||||||
|
if existing:
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
|
||||||
|
"message": f"{len(existing)} sub-session(s) already exist",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load base image
|
||||||
|
base_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not base_png:
|
||||||
|
base_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not base_png:
|
||||||
|
raise HTTPException(status_code=400, detail="No base image available")
|
||||||
|
|
||||||
|
arr = np.frombuffer(base_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
|
||||||
|
parent_name = session.get("name", "Session")
|
||||||
|
created = []
|
||||||
|
|
||||||
|
for i, zone in enumerate(box_zones):
|
||||||
|
box = zone["box"]
|
||||||
|
bx, by = box["x"], box["y"]
|
||||||
|
bw, bh = box["width"], box["height"]
|
||||||
|
|
||||||
|
# Crop box region with small padding
|
||||||
|
pad = 5
|
||||||
|
y1 = max(0, by - pad)
|
||||||
|
y2 = min(img.shape[0], by + bh + pad)
|
||||||
|
x1 = max(0, bx - pad)
|
||||||
|
x2 = min(img.shape[1], bx + bw + pad)
|
||||||
|
crop = img[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
# Encode as PNG
|
||||||
|
success, png_buf = cv2.imencode(".png", crop)
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to encode box {i} crop for session {session_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
sub_id = str(uuid.uuid4())
|
||||||
|
sub_name = f"{parent_name} — Box {i + 1}"
|
||||||
|
|
||||||
|
await create_session_db(
|
||||||
|
session_id=sub_id,
|
||||||
|
name=sub_name,
|
||||||
|
filename=session.get("filename", "box-crop.png"),
|
||||||
|
original_png=png_buf.tobytes(),
|
||||||
|
parent_session_id=session_id,
|
||||||
|
box_index=i,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the BGR for immediate processing
|
||||||
|
# Promote original to cropped so column/row/word detection finds it
|
||||||
|
box_bgr = crop.copy()
|
||||||
|
_cache[sub_id] = {
|
||||||
|
"id": sub_id,
|
||||||
|
"filename": session.get("filename", "box-crop.png"),
|
||||||
|
"name": sub_name,
|
||||||
|
"parent_session_id": session_id,
|
||||||
|
"original_bgr": box_bgr,
|
||||||
|
"oriented_bgr": None,
|
||||||
|
"cropped_bgr": box_bgr,
|
||||||
|
"deskewed_bgr": None,
|
||||||
|
"dewarped_bgr": None,
|
||||||
|
"orientation_result": None,
|
||||||
|
"crop_result": None,
|
||||||
|
"deskew_result": None,
|
||||||
|
"dewarp_result": None,
|
||||||
|
"ground_truth": {},
|
||||||
|
"current_step": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
created.append({
|
||||||
|
"id": sub_id,
|
||||||
|
"name": sub_name,
|
||||||
|
"box_index": i,
|
||||||
|
"box": box,
|
||||||
|
"image_width": crop.shape[1],
|
||||||
|
"image_height": crop.shape[0],
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Created box sub-session {sub_id} for session {session_id} "
|
||||||
|
f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"sub_sessions": created,
|
||||||
|
"total": len(created),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/thumbnail")
|
||||||
|
async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
|
||||||
|
"""Return a small thumbnail of the original image."""
|
||||||
|
original_png = await get_session_image(session_id, "original")
|
||||||
|
if not original_png:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
|
||||||
|
arr = np.frombuffer(original_png, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decode image")
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
scale = size / max(h, w)
|
||||||
|
new_w, new_h = int(w * scale), int(h * scale)
|
||||||
|
thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||||
|
_, png_bytes = cv2.imencode(".png", thumb)
|
||||||
|
return Response(content=png_bytes.tobytes(), media_type="image/png",
|
||||||
|
headers={"Cache-Control": "public, max-age=3600"})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/pipeline-log")
|
||||||
|
async def get_pipeline_log(session_id: str):
|
||||||
|
"""Get the pipeline execution log for a session."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/categories")
|
||||||
|
async def list_categories():
|
||||||
|
"""List valid document categories."""
|
||||||
|
return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Image Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||||
|
async def get_image(session_id: str, image_type: str):
|
||||||
|
"""Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
|
||||||
|
valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||||
|
if image_type not in valid_types:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||||
|
|
||||||
|
if image_type == "structure-overlay":
|
||||||
|
return await render_overlay("structure", session_id)
|
||||||
|
|
||||||
|
if image_type == "columns-overlay":
|
||||||
|
return await render_overlay("columns", session_id)
|
||||||
|
|
||||||
|
if image_type == "rows-overlay":
|
||||||
|
return await render_overlay("rows", session_id)
|
||||||
|
|
||||||
|
if image_type == "words-overlay":
|
||||||
|
return await render_overlay("words", session_id)
|
||||||
|
|
||||||
|
# Try cache first for fast serving
|
||||||
|
cached = _cache.get(session_id)
|
||||||
|
if cached:
|
||||||
|
png_key = f"{image_type}_png" if image_type != "original" else None
|
||||||
|
bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
|
||||||
|
|
||||||
|
# For binarized, check if we have it cached as PNG
|
||||||
|
if image_type == "binarized" and cached.get("binarized_png"):
|
||||||
|
return Response(content=cached["binarized_png"], media_type="image/png")
|
||||||
|
|
||||||
|
# Load from DB — for cropped/dewarped, fall back through the chain
|
||||||
|
if image_type in ("cropped", "dewarped"):
|
||||||
|
data = await _get_base_image_png(session_id)
|
||||||
|
else:
|
||||||
|
data = await get_session_image(session_id, image_type)
|
||||||
|
if not data:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
|
||||||
|
|
||||||
|
return Response(content=data, media_type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document Type Detection (between Dewarp and Columns)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/detect-type")
|
||||||
|
async def detect_type(session_id: str):
|
||||||
|
"""Detect document type (vocab_table, full_text, generic_table).
|
||||||
|
|
||||||
|
Should be called after crop (clean image available).
|
||||||
|
Falls back to dewarped if crop was skipped.
|
||||||
|
Stores result in session for frontend to decide pipeline flow.
|
||||||
|
"""
|
||||||
|
if session_id not in _cache:
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(img_bgr)
|
||||||
|
result = detect_document_type(ocr_img, img_bgr)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"doc_type": result.doc_type,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
"pipeline": result.pipeline,
|
||||||
|
"skip_steps": result.skip_steps,
|
||||||
|
"features": result.features,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
doc_type=result.doc_type,
|
||||||
|
doc_type_result=result_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["doc_type_result"] = result_dict
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: detect-type session {session_id}: "
|
||||||
|
f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "detect_type", {
|
||||||
|
"doc_type": result.doc_type,
|
||||||
|
"pipeline": result.pipeline,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
**{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **result_dict}
|
||||||
876
klausur-service/backend/ocr_pipeline_words.py
Normal file
876
klausur-service/backend/ocr_pipeline_words.py
Normal file
@@ -0,0 +1,876 @@
|
|||||||
|
"""
|
||||||
|
OCR Pipeline Words - Word detection and ground truth endpoints.
|
||||||
|
|
||||||
|
Extracted from ocr_pipeline_api.py.
|
||||||
|
Handles:
|
||||||
|
- POST /sessions/{session_id}/words — main SSE streaming word detection
|
||||||
|
- POST /sessions/{session_id}/paddle-direct — PaddleOCR direct endpoint
|
||||||
|
- POST /sessions/{session_id}/ground-truth/words — save ground truth
|
||||||
|
- GET /sessions/{session_id}/ground-truth/words — get ground truth
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cv_vocab_pipeline import (
|
||||||
|
PageRegion,
|
||||||
|
RowGeometry,
|
||||||
|
_cells_to_vocab_entries,
|
||||||
|
_fix_character_confusion,
|
||||||
|
_fix_phonetic_brackets,
|
||||||
|
fix_cell_phonetics,
|
||||||
|
build_cell_grid_v2,
|
||||||
|
build_cell_grid_v2_streaming,
|
||||||
|
create_ocr_image,
|
||||||
|
detect_column_geometry,
|
||||||
|
)
|
||||||
|
from cv_words_first import build_grid_from_words
|
||||||
|
from ocr_pipeline_session_store import (
|
||||||
|
get_session_db,
|
||||||
|
get_session_image,
|
||||||
|
update_session_db,
|
||||||
|
)
|
||||||
|
from ocr_pipeline_common import (
|
||||||
|
_cache,
|
||||||
|
_load_session_to_cache,
|
||||||
|
_get_cached,
|
||||||
|
_get_base_image_png,
|
||||||
|
_append_pipeline_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WordGroundTruthRequest(BaseModel):
|
||||||
|
is_correct: bool
|
||||||
|
corrected_entries: Optional[List[Dict[str, Any]]] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Word Detection Endpoint (Step 7)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/words")
|
||||||
|
async def detect_words(
|
||||||
|
session_id: str,
|
||||||
|
request: Request,
|
||||||
|
engine: str = "auto",
|
||||||
|
pronunciation: str = "british",
|
||||||
|
stream: bool = False,
|
||||||
|
skip_heal_gaps: bool = False,
|
||||||
|
grid_method: str = "v2",
|
||||||
|
):
|
||||||
|
"""Build word grid from columns × rows, OCR each cell.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
||||||
|
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||||||
|
stream: false (default) for JSON response, true for SSE streaming
|
||||||
|
skip_heal_gaps: false (default). When true, cells keep exact row geometry
|
||||||
|
positions without gap-healing expansion. Better for overlay rendering.
|
||||||
|
grid_method: 'v2' (default) or 'words_first' — grid construction strategy.
|
||||||
|
'v2' uses pre-detected columns/rows (top-down).
|
||||||
|
'words_first' clusters words bottom-up (no column/row detection needed).
|
||||||
|
"""
|
||||||
|
# PaddleOCR is full-page remote OCR → force words_first grid method
|
||||||
|
if engine == "paddle" and grid_method != "words_first":
|
||||||
|
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
||||||
|
grid_method = "words_first"
|
||||||
|
|
||||||
|
if session_id not in _cache:
|
||||||
|
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
||||||
|
await _load_session_to_cache(session_id)
|
||||||
|
cached = _get_cached(session_id)
|
||||||
|
|
||||||
|
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||||
|
if dewarped_bgr is None:
|
||||||
|
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
||||||
|
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
||||||
|
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
column_result = session.get("column_result")
|
||||||
|
row_result = session.get("row_result")
|
||||||
|
if not column_result or not column_result.get("columns"):
|
||||||
|
# No column detection — synthesize a single full-page pseudo-column.
|
||||||
|
# This enables the overlay pipeline which skips column detection.
|
||||||
|
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
||||||
|
column_result = {
|
||||||
|
"columns": [{
|
||||||
|
"type": "column_text",
|
||||||
|
"x": 0, "y": 0,
|
||||||
|
"width": img_w_tmp, "height": img_h_tmp,
|
||||||
|
"classification_confidence": 1.0,
|
||||||
|
"classification_method": "full_page_fallback",
|
||||||
|
}],
|
||||||
|
"zones": [],
|
||||||
|
"duration_seconds": 0,
|
||||||
|
}
|
||||||
|
logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
||||||
|
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
||||||
|
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||||
|
|
||||||
|
# Convert column dicts back to PageRegion objects
|
||||||
|
col_regions = [
|
||||||
|
PageRegion(
|
||||||
|
type=c["type"],
|
||||||
|
x=c["x"], y=c["y"],
|
||||||
|
width=c["width"], height=c["height"],
|
||||||
|
classification_confidence=c.get("classification_confidence", 1.0),
|
||||||
|
classification_method=c.get("classification_method", ""),
|
||||||
|
)
|
||||||
|
for c in column_result["columns"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert row dicts back to RowGeometry objects
|
||||||
|
row_geoms = [
|
||||||
|
RowGeometry(
|
||||||
|
index=r["index"],
|
||||||
|
x=r["x"], y=r["y"],
|
||||||
|
width=r["width"], height=r["height"],
|
||||||
|
word_count=r.get("word_count", 0),
|
||||||
|
words=[],
|
||||||
|
row_type=r.get("row_type", "content"),
|
||||||
|
gap_before=r.get("gap_before", 0),
|
||||||
|
)
|
||||||
|
for r in row_result["rows"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cell-First OCR (v2): no full-page word re-population needed.
|
||||||
|
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
|
||||||
|
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
|
||||||
|
# so populate from cached words if available (just for counting).
|
||||||
|
word_dicts = cached.get("_word_dicts")
|
||||||
|
if word_dicts is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||||
|
if geo_result is not None:
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
if word_dicts:
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
if content_bounds:
|
||||||
|
_lx, _rx, top_y, _by = content_bounds
|
||||||
|
else:
|
||||||
|
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||||||
|
|
||||||
|
for row in row_geoms:
|
||||||
|
row_y_rel = row.y - top_y
|
||||||
|
row_bottom_rel = row_y_rel + row.height
|
||||||
|
row.words = [
|
||||||
|
w for w in word_dicts
|
||||||
|
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||||
|
]
|
||||||
|
row.word_count = len(row.words)
|
||||||
|
|
||||||
|
# Exclude rows that fall within box zones.
|
||||||
|
# Use inner box range (shrunk by border_thickness) so that rows at
|
||||||
|
# the boundary (overlapping with the box border) are NOT excluded.
|
||||||
|
zones = column_result.get("zones") or []
|
||||||
|
box_ranges_inner = []
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box = zone["box"]
|
||||||
|
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||||||
|
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||||
|
|
||||||
|
if box_ranges_inner:
|
||||||
|
def _row_in_box(r):
|
||||||
|
center_y = r.y + r.height / 2
|
||||||
|
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
||||||
|
|
||||||
|
before_count = len(row_geoms)
|
||||||
|
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
||||||
|
excluded = before_count - len(row_geoms)
|
||||||
|
if excluded:
|
||||||
|
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
||||||
|
|
||||||
|
# --- Words-First path: bottom-up grid from word boxes ---
|
||||||
|
if grid_method == "words_first":
|
||||||
|
t0 = time.time()
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
# For paddle engine: run remote PaddleOCR full-page instead of Tesseract
|
||||||
|
if engine == "paddle":
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
|
||||||
|
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
||||||
|
# PaddleOCR returns absolute coordinates, no content_bounds offset needed
|
||||||
|
cached["_paddle_word_dicts"] = wf_word_dicts
|
||||||
|
else:
|
||||||
|
# Get word_dicts from cache or run Tesseract full-page
|
||||||
|
wf_word_dicts = cached.get("_word_dicts")
|
||||||
|
if wf_word_dicts is None:
|
||||||
|
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||||
|
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||||
|
if geo_result is not None:
|
||||||
|
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
||||||
|
cached["_word_dicts"] = wf_word_dicts
|
||||||
|
cached["_inv"] = inv
|
||||||
|
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||||
|
|
||||||
|
if not wf_word_dicts:
|
||||||
|
raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid")
|
||||||
|
|
||||||
|
# Convert word coordinates to absolute image coordinates if needed
|
||||||
|
# (detect_column_geometry returns words relative to content ROI)
|
||||||
|
# PaddleOCR already returns absolute coordinates — skip offset.
|
||||||
|
if engine != "paddle":
|
||||||
|
content_bounds = cached.get("_content_bounds")
|
||||||
|
if content_bounds:
|
||||||
|
lx, _rx, ty, _by = content_bounds
|
||||||
|
abs_words = []
|
||||||
|
for w in wf_word_dicts:
|
||||||
|
abs_words.append({
|
||||||
|
**w,
|
||||||
|
'left': w['left'] + lx,
|
||||||
|
'top': w['top'] + ty,
|
||||||
|
})
|
||||||
|
wf_word_dicts = abs_words
|
||||||
|
|
||||||
|
# Extract box rects for box-aware column clustering
|
||||||
|
box_rects = []
|
||||||
|
for zone in zones:
|
||||||
|
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||||
|
box_rects.append(zone["box"])
|
||||||
|
|
||||||
|
cells, columns_meta = build_grid_from_words(
|
||||||
|
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Apply IPA phonetic fixes
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
# Add zone_index for backward compat
|
||||||
|
for cell in cells:
|
||||||
|
cell.setdefault("zone_index", 0)
|
||||||
|
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
used_engine = "paddle" if engine == "paddle" else "words_first"
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_rows,
|
||||||
|
"cols": n_cols,
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"grid_method": "words_first",
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_vocab or 'column_text' in col_types:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
||||||
|
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "words", {
|
||||||
|
"grid_method": "words_first",
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"layout": word_result["layout"],
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
# Cell-First OCR v2: use batch-then-stream approach instead of
|
||||||
|
# per-cell streaming. The parallel ThreadPoolExecutor in
|
||||||
|
# build_cell_grid_v2 is much faster than sequential streaming.
|
||||||
|
return StreamingResponse(
|
||||||
|
_word_batch_stream_generator(
|
||||||
|
session_id, cached, col_regions, row_geoms,
|
||||||
|
dewarped_bgr, engine, pronunciation, request,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Non-streaming path (grid_method=v2) ---
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Create binarized OCR image (for Tesseract)
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
|
||||||
|
cells, columns_meta = build_cell_grid_v2(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Add zone_index to each cell (default 0 for backward compatibility)
|
||||||
|
for cell in cells:
|
||||||
|
cell.setdefault("zone_index", 0)
|
||||||
|
|
||||||
|
# Layout detection
|
||||||
|
col_types = {c['type'] for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
|
||||||
|
# Count content rows and columns for grid_shape
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
|
||||||
|
# Determine which engine was actually used
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||||
|
|
||||||
|
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
# Grid result (always generic)
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_content_rows,
|
||||||
|
"cols": n_cols,
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||||||
|
# to vocab entries (row→entry).
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, "
|
||||||
|
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "words", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"low_confidence_count": word_result["summary"]["low_confidence"],
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"layout": word_result["layout"],
|
||||||
|
"entry_count": word_result.get("entry_count", 0),
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
**word_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _word_batch_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
col_regions: List[PageRegion],
|
||||||
|
row_geoms: List[RowGeometry],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
request: Request,
|
||||||
|
skip_heal_gaps: bool = False,
|
||||||
|
):
|
||||||
|
"""SSE generator that runs batch OCR (parallel) then streams results.
|
||||||
|
|
||||||
|
Unlike the old per-cell streaming, this uses build_cell_grid_v2 with
|
||||||
|
ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events.
|
||||||
|
The 'preparing' event keeps the connection alive during OCR processing.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||||
|
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
total_cells = n_content_rows * n_cols
|
||||||
|
|
||||||
|
# 1. Send meta event immediately
|
||||||
|
meta_event = {
|
||||||
|
"type": "meta",
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||||
|
|
||||||
|
# 2. Send preparing event (keepalive for proxy)
|
||||||
|
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
||||||
|
|
||||||
|
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
||||||
|
# The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE
|
||||||
|
# connections after 30-60s. Send keepalive every 5s to prevent this.
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ocr_future = loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: build_cell_grid_v2(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
skip_heal_gaps=skip_heal_gaps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send keepalive events every 5 seconds while OCR runs
|
||||||
|
keepalive_count = 0
|
||||||
|
while not ocr_future.done():
|
||||||
|
try:
|
||||||
|
cells, columns_meta = await asyncio.wait_for(
|
||||||
|
asyncio.shield(ocr_future), timeout=5.0,
|
||||||
|
)
|
||||||
|
break # OCR finished
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
keepalive_count += 1
|
||||||
|
elapsed = int(time.time() - t0)
|
||||||
|
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
||||||
|
ocr_future.cancel()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cells, columns_meta = ocr_future.result()
|
||||||
|
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||||
|
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
# 5. Send columns meta
|
||||||
|
if columns_meta:
|
||||||
|
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
||||||
|
|
||||||
|
# 6. Stream all cells
|
||||||
|
for idx, cell in enumerate(cells):
|
||||||
|
cell_event = {
|
||||||
|
"type": "cell",
|
||||||
|
"cell": cell,
|
||||||
|
"progress": {"current": idx + 1, "total": len(cells)},
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||||
|
|
||||||
|
# 6. Build final result and persist
|
||||||
|
duration = time.time() - t0
|
||||||
|
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab_entries = None
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
vocab_entries = entries
|
||||||
|
|
||||||
|
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
||||||
|
|
||||||
|
# 7. Send complete event
|
||||||
|
complete_event = {
|
||||||
|
"type": "complete",
|
||||||
|
"summary": word_result["summary"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
}
|
||||||
|
if vocab_entries is not None:
|
||||||
|
complete_event["vocab_entries"] = vocab_entries
|
||||||
|
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def _word_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
cached: Dict[str, Any],
|
||||||
|
col_regions: List[PageRegion],
|
||||||
|
row_geoms: List[RowGeometry],
|
||||||
|
dewarped_bgr: np.ndarray,
|
||||||
|
engine: str,
|
||||||
|
pronunciation: str,
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
ocr_img = create_ocr_image(dewarped_bgr)
|
||||||
|
img_h, img_w = dewarped_bgr.shape[:2]
|
||||||
|
|
||||||
|
# Compute grid shape upfront for the meta event
|
||||||
|
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||||
|
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||||
|
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||||
|
|
||||||
|
# Determine layout
|
||||||
|
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||||
|
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||||
|
|
||||||
|
# Start streaming — first event: meta
|
||||||
|
columns_meta = None # will be set from first yield
|
||||||
|
total_cells = n_content_rows * n_cols
|
||||||
|
|
||||||
|
meta_event = {
|
||||||
|
"type": "meta",
|
||||||
|
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||||
|
|
||||||
|
# Keepalive: send preparing event so proxy doesn't timeout during OCR init
|
||||||
|
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
||||||
|
|
||||||
|
# Stream cells one by one
|
||||||
|
all_cells: List[Dict[str, Any]] = []
|
||||||
|
cell_idx = 0
|
||||||
|
last_keepalive = time.time()
|
||||||
|
|
||||||
|
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||||
|
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||||
|
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||||
|
):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if columns_meta is None:
|
||||||
|
columns_meta = cols_meta
|
||||||
|
# Send columns_used as part of first cell or update meta
|
||||||
|
meta_update = {
|
||||||
|
"type": "columns",
|
||||||
|
"columns_used": cols_meta,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||||
|
|
||||||
|
all_cells.append(cell)
|
||||||
|
cell_idx += 1
|
||||||
|
|
||||||
|
cell_event = {
|
||||||
|
"type": "cell",
|
||||||
|
"cell": cell,
|
||||||
|
"progress": {"current": cell_idx, "total": total},
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||||
|
|
||||||
|
# All cells done — build final result
|
||||||
|
duration = time.time() - t0
|
||||||
|
if columns_meta is None:
|
||||||
|
columns_meta = []
|
||||||
|
|
||||||
|
# Post-OCR: remove rows where ALL cells are empty (inter-row gaps
|
||||||
|
# that had stray Tesseract artifacts giving word_count > 0).
|
||||||
|
rows_with_text: set = set()
|
||||||
|
for c in all_cells:
|
||||||
|
if c.get("text", "").strip():
|
||||||
|
rows_with_text.add(c["row_index"])
|
||||||
|
before_filter = len(all_cells)
|
||||||
|
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
||||||
|
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
||||||
|
if empty_rows_removed > 0:
|
||||||
|
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
||||||
|
|
||||||
|
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||||
|
|
||||||
|
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||||
|
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": all_cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_content_rows,
|
||||||
|
"cols": n_cols,
|
||||||
|
"total_cells": len(all_cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(all_cells),
|
||||||
|
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||||||
|
# to vocab entries (row→entry).
|
||||||
|
vocab_entries = None
|
||||||
|
has_text_col = 'column_text' in col_types
|
||||||
|
if is_vocab or has_text_col:
|
||||||
|
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||||
|
entries = _fix_character_confusion(entries)
|
||||||
|
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||||
|
word_result["vocab_entries"] = entries
|
||||||
|
word_result["entries"] = entries
|
||||||
|
word_result["entry_count"] = len(entries)
|
||||||
|
word_result["summary"]["total_entries"] = len(entries)
|
||||||
|
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||||
|
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||||
|
vocab_entries = entries
|
||||||
|
|
||||||
|
# Persist to DB
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
cached["word_result"] = word_result
|
||||||
|
|
||||||
|
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||||
|
f"layout={word_result['layout']}, "
|
||||||
|
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||||
|
|
||||||
|
# Final complete event
|
||||||
|
complete_event = {
|
||||||
|
"type": "complete",
|
||||||
|
"summary": word_result["summary"],
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": used_engine,
|
||||||
|
}
|
||||||
|
if vocab_entries is not None:
|
||||||
|
complete_event["vocab_entries"] = vocab_entries
|
||||||
|
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PaddleOCR Direct Endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/paddle-direct")
|
||||||
|
async def paddle_direct(session_id: str):
|
||||||
|
"""Run PaddleOCR on the preprocessed image and build a word grid directly.
|
||||||
|
|
||||||
|
Expects orientation/deskew/dewarp/crop to be done already.
|
||||||
|
Uses the cropped image (falls back to dewarped, then original).
|
||||||
|
The used image is stored as cropped_png so OverlayReconstruction
|
||||||
|
can display it as the background.
|
||||||
|
"""
|
||||||
|
# Try preprocessed images first (crop > dewarp > original)
|
||||||
|
img_png = await get_session_image(session_id, "cropped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "dewarped")
|
||||||
|
if not img_png:
|
||||||
|
img_png = await get_session_image(session_id, "original")
|
||||||
|
if not img_png:
|
||||||
|
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||||
|
|
||||||
|
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to decode original image")
|
||||||
|
|
||||||
|
img_h, img_w = img_bgr.shape[:2]
|
||||||
|
|
||||||
|
from cv_ocr_engines import ocr_region_paddle
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
word_dicts = await ocr_region_paddle(img_bgr, region=None)
|
||||||
|
if not word_dicts:
|
||||||
|
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
|
||||||
|
|
||||||
|
# Reuse build_grid_from_words — same function that works in the regular
|
||||||
|
# pipeline with PaddleOCR (engine=paddle, grid_method=words_first).
|
||||||
|
# Handles phrase splitting, column clustering, and reading order.
|
||||||
|
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
|
||||||
|
duration = time.time() - t0
|
||||||
|
|
||||||
|
# Tag cells as paddle_direct
|
||||||
|
for cell in cells:
|
||||||
|
cell["ocr_engine"] = "paddle_direct"
|
||||||
|
|
||||||
|
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||||
|
n_cols = len(columns_meta)
|
||||||
|
col_types = {c.get("type") for c in columns_meta}
|
||||||
|
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||||
|
|
||||||
|
word_result = {
|
||||||
|
"cells": cells,
|
||||||
|
"grid_shape": {
|
||||||
|
"rows": n_rows,
|
||||||
|
"cols": n_cols,
|
||||||
|
"total_cells": len(cells),
|
||||||
|
},
|
||||||
|
"columns_used": columns_meta,
|
||||||
|
"layout": "vocab" if is_vocab else "generic",
|
||||||
|
"image_width": img_w,
|
||||||
|
"image_height": img_h,
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"ocr_engine": "paddle_direct",
|
||||||
|
"grid_method": "paddle_direct",
|
||||||
|
"summary": {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||||
|
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store preprocessed image as cropped_png so OverlayReconstruction shows it
|
||||||
|
await update_session_db(
|
||||||
|
session_id,
|
||||||
|
word_result=word_result,
|
||||||
|
cropped_png=img_png,
|
||||||
|
current_step=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
|
||||||
|
session_id, len(cells), n_rows, n_cols, duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _append_pipeline_log(session_id, "paddle_direct", {
|
||||||
|
"total_cells": len(cells),
|
||||||
|
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||||
|
"ocr_engine": "paddle_direct",
|
||||||
|
}, duration_ms=int(duration * 1000))
|
||||||
|
|
||||||
|
return {"session_id": session_id, **word_result}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ground Truth Words Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/ground-truth/words")
|
||||||
|
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||||||
|
"""Save ground truth feedback for the word recognition step."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
gt = {
|
||||||
|
"is_correct": req.is_correct,
|
||||||
|
"corrected_entries": req.corrected_entries,
|
||||||
|
"notes": req.notes,
|
||||||
|
"saved_at": datetime.utcnow().isoformat(),
|
||||||
|
"word_result": session.get("word_result"),
|
||||||
|
}
|
||||||
|
ground_truth["words"] = gt
|
||||||
|
|
||||||
|
await update_session_db(session_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
|
if session_id in _cache:
|
||||||
|
_cache[session_id]["ground_truth"] = ground_truth
|
||||||
|
|
||||||
|
return {"session_id": session_id, "ground_truth": gt}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/ground-truth/words")
|
||||||
|
async def get_word_ground_truth(session_id: str):
|
||||||
|
"""Retrieve saved ground truth for word recognition."""
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
ground_truth = session.get("ground_truth") or {}
|
||||||
|
words_gt = ground_truth.get("words")
|
||||||
|
if not words_gt:
|
||||||
|
raise HTTPException(status_code=404, detail="No word ground truth saved")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"words_gt": words_gt,
|
||||||
|
"words_auto": session.get("word_result"),
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user