Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_postprocess.py
Benjamin Admin ec287fd12e refactor: split ocr_pipeline_api.py (5426 lines) into 8 modules
Each module is under 1050 lines:
- ocr_pipeline_common.py (354) - shared state, cache, models, helpers
- ocr_pipeline_sessions.py (483) - session CRUD, image serving, doc-type
- ocr_pipeline_geometry.py (1025) - deskew, dewarp, structure, columns
- ocr_pipeline_rows.py (348) - row detection, box-overlay helper
- ocr_pipeline_words.py (876) - word detection (SSE), paddle-direct
- ocr_pipeline_ocr_merge.py (615) - merge helpers, kombi endpoints
- ocr_pipeline_postprocess.py (929) - LLM review, reconstruction, export
- ocr_pipeline_auto.py (705) - auto-mode orchestrator, reprocess

ocr_pipeline_api.py is now a 61-line thin wrapper that re-exports
router, _cache, and test-imported symbols for backward compatibility.
No changes needed in main.py or tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 08:42:00 +01:00

930 lines
35 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation,
image detection/generation, and handwriting removal endpoints.
Extracted from ocr_pipeline_api.py to keep the main module manageable.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
llm_review_entries,
llm_review_entries_streaming,
)
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
get_sub_sessions,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Step 8: LLM Review
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}
# ---------------------------------------------------------------------------
# Step 9: Reconstruction + Fabric JSON export
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=10)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in main_updates:
cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
# Map cell_id pattern "R{row}_C{col}" to entry fields
for entry in entries:
row_idx = entry.get("row_index", -1)
# Check each field's cell
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
# Also try without zero-padding
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=10)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
# Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return {
"session_id": session_id,
"updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
If the session has sub-sessions (box regions), their cells are merged
into the result at the correct Y positions.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
# Offset sub-session cells to absolute page coordinates
for cell in sub_word["cells"]:
cell_copy = dict(cell)
# Prefix cell_id with box index
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
# Offset bbox_px
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
# ---------------------------------------------------------------------------
# Vocab entries merged + PDF/DOCX export
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
# Tag main entries
for e in entries:
e.setdefault("source", "main")
# Merge sub-session entries
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y # for sorting
entries.append(e_copy)
# Sort by approximate Y position
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100 # main entries by row index
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows x columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
# Generate PDF with reportlab
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion Session {session_id[:8]}', level=1)
# Build header
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
# Header row
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
# Data rows
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")
# ---------------------------------------------------------------------------
# Step 8: Validation — Original vs. Reconstruction
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM.
Sends the original image to qwen2.5vl to find non-text, non-table
image areas, returning bounding boxes (in %) and descriptions.
"""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get original image bytes
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
# Build context from vocab entries for richer descriptions
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON array from response
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
# Normalize to ImageRegion format
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
# Save to ground_truth JSONB
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux.
Sends the prompt (with style suffix) to the mflux-service running
natively on the Mac Mini (Metal GPU required).
"""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
# Determine image size from region aspect ratio (snap to multiples of 64)
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
# Save to ground_truth
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8.
Stores notes, score, and preserves any detected/generated image regions.
Sets current_step = 10 to mark pipeline as complete.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""
Remove handwriting from a session image using inpainting.
Steps:
1. Load source image (auto -> deskewed if available, else original)
2. Detect handwriting mask (filtered by target_ink)
3. Dilate mask to cover stroke edges
4. Inpaint the image
5. Store result as clean_png in the session
Returns metadata including the URL to fetch the clean image.
"""
import time as _time
t0 = _time.monotonic()
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image (convert BGR ndarray -> PNG bytes)
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}