Each module is under 1050 lines: - ocr_pipeline_common.py (354) - shared state, cache, models, helpers - ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type - ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns - ocr_pipeline_rows.py (348) - row detection, box-overlay helper - ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct - ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints - ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export - ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports router, _cache, and test-imported symbols for backward compatibility. No changes needed in main.py or tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
930 lines
35 KiB
Python
930 lines
35 KiB
Python
"""
|
||
OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation,
|
||
image detection/generation, and handwriting removal endpoints.
|
||
|
||
Extracted from ocr_pipeline_api.py to keep the main module manageable.
|
||
|
||
Lizenz: Apache 2.0
|
||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastapi import APIRouter, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from cv_vocab_pipeline import (
|
||
OLLAMA_REVIEW_MODEL,
|
||
llm_review_entries,
|
||
llm_review_entries_streaming,
|
||
)
|
||
from ocr_pipeline_session_store import (
|
||
get_session_db,
|
||
get_session_image,
|
||
get_sub_sessions,
|
||
update_session_db,
|
||
)
|
||
from ocr_pipeline_common import (
|
||
_cache,
|
||
_load_session_to_cache,
|
||
_get_cached,
|
||
_get_base_image_png,
|
||
_append_pipeline_log,
|
||
RemoveHandwritingRequest,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic Models
|
||
# ---------------------------------------------------------------------------
|
||
|
||
STYLE_SUFFIXES = {
|
||
"educational": "educational illustration, textbook style, clear, colorful",
|
||
"cartoon": "cartoon, child-friendly, simple shapes",
|
||
"sketch": "pencil sketch, hand-drawn, black and white",
|
||
"clipart": "clipart, flat vector style, simple",
|
||
"realistic": "photorealistic, high detail",
|
||
}
|
||
|
||
|
||
class ValidationRequest(BaseModel):
|
||
notes: Optional[str] = None
|
||
score: Optional[int] = None
|
||
|
||
|
||
class GenerateImageRequest(BaseModel):
|
||
region_index: int
|
||
prompt: str
|
||
style: str = "educational"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Step 8: LLM Review
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/llm-review")
|
||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||
"""Run LLM-based correction on vocab entries from Step 5.
|
||
|
||
Query params:
|
||
stream: false (default) for JSON response, true for SSE streaming
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if not entries:
|
||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||
|
||
# Optional model override from request body
|
||
body = {}
|
||
try:
|
||
body = await request.json()
|
||
except Exception:
|
||
pass
|
||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||
|
||
if stream:
|
||
return StreamingResponse(
|
||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
# Non-streaming path
|
||
try:
|
||
result = await llm_review_entries(entries, model=model)
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||
|
||
# Store result inside word_result as a sub-key
|
||
word_result["llm_review"] = {
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"entries_corrected": result["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||
|
||
await _append_pipeline_log(session_id, "correction", {
|
||
"engine": "llm",
|
||
"model": result["model_used"],
|
||
"total_entries": len(entries),
|
||
"corrections_proposed": len(result["changes"]),
|
||
}, duration_ms=result["duration_ms"])
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"changes": result["changes"],
|
||
"model_used": result["model_used"],
|
||
"duration_ms": result["duration_ms"],
|
||
"total_entries": len(entries),
|
||
"corrections_found": len(result["changes"]),
|
||
}
|
||
|
||
|
||
async def _llm_review_stream_generator(
|
||
session_id: str,
|
||
entries: List[Dict],
|
||
word_result: Dict,
|
||
model: str,
|
||
request: Request,
|
||
):
|
||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||
try:
|
||
async for event in llm_review_entries_streaming(entries, model=model):
|
||
if await request.is_disconnected():
|
||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||
return
|
||
|
||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||
|
||
# On complete: persist to DB
|
||
if event.get("type") == "complete":
|
||
word_result["llm_review"] = {
|
||
"changes": event["changes"],
|
||
"model_used": event["model_used"],
|
||
"duration_ms": event["duration_ms"],
|
||
"entries_corrected": event["entries_corrected"],
|
||
}
|
||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||
yield f"data: {json.dumps(error_event)}\n\n"
|
||
|
||
|
||
@router.post("/sessions/{session_id}/llm-review/apply")
|
||
async def apply_llm_corrections(session_id: str, request: Request):
|
||
"""Apply selected LLM corrections to vocab entries."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
llm_review = word_result.get("llm_review")
|
||
if not llm_review:
|
||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||
|
||
body = await request.json()
|
||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||
|
||
changes = llm_review.get("changes", [])
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
|
||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||
corrections = {}
|
||
applied_count = 0
|
||
for idx, change in enumerate(changes):
|
||
if idx in accepted_indices:
|
||
key = (change["row_index"], change["field"])
|
||
corrections[key] = change["new"]
|
||
applied_count += 1
|
||
|
||
# Apply corrections to entries
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
for field_name in ("english", "german", "example"):
|
||
key = (row_idx, field_name)
|
||
if key in corrections:
|
||
entry[field_name] = corrections[key]
|
||
entry["llm_corrected"] = True
|
||
|
||
# Update word_result
|
||
word_result["vocab_entries"] = entries
|
||
word_result["entries"] = entries
|
||
word_result["llm_review"]["applied_count"] = applied_count
|
||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||
|
||
await update_session_db(session_id, word_result=word_result)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"applied_count": applied_count,
|
||
"total_changes": len(changes),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Step 9: Reconstruction + Fabric JSON export
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction")
|
||
async def save_reconstruction(session_id: str, request: Request):
|
||
"""Save edited cell texts from reconstruction step."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
body = await request.json()
|
||
cell_updates = body.get("cells", [])
|
||
|
||
if not cell_updates:
|
||
await update_session_db(session_id, current_step=10)
|
||
return {"session_id": session_id, "updated": 0}
|
||
|
||
# Build update map: cell_id -> new text
|
||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||
|
||
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||
main_updates: Dict[str, str] = {}
|
||
for cell_id, text in update_map.items():
|
||
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||
if m:
|
||
bi = int(m.group(1))
|
||
original_id = m.group(2)
|
||
sub_updates.setdefault(bi, {})[original_id] = text
|
||
else:
|
||
main_updates[cell_id] = text
|
||
|
||
# Update main session cells
|
||
cells = word_result.get("cells", [])
|
||
updated_count = 0
|
||
for cell in cells:
|
||
if cell["cell_id"] in main_updates:
|
||
cell["text"] = main_updates[cell["cell_id"]]
|
||
cell["status"] = "edited"
|
||
updated_count += 1
|
||
|
||
word_result["cells"] = cells
|
||
|
||
# Also update vocab_entries if present
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
if entries:
|
||
# Map cell_id pattern "R{row}_C{col}" to entry fields
|
||
for entry in entries:
|
||
row_idx = entry.get("row_index", -1)
|
||
# Check each field's cell
|
||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||
# Also try without zero-padding
|
||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||
if new_text is not None:
|
||
entry[field_name] = new_text
|
||
|
||
word_result["vocab_entries"] = entries
|
||
if "entries" in word_result:
|
||
word_result["entries"] = entries
|
||
|
||
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["word_result"] = word_result
|
||
|
||
# Route sub-session updates
|
||
sub_updated = 0
|
||
if sub_updates:
|
||
subs = await get_sub_sessions(session_id)
|
||
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||
for bi, updates in sub_updates.items():
|
||
sub_id = sub_by_index.get(bi)
|
||
if not sub_id:
|
||
continue
|
||
sub_session = await get_session_db(sub_id)
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result")
|
||
if not sub_word:
|
||
continue
|
||
sub_cells = sub_word.get("cells", [])
|
||
for cell in sub_cells:
|
||
if cell["cell_id"] in updates:
|
||
cell["text"] = updates[cell["cell_id"]]
|
||
cell["status"] = "edited"
|
||
sub_updated += 1
|
||
sub_word["cells"] = sub_cells
|
||
await update_session_db(sub_id, word_result=sub_word)
|
||
if sub_id in _cache:
|
||
_cache[sub_id]["word_result"] = sub_word
|
||
|
||
total_updated = updated_count + sub_updated
|
||
logger.info(f"Reconstruction saved for session {session_id}: "
|
||
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"updated": total_updated,
|
||
"main_updated": updated_count,
|
||
"sub_updated": sub_updated,
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||
async def get_fabric_json(session_id: str):
|
||
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
|
||
|
||
If the session has sub-sessions (box regions), their cells are merged
|
||
into the result at the correct Y positions.
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = list(word_result.get("cells", []))
|
||
img_w = word_result.get("image_width", 800)
|
||
img_h = word_result.get("image_height", 600)
|
||
|
||
# Merge sub-session cells at box positions
|
||
subs = await get_sub_sessions(session_id)
|
||
if subs:
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||
|
||
for sub in subs:
|
||
sub_session = await get_session_db(sub["id"])
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result")
|
||
if not sub_word or not sub_word.get("cells"):
|
||
continue
|
||
|
||
bi = sub.get("box_index", 0)
|
||
if bi < len(box_zones):
|
||
box = box_zones[bi]["box"]
|
||
box_y, box_x = box["y"], box["x"]
|
||
else:
|
||
box_y, box_x = 0, 0
|
||
|
||
# Offset sub-session cells to absolute page coordinates
|
||
for cell in sub_word["cells"]:
|
||
cell_copy = dict(cell)
|
||
# Prefix cell_id with box index
|
||
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||
cell_copy["source"] = f"box_{bi}"
|
||
# Offset bbox_px
|
||
bbox = cell_copy.get("bbox_px", {})
|
||
if bbox:
|
||
bbox = dict(bbox)
|
||
bbox["x"] = bbox.get("x", 0) + box_x
|
||
bbox["y"] = bbox.get("y", 0) + box_y
|
||
cell_copy["bbox_px"] = bbox
|
||
cells.append(cell_copy)
|
||
|
||
from services.layout_reconstruction_service import cells_to_fabric_json
|
||
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||
|
||
return fabric_json
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Vocab entries merged + PDF/DOCX export
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||
async def get_merged_vocab_entries(session_id: str):
|
||
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result") or {}
|
||
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||
|
||
# Tag main entries
|
||
for e in entries:
|
||
e.setdefault("source", "main")
|
||
|
||
# Merge sub-session entries
|
||
subs = await get_sub_sessions(session_id)
|
||
if subs:
|
||
column_result = session.get("column_result") or {}
|
||
zones = column_result.get("zones") or []
|
||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||
|
||
for sub in subs:
|
||
sub_session = await get_session_db(sub["id"])
|
||
if not sub_session:
|
||
continue
|
||
sub_word = sub_session.get("word_result") or {}
|
||
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||
|
||
bi = sub.get("box_index", 0)
|
||
box_y = 0
|
||
if bi < len(box_zones):
|
||
box_y = box_zones[bi]["box"]["y"]
|
||
|
||
for e in sub_entries:
|
||
e_copy = dict(e)
|
||
e_copy["source"] = f"box_{bi}"
|
||
e_copy["source_y"] = box_y # for sorting
|
||
entries.append(e_copy)
|
||
|
||
# Sort by approximate Y position
|
||
def _sort_key(e):
|
||
if e.get("source", "main") == "main":
|
||
return e.get("row_index", 0) * 100 # main entries by row index
|
||
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||
|
||
entries.sort(key=_sort_key)
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"entries": entries,
|
||
"total": len(entries),
|
||
"sources": list(set(e.get("source", "main") for e in entries)),
|
||
}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||
async def export_reconstruction_pdf(session_id: str):
|
||
"""Export the reconstructed cell grid as a PDF table."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = word_result.get("cells", [])
|
||
columns_used = word_result.get("columns_used", [])
|
||
grid_shape = word_result.get("grid_shape", {})
|
||
n_rows = grid_shape.get("rows", 0)
|
||
n_cols = grid_shape.get("cols", 0)
|
||
|
||
# Build table data: rows x columns
|
||
table_data: list[list[str]] = []
|
||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||
if not header:
|
||
header = [f"Col {i}" for i in range(n_cols)]
|
||
table_data.append(header)
|
||
|
||
for r in range(n_rows):
|
||
row_texts = []
|
||
for ci in range(n_cols):
|
||
cell_id = f"R{r:02d}_C{ci}"
|
||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||
row_texts.append(cell.get("text", "") if cell else "")
|
||
table_data.append(row_texts)
|
||
|
||
# Generate PDF with reportlab
|
||
try:
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.lib import colors
|
||
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||
import io as _io
|
||
|
||
buf = _io.BytesIO()
|
||
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||
if not table_data or not table_data[0]:
|
||
raise HTTPException(status_code=400, detail="No data to export")
|
||
|
||
t = Table(table_data)
|
||
t.setStyle(TableStyle([
|
||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||
('WORDWRAP', (0, 0), (-1, -1), True),
|
||
]))
|
||
doc.build([t])
|
||
buf.seek(0)
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/pdf",
|
||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||
)
|
||
except ImportError:
|
||
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||
async def export_reconstruction_docx(session_id: str):
|
||
"""Export the reconstructed cell grid as a DOCX table."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
word_result = session.get("word_result")
|
||
if not word_result:
|
||
raise HTTPException(status_code=400, detail="No word result found")
|
||
|
||
cells = word_result.get("cells", [])
|
||
columns_used = word_result.get("columns_used", [])
|
||
grid_shape = word_result.get("grid_shape", {})
|
||
n_rows = grid_shape.get("rows", 0)
|
||
n_cols = grid_shape.get("cols", 0)
|
||
|
||
try:
|
||
from docx import Document
|
||
from docx.shared import Pt
|
||
import io as _io
|
||
|
||
doc = Document()
|
||
doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1)
|
||
|
||
# Build header
|
||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||
if not header:
|
||
header = [f"Col {i}" for i in range(n_cols)]
|
||
|
||
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||
table.style = 'Table Grid'
|
||
|
||
# Header row
|
||
for ci, h in enumerate(header):
|
||
table.rows[0].cells[ci].text = h
|
||
|
||
# Data rows
|
||
for r in range(n_rows):
|
||
for ci in range(n_cols):
|
||
cell_id = f"R{r:02d}_C{ci}"
|
||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||
|
||
buf = _io.BytesIO()
|
||
doc.save(buf)
|
||
buf.seek(0)
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||
)
|
||
except ImportError:
|
||
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Step 8: Validation — Original vs. Reconstruction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||
async def detect_image_regions(session_id: str):
|
||
"""Detect illustration/image regions in the original scan using VLM.
|
||
|
||
Sends the original image to qwen2.5vl to find non-text, non-table
|
||
image areas, returning bounding boxes (in %) and descriptions.
|
||
"""
|
||
import base64
|
||
import httpx
|
||
import re
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# Get original image bytes
|
||
original_png = await get_session_image(session_id, "original")
|
||
if not original_png:
|
||
raise HTTPException(status_code=400, detail="No original image found")
|
||
|
||
# Build context from vocab entries for richer descriptions
|
||
word_result = session.get("word_result") or {}
|
||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||
vocab_context = ""
|
||
if entries:
|
||
sample = entries[:10]
|
||
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||
if words:
|
||
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||
|
||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||
|
||
prompt = (
|
||
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||
"(NOT text, NOT table cells, NOT blank areas). "
|
||
"For each image region found, return its bounding box as percentage of page dimensions "
|
||
"and a short English description of what the image shows. "
|
||
"Reply with ONLY a JSON array like: "
|
||
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||
"If there are NO images on the page, return an empty array: []"
|
||
f"{vocab_context}"
|
||
)
|
||
|
||
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"images": [img_b64],
|
||
"stream": False,
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||
resp.raise_for_status()
|
||
text = resp.json().get("response", "")
|
||
|
||
# Parse JSON array from response
|
||
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||
if match:
|
||
raw_regions = json.loads(match.group(0))
|
||
else:
|
||
raw_regions = []
|
||
|
||
# Normalize to ImageRegion format
|
||
regions = []
|
||
for r in raw_regions:
|
||
regions.append({
|
||
"bbox_pct": {
|
||
"x": max(0, min(100, float(r.get("x", 0)))),
|
||
"y": max(0, min(100, float(r.get("y", 0)))),
|
||
"w": max(1, min(100, float(r.get("w", 10)))),
|
||
"h": max(1, min(100, float(r.get("h", 10)))),
|
||
},
|
||
"description": r.get("description", ""),
|
||
"prompt": r.get("description", ""),
|
||
"image_b64": None,
|
||
"style": "educational",
|
||
})
|
||
|
||
# Enrich prompts with nearby vocab context
|
||
if entries:
|
||
for region in regions:
|
||
ry = region["bbox_pct"]["y"]
|
||
rh = region["bbox_pct"]["h"]
|
||
nearby = [
|
||
e for e in entries
|
||
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||
]
|
||
if nearby:
|
||
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||
if en_words or de_words:
|
||
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||
if de_words:
|
||
context += f" / {', '.join(de_words[:5])}"
|
||
context += ")"
|
||
region["prompt"] = region["description"] + context
|
||
|
||
# Save to ground_truth JSONB
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
validation["image_regions"] = regions
|
||
validation["detected_at"] = datetime.utcnow().isoformat()
|
||
ground_truth["validation"] = validation
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||
|
||
return {"regions": regions, "count": len(regions)}
|
||
|
||
except httpx.ConnectError:
|
||
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||
except Exception as e:
|
||
logger.error(f"Image detection failed for {session_id}: {e}")
|
||
return {"regions": [], "count": 0, "error": str(e)}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||
"""Generate a replacement image for a detected region using mflux.
|
||
|
||
Sends the prompt (with style suffix) to the mflux-service running
|
||
natively on the Mac Mini (Metal GPU required).
|
||
"""
|
||
import httpx
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
regions = validation.get("image_regions") or []
|
||
|
||
if req.region_index < 0 or req.region_index >= len(regions):
|
||
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||
|
||
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||
full_prompt = f"{req.prompt}, {style_suffix}"
|
||
|
||
# Determine image size from region aspect ratio (snap to multiples of 64)
|
||
region = regions[req.region_index]
|
||
bbox = region["bbox_pct"]
|
||
aspect = bbox["w"] / max(bbox["h"], 1)
|
||
if aspect > 1.3:
|
||
width, height = 768, 512
|
||
elif aspect < 0.7:
|
||
width, height = 512, 768
|
||
else:
|
||
width, height = 512, 512
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||
resp = await client.post(f"{mflux_url}/generate", json={
|
||
"prompt": full_prompt,
|
||
"width": width,
|
||
"height": height,
|
||
"steps": 4,
|
||
})
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
image_b64 = data.get("image_b64")
|
||
|
||
if not image_b64:
|
||
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||
|
||
# Save to ground_truth
|
||
regions[req.region_index]["image_b64"] = image_b64
|
||
regions[req.region_index]["prompt"] = req.prompt
|
||
regions[req.region_index]["style"] = req.style
|
||
validation["image_regions"] = regions
|
||
ground_truth["validation"] = validation
|
||
await update_session_db(session_id, ground_truth=ground_truth)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||
return {"image_b64": image_b64, "success": True}
|
||
|
||
except httpx.ConnectError:
|
||
logger.warning(f"mflux-service not available at {mflux_url}")
|
||
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||
except Exception as e:
|
||
logger.error(f"Image generation failed for {session_id}: {e}")
|
||
return {"image_b64": None, "success": False, "error": str(e)}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||
async def save_validation(session_id: str, req: ValidationRequest):
|
||
"""Save final validation results for step 8.
|
||
|
||
Stores notes, score, and preserves any detected/generated image regions.
|
||
Sets current_step = 10 to mark pipeline as complete.
|
||
"""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation") or {}
|
||
validation["validated_at"] = datetime.utcnow().isoformat()
|
||
validation["notes"] = req.notes
|
||
validation["score"] = req.score
|
||
ground_truth["validation"] = validation
|
||
|
||
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||
|
||
if session_id in _cache:
|
||
_cache[session_id]["ground_truth"] = ground_truth
|
||
|
||
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||
|
||
return {"session_id": session_id, "validation": validation}
|
||
|
||
|
||
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||
async def get_validation(session_id: str):
|
||
"""Retrieve saved validation data for step 8."""
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
ground_truth = session.get("ground_truth") or {}
|
||
validation = ground_truth.get("validation")
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"validation": validation,
|
||
"word_result": session.get("word_result"),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Remove handwriting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@router.post("/sessions/{session_id}/remove-handwriting")
|
||
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||
"""
|
||
Remove handwriting from a session image using inpainting.
|
||
|
||
Steps:
|
||
1. Load source image (auto -> deskewed if available, else original)
|
||
2. Detect handwriting mask (filtered by target_ink)
|
||
3. Dilate mask to cover stroke edges
|
||
4. Inpaint the image
|
||
5. Store result as clean_png in the session
|
||
|
||
Returns metadata including the URL to fetch the clean image.
|
||
"""
|
||
import time as _time
|
||
t0 = _time.monotonic()
|
||
|
||
from services.handwriting_detection import detect_handwriting
|
||
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||
|
||
session = await get_session_db(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||
|
||
# 1. Determine source image
|
||
source = req.use_source
|
||
if source == "auto":
|
||
deskewed = await get_session_image(session_id, "deskewed")
|
||
source = "deskewed" if deskewed else "original"
|
||
|
||
image_bytes = await get_session_image(session_id, source)
|
||
if not image_bytes:
|
||
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||
|
||
# 2. Detect handwriting mask
|
||
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||
|
||
# 3. Convert mask to PNG bytes and dilate
|
||
import io
|
||
from PIL import Image as _PILImage
|
||
mask_img = _PILImage.fromarray(detection.mask)
|
||
mask_buf = io.BytesIO()
|
||
mask_img.save(mask_buf, format="PNG")
|
||
mask_bytes = mask_buf.getvalue()
|
||
|
||
if req.dilation > 0:
|
||
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||
|
||
# 4. Inpaint
|
||
method_map = {
|
||
"telea": InpaintingMethod.OPENCV_TELEA,
|
||
"ns": InpaintingMethod.OPENCV_NS,
|
||
"auto": InpaintingMethod.AUTO,
|
||
}
|
||
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||
|
||
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||
if not result.success:
|
||
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||
|
||
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||
|
||
meta = {
|
||
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||
"detection_confidence": round(detection.confidence, 4),
|
||
"target_ink": req.target_ink,
|
||
"dilation": req.dilation,
|
||
"source_image": source,
|
||
"processing_time_ms": elapsed_ms,
|
||
}
|
||
|
||
# 5. Persist clean image (convert BGR ndarray -> PNG bytes)
|
||
clean_png_bytes = image_to_png(result.image)
|
||
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||
|
||
return {
|
||
**meta,
|
||
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||
"session_id": session_id,
|
||
}
|