""" OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation, image detection/generation, and handwriting removal endpoints. Extracted from ocr_pipeline_api.py to keep the main module manageable. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import json import logging import os import re from datetime import datetime from typing import Any, Dict, List, Optional from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from cv_vocab_pipeline import ( OLLAMA_REVIEW_MODEL, llm_review_entries, llm_review_entries_streaming, ) from ocr_pipeline_session_store import ( get_session_db, get_session_image, get_sub_sessions, update_session_db, ) from ocr_pipeline_common import ( _cache, _load_session_to_cache, _get_cached, _get_base_image_png, _append_pipeline_log, RemoveHandwritingRequest, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) # --------------------------------------------------------------------------- # Pydantic Models # --------------------------------------------------------------------------- STYLE_SUFFIXES = { "educational": "educational illustration, textbook style, clear, colorful", "cartoon": "cartoon, child-friendly, simple shapes", "sketch": "pencil sketch, hand-drawn, black and white", "clipart": "clipart, flat vector style, simple", "realistic": "photorealistic, high detail", } class ValidationRequest(BaseModel): notes: Optional[str] = None score: Optional[int] = None class GenerateImageRequest(BaseModel): region_index: int prompt: str style: str = "educational" # --------------------------------------------------------------------------- # Step 8: LLM Review # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/llm-review") async def run_llm_review(session_id: str, request: Request, stream: bool = False): """Run LLM-based correction on vocab entries from Step 5. Query params: stream: false (default) for JSON response, true for SSE streaming """ session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") entries = word_result.get("vocab_entries") or word_result.get("entries") or [] if not entries: raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") # Optional model override from request body body = {} try: body = await request.json() except Exception: pass model = body.get("model") or OLLAMA_REVIEW_MODEL if stream: return StreamingResponse( _llm_review_stream_generator(session_id, entries, word_result, model, request), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) # Non-streaming path try: result = await llm_review_entries(entries, model=model) except Exception as e: import traceback logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") # Store result inside word_result as a sub-key word_result["llm_review"] = { "changes": result["changes"], "model_used": result["model_used"], "duration_ms": result["duration_ms"], "entries_corrected": result["entries_corrected"], } await update_session_db(session_id, word_result=word_result, current_step=9) if session_id in _cache: _cache[session_id]["word_result"] = word_result logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " f"{result['duration_ms']}ms, model={result['model_used']}") await _append_pipeline_log(session_id, "correction", { "engine": "llm", "model": result["model_used"], "total_entries": len(entries), "corrections_proposed": len(result["changes"]), }, duration_ms=result["duration_ms"]) return { "session_id": session_id, "changes": result["changes"], "model_used": result["model_used"], "duration_ms": result["duration_ms"], "total_entries": len(entries), "corrections_found": len(result["changes"]), } async def _llm_review_stream_generator( session_id: str, entries: List[Dict], word_result: Dict, model: str, request: Request, ): """SSE generator that yields batch-by-batch LLM review progress.""" try: async for event in llm_review_entries_streaming(entries, model=model): if await request.is_disconnected(): logger.info(f"SSE: client disconnected during LLM review for {session_id}") return yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" # On complete: persist to DB if event.get("type") == "complete": word_result["llm_review"] = { "changes": event["changes"], "model_used": event["model_used"], "duration_ms": event["duration_ms"], "entries_corrected": event["entries_corrected"], } await update_session_db(session_id, word_result=word_result, current_step=9) if session_id in _cache: _cache[session_id]["word_result"] = word_result logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") except Exception as e: import traceback logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} yield f"data: {json.dumps(error_event)}\n\n" @router.post("/sessions/{session_id}/llm-review/apply") async def apply_llm_corrections(session_id: str, request: Request): """Apply selected LLM corrections to vocab entries.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found") llm_review = word_result.get("llm_review") if not llm_review: raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") body = await request.json() accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] changes = llm_review.get("changes", []) entries = word_result.get("vocab_entries") or word_result.get("entries") or [] # Build a lookup: (row_index, field) -> new_value for accepted changes corrections = {} applied_count = 0 for idx, change in enumerate(changes): if idx in accepted_indices: key = (change["row_index"], change["field"]) corrections[key] = change["new"] applied_count += 1 # Apply corrections to entries for entry in entries: row_idx = entry.get("row_index", -1) for field_name in ("english", "german", "example"): key = (row_idx, field_name) if key in corrections: entry[field_name] = corrections[key] entry["llm_corrected"] = True # Update word_result word_result["vocab_entries"] = entries word_result["entries"] = entries word_result["llm_review"]["applied_count"] = applied_count word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() await update_session_db(session_id, word_result=word_result) if session_id in _cache: _cache[session_id]["word_result"] = word_result logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") return { "session_id": session_id, "applied_count": applied_count, "total_changes": len(changes), } # --------------------------------------------------------------------------- # Step 9: Reconstruction + Fabric JSON export # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/reconstruction") async def save_reconstruction(session_id: str, request: Request): """Save edited cell texts from reconstruction step.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found") body = await request.json() cell_updates = body.get("cells", []) if not cell_updates: await update_session_db(session_id, current_step=10) return {"session_id": session_id, "updated": 0} # Build update map: cell_id -> new text update_map = {c["cell_id"]: c["text"] for c in cell_updates} # Separate sub-session updates (cell_ids prefixed with "box{N}_") sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} main_updates: Dict[str, str] = {} for cell_id, text in update_map.items(): m = re.match(r'^box(\d+)_(.+)$', cell_id) if m: bi = int(m.group(1)) original_id = m.group(2) sub_updates.setdefault(bi, {})[original_id] = text else: main_updates[cell_id] = text # Update main session cells cells = word_result.get("cells", []) updated_count = 0 for cell in cells: if cell["cell_id"] in main_updates: cell["text"] = main_updates[cell["cell_id"]] cell["status"] = "edited" updated_count += 1 word_result["cells"] = cells # Also update vocab_entries if present entries = word_result.get("vocab_entries") or word_result.get("entries") or [] if entries: # Map cell_id pattern "R{row}_C{col}" to entry fields for entry in entries: row_idx = entry.get("row_index", -1) # Check each field's cell for col_idx, field_name in enumerate(["english", "german", "example"]): cell_id = f"R{row_idx:02d}_C{col_idx}" # Also try without zero-padding cell_id_alt = f"R{row_idx}_C{col_idx}" new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) if new_text is not None: entry[field_name] = new_text word_result["vocab_entries"] = entries if "entries" in word_result: word_result["entries"] = entries await update_session_db(session_id, word_result=word_result, current_step=10) if session_id in _cache: _cache[session_id]["word_result"] = word_result # Route sub-session updates sub_updated = 0 if sub_updates: subs = await get_sub_sessions(session_id) sub_by_index = {s.get("box_index"): s["id"] for s in subs} for bi, updates in sub_updates.items(): sub_id = sub_by_index.get(bi) if not sub_id: continue sub_session = await get_session_db(sub_id) if not sub_session: continue sub_word = sub_session.get("word_result") if not sub_word: continue sub_cells = sub_word.get("cells", []) for cell in sub_cells: if cell["cell_id"] in updates: cell["text"] = updates[cell["cell_id"]] cell["status"] = "edited" sub_updated += 1 sub_word["cells"] = sub_cells await update_session_db(sub_id, word_result=sub_word) if sub_id in _cache: _cache[sub_id]["word_result"] = sub_word total_updated = updated_count + sub_updated logger.info(f"Reconstruction saved for session {session_id}: " f"{updated_count} main + {sub_updated} sub-session cells updated") return { "session_id": session_id, "updated": total_updated, "main_updated": updated_count, "sub_updated": sub_updated, } @router.get("/sessions/{session_id}/reconstruction/fabric-json") async def get_fabric_json(session_id: str): """Return cell grid as Fabric.js-compatible JSON for the canvas editor. If the session has sub-sessions (box regions), their cells are merged into the result at the correct Y positions. """ session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found") cells = list(word_result.get("cells", [])) img_w = word_result.get("image_width", 800) img_h = word_result.get("image_height", 600) # Merge sub-session cells at box positions subs = await get_sub_sessions(session_id) if subs: column_result = session.get("column_result") or {} zones = column_result.get("zones") or [] box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] for sub in subs: sub_session = await get_session_db(sub["id"]) if not sub_session: continue sub_word = sub_session.get("word_result") if not sub_word or not sub_word.get("cells"): continue bi = sub.get("box_index", 0) if bi < len(box_zones): box = box_zones[bi]["box"] box_y, box_x = box["y"], box["x"] else: box_y, box_x = 0, 0 # Offset sub-session cells to absolute page coordinates for cell in sub_word["cells"]: cell_copy = dict(cell) # Prefix cell_id with box index cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" cell_copy["source"] = f"box_{bi}" # Offset bbox_px bbox = cell_copy.get("bbox_px", {}) if bbox: bbox = dict(bbox) bbox["x"] = bbox.get("x", 0) + box_x bbox["y"] = bbox.get("y", 0) + box_y cell_copy["bbox_px"] = bbox cells.append(cell_copy) from services.layout_reconstruction_service import cells_to_fabric_json fabric_json = cells_to_fabric_json(cells, img_w, img_h) return fabric_json # --------------------------------------------------------------------------- # Vocab entries merged + PDF/DOCX export # --------------------------------------------------------------------------- @router.get("/sessions/{session_id}/vocab-entries/merged") async def get_merged_vocab_entries(session_id: str): """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") or {} entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) # Tag main entries for e in entries: e.setdefault("source", "main") # Merge sub-session entries subs = await get_sub_sessions(session_id) if subs: column_result = session.get("column_result") or {} zones = column_result.get("zones") or [] box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] for sub in subs: sub_session = await get_session_db(sub["id"]) if not sub_session: continue sub_word = sub_session.get("word_result") or {} sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] bi = sub.get("box_index", 0) box_y = 0 if bi < len(box_zones): box_y = box_zones[bi]["box"]["y"] for e in sub_entries: e_copy = dict(e) e_copy["source"] = f"box_{bi}" e_copy["source_y"] = box_y # for sorting entries.append(e_copy) # Sort by approximate Y position def _sort_key(e): if e.get("source", "main") == "main": return e.get("row_index", 0) * 100 # main entries by row index return e.get("source_y", 0) * 100 + e.get("row_index", 0) entries.sort(key=_sort_key) return { "session_id": session_id, "entries": entries, "total": len(entries), "sources": list(set(e.get("source", "main") for e in entries)), } @router.get("/sessions/{session_id}/reconstruction/export/pdf") async def export_reconstruction_pdf(session_id: str): """Export the reconstructed cell grid as a PDF table.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found") cells = word_result.get("cells", []) columns_used = word_result.get("columns_used", []) grid_shape = word_result.get("grid_shape", {}) n_rows = grid_shape.get("rows", 0) n_cols = grid_shape.get("cols", 0) # Build table data: rows x columns table_data: list[list[str]] = [] header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] if not header: header = [f"Col {i}" for i in range(n_cols)] table_data.append(header) for r in range(n_rows): row_texts = [] for ci in range(n_cols): cell_id = f"R{r:02d}_C{ci}" cell = next((c for c in cells if c.get("cell_id") == cell_id), None) row_texts.append(cell.get("text", "") if cell else "") table_data.append(row_texts) # Generate PDF with reportlab try: from reportlab.lib.pagesizes import A4 from reportlab.lib import colors from reportlab.platypus import SimpleDocTemplate, Table, TableStyle import io as _io buf = _io.BytesIO() doc = SimpleDocTemplate(buf, pagesize=A4) if not table_data or not table_data[0]: raise HTTPException(status_code=400, detail="No data to export") t = Table(table_data) t.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), ('FONTSIZE', (0, 0), (-1, -1), 9), ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), ('VALIGN', (0, 0), (-1, -1), 'TOP'), ('WORDWRAP', (0, 0), (-1, -1), True), ])) doc.build([t]) buf.seek(0) from fastapi.responses import StreamingResponse return StreamingResponse( buf, media_type="application/pdf", headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, ) except ImportError: raise HTTPException(status_code=501, detail="reportlab not installed") @router.get("/sessions/{session_id}/reconstruction/export/docx") async def export_reconstruction_docx(session_id: str): """Export the reconstructed cell grid as a DOCX table.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") word_result = session.get("word_result") if not word_result: raise HTTPException(status_code=400, detail="No word result found") cells = word_result.get("cells", []) columns_used = word_result.get("columns_used", []) grid_shape = word_result.get("grid_shape", {}) n_rows = grid_shape.get("rows", 0) n_cols = grid_shape.get("cols", 0) try: from docx import Document from docx.shared import Pt import io as _io doc = Document() doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1) # Build header header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] if not header: header = [f"Col {i}" for i in range(n_cols)] table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) table.style = 'Table Grid' # Header row for ci, h in enumerate(header): table.rows[0].cells[ci].text = h # Data rows for r in range(n_rows): for ci in range(n_cols): cell_id = f"R{r:02d}_C{ci}" cell = next((c for c in cells if c.get("cell_id") == cell_id), None) table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" buf = _io.BytesIO() doc.save(buf) buf.seek(0) from fastapi.responses import StreamingResponse return StreamingResponse( buf, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, ) except ImportError: raise HTTPException(status_code=501, detail="python-docx not installed") # --------------------------------------------------------------------------- # Step 8: Validation — Original vs. Reconstruction # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/reconstruction/detect-images") async def detect_image_regions(session_id: str): """Detect illustration/image regions in the original scan using VLM. Sends the original image to qwen2.5vl to find non-text, non-table image areas, returning bounding boxes (in %) and descriptions. """ import base64 import httpx import re session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") # Get original image bytes original_png = await get_session_image(session_id, "original") if not original_png: raise HTTPException(status_code=400, detail="No original image found") # Build context from vocab entries for richer descriptions word_result = session.get("word_result") or {} entries = word_result.get("vocab_entries") or word_result.get("entries") or [] vocab_context = "" if entries: sample = entries[:10] words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] if words: vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") prompt = ( "Analyze this scanned page. Find ALL illustration/image/picture regions " "(NOT text, NOT table cells, NOT blank areas). " "For each image region found, return its bounding box as percentage of page dimensions " "and a short English description of what the image shows. " "Reply with ONLY a JSON array like: " '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' "where x, y, w, h are percentages (0-100) of the page width/height. " "If there are NO images on the page, return an empty array: []" f"{vocab_context}" ) img_b64 = base64.b64encode(original_png).decode("utf-8") payload = { "model": model, "prompt": prompt, "images": [img_b64], "stream": False, } try: async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post(f"{ollama_base}/api/generate", json=payload) resp.raise_for_status() text = resp.json().get("response", "") # Parse JSON array from response match = re.search(r'\[.*?\]', text, re.DOTALL) if match: raw_regions = json.loads(match.group(0)) else: raw_regions = [] # Normalize to ImageRegion format regions = [] for r in raw_regions: regions.append({ "bbox_pct": { "x": max(0, min(100, float(r.get("x", 0)))), "y": max(0, min(100, float(r.get("y", 0)))), "w": max(1, min(100, float(r.get("w", 10)))), "h": max(1, min(100, float(r.get("h", 10)))), }, "description": r.get("description", ""), "prompt": r.get("description", ""), "image_b64": None, "style": "educational", }) # Enrich prompts with nearby vocab context if entries: for region in regions: ry = region["bbox_pct"]["y"] rh = region["bbox_pct"]["h"] nearby = [ e for e in entries if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 ] if nearby: en_words = [e.get("english", "") for e in nearby if e.get("english")] de_words = [e.get("german", "") for e in nearby if e.get("german")] if en_words or de_words: context = f" (vocabulary context: {', '.join(en_words[:5])}" if de_words: context += f" / {', '.join(de_words[:5])}" context += ")" region["prompt"] = region["description"] + context # Save to ground_truth JSONB ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} validation["image_regions"] = regions validation["detected_at"] = datetime.utcnow().isoformat() ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Detected {len(regions)} image regions for session {session_id}") return {"regions": regions, "count": len(regions)} except httpx.ConnectError: logger.warning(f"VLM not available at {ollama_base} for image detection") return {"regions": [], "count": 0, "error": "VLM not available"} except Exception as e: logger.error(f"Image detection failed for {session_id}: {e}") return {"regions": [], "count": 0, "error": str(e)} @router.post("/sessions/{session_id}/reconstruction/generate-image") async def generate_image_for_region(session_id: str, req: GenerateImageRequest): """Generate a replacement image for a detected region using mflux. Sends the prompt (with style suffix) to the mflux-service running natively on the Mac Mini (Metal GPU required). """ import httpx session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} regions = validation.get("image_regions") or [] if req.region_index < 0 or req.region_index >= len(regions): raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) full_prompt = f"{req.prompt}, {style_suffix}" # Determine image size from region aspect ratio (snap to multiples of 64) region = regions[req.region_index] bbox = region["bbox_pct"] aspect = bbox["w"] / max(bbox["h"], 1) if aspect > 1.3: width, height = 768, 512 elif aspect < 0.7: width, height = 512, 768 else: width, height = 512, 512 try: async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(f"{mflux_url}/generate", json={ "prompt": full_prompt, "width": width, "height": height, "steps": 4, }) resp.raise_for_status() data = resp.json() image_b64 = data.get("image_b64") if not image_b64: return {"image_b64": None, "success": False, "error": "No image returned"} # Save to ground_truth regions[req.region_index]["image_b64"] = image_b64 regions[req.region_index]["prompt"] = req.prompt regions[req.region_index]["style"] = req.style validation["image_regions"] = regions ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Generated image for session {session_id} region {req.region_index}") return {"image_b64": image_b64, "success": True} except httpx.ConnectError: logger.warning(f"mflux-service not available at {mflux_url}") return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} except Exception as e: logger.error(f"Image generation failed for {session_id}: {e}") return {"image_b64": None, "success": False, "error": str(e)} @router.post("/sessions/{session_id}/reconstruction/validate") async def save_validation(session_id: str, req: ValidationRequest): """Save final validation results for step 8. Stores notes, score, and preserves any detected/generated image regions. Sets current_step = 10 to mark pipeline as complete. """ session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} validation["validated_at"] = datetime.utcnow().isoformat() validation["notes"] = req.notes validation["score"] = req.score ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth, current_step=11) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Validation saved for session {session_id}: score={req.score}") return {"session_id": session_id, "validation": validation} @router.get("/sessions/{session_id}/reconstruction/validation") async def get_validation(session_id: str): """Retrieve saved validation data for step 8.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") return { "session_id": session_id, "validation": validation, "word_result": session.get("word_result"), } # --------------------------------------------------------------------------- # Remove handwriting # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/remove-handwriting") async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): """ Remove handwriting from a session image using inpainting. Steps: 1. Load source image (auto -> deskewed if available, else original) 2. Detect handwriting mask (filtered by target_ink) 3. Dilate mask to cover stroke edges 4. Inpaint the image 5. Store result as clean_png in the session Returns metadata including the URL to fetch the clean image. """ import time as _time t0 = _time.monotonic() from services.handwriting_detection import detect_handwriting from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") # 1. Determine source image source = req.use_source if source == "auto": deskewed = await get_session_image(session_id, "deskewed") source = "deskewed" if deskewed else "original" image_bytes = await get_session_image(session_id, source) if not image_bytes: raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") # 2. Detect handwriting mask detection = detect_handwriting(image_bytes, target_ink=req.target_ink) # 3. Convert mask to PNG bytes and dilate import io from PIL import Image as _PILImage mask_img = _PILImage.fromarray(detection.mask) mask_buf = io.BytesIO() mask_img.save(mask_buf, format="PNG") mask_bytes = mask_buf.getvalue() if req.dilation > 0: mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) # 4. Inpaint method_map = { "telea": InpaintingMethod.OPENCV_TELEA, "ns": InpaintingMethod.OPENCV_NS, "auto": InpaintingMethod.AUTO, } inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) if not result.success: raise HTTPException(status_code=500, detail="Inpainting failed") elapsed_ms = int((_time.monotonic() - t0) * 1000) meta = { "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), "handwriting_ratio": round(detection.handwriting_ratio, 4), "detection_confidence": round(detection.confidence, 4), "target_ink": req.target_ink, "dilation": req.dilation, "source_image": source, "processing_time_ms": elapsed_ms, } # 5. Persist clean image (convert BGR ndarray -> PNG bytes) clean_png_bytes = image_to_png(result.image) await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) return { **meta, "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", "session_id": session_id, }