""" OCR Pipeline Validation — image detection, generation, validation save, and handwriting removal endpoints. Extracted from ocr_pipeline_postprocess.py. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import json import logging import os from datetime import datetime from typing import Optional from fastapi import APIRouter, HTTPException from pydantic import BaseModel from ocr_pipeline_session_store import ( get_session_db, get_session_image, update_session_db, ) from ocr_pipeline_common import ( _cache, RemoveHandwritingRequest, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) # --------------------------------------------------------------------------- # Pydantic Models # --------------------------------------------------------------------------- STYLE_SUFFIXES = { "educational": "educational illustration, textbook style, clear, colorful", "cartoon": "cartoon, child-friendly, simple shapes", "sketch": "pencil sketch, hand-drawn, black and white", "clipart": "clipart, flat vector style, simple", "realistic": "photorealistic, high detail", } class ValidationRequest(BaseModel): notes: Optional[str] = None score: Optional[int] = None class GenerateImageRequest(BaseModel): region_index: int prompt: str style: str = "educational" # --------------------------------------------------------------------------- # Image detection + generation # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/reconstruction/detect-images") async def detect_image_regions(session_id: str): """Detect illustration/image regions in the original scan using VLM.""" import base64 import httpx import re session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") original_png = await get_session_image(session_id, "original") if not original_png: raise HTTPException(status_code=400, detail="No original image found") word_result = session.get("word_result") or {} entries = word_result.get("vocab_entries") or word_result.get("entries") or [] vocab_context = "" if entries: sample = entries[:10] words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] if words: vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") prompt = ( "Analyze this scanned page. Find ALL illustration/image/picture regions " "(NOT text, NOT table cells, NOT blank areas). " "For each image region found, return its bounding box as percentage of page dimensions " "and a short English description of what the image shows. " "Reply with ONLY a JSON array like: " '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' "where x, y, w, h are percentages (0-100) of the page width/height. " "If there are NO images on the page, return an empty array: []" f"{vocab_context}" ) img_b64 = base64.b64encode(original_png).decode("utf-8") payload = { "model": model, "prompt": prompt, "images": [img_b64], "stream": False, } try: async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post(f"{ollama_base}/api/generate", json=payload) resp.raise_for_status() text = resp.json().get("response", "") match = re.search(r'\[.*?\]', text, re.DOTALL) if match: raw_regions = json.loads(match.group(0)) else: raw_regions = [] regions = [] for r in raw_regions: regions.append({ "bbox_pct": { "x": max(0, min(100, float(r.get("x", 0)))), "y": max(0, min(100, float(r.get("y", 0)))), "w": max(1, min(100, float(r.get("w", 10)))), "h": max(1, min(100, float(r.get("h", 10)))), }, "description": r.get("description", ""), "prompt": r.get("description", ""), "image_b64": None, "style": "educational", }) # Enrich prompts with nearby vocab context if entries: for region in regions: ry = region["bbox_pct"]["y"] rh = region["bbox_pct"]["h"] nearby = [ e for e in entries if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 ] if nearby: en_words = [e.get("english", "") for e in nearby if e.get("english")] de_words = [e.get("german", "") for e in nearby if e.get("german")] if en_words or de_words: context = f" (vocabulary context: {', '.join(en_words[:5])}" if de_words: context += f" / {', '.join(de_words[:5])}" context += ")" region["prompt"] = region["description"] + context ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} validation["image_regions"] = regions validation["detected_at"] = datetime.utcnow().isoformat() ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Detected {len(regions)} image regions for session {session_id}") return {"regions": regions, "count": len(regions)} except httpx.ConnectError: logger.warning(f"VLM not available at {ollama_base} for image detection") return {"regions": [], "count": 0, "error": "VLM not available"} except Exception as e: logger.error(f"Image detection failed for {session_id}: {e}") return {"regions": [], "count": 0, "error": str(e)} @router.post("/sessions/{session_id}/reconstruction/generate-image") async def generate_image_for_region(session_id: str, req: GenerateImageRequest): """Generate a replacement image for a detected region using mflux.""" import httpx session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} regions = validation.get("image_regions") or [] if req.region_index < 0 or req.region_index >= len(regions): raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) full_prompt = f"{req.prompt}, {style_suffix}" region = regions[req.region_index] bbox = region["bbox_pct"] aspect = bbox["w"] / max(bbox["h"], 1) if aspect > 1.3: width, height = 768, 512 elif aspect < 0.7: width, height = 512, 768 else: width, height = 512, 512 try: async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(f"{mflux_url}/generate", json={ "prompt": full_prompt, "width": width, "height": height, "steps": 4, }) resp.raise_for_status() data = resp.json() image_b64 = data.get("image_b64") if not image_b64: return {"image_b64": None, "success": False, "error": "No image returned"} regions[req.region_index]["image_b64"] = image_b64 regions[req.region_index]["prompt"] = req.prompt regions[req.region_index]["style"] = req.style validation["image_regions"] = regions ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Generated image for session {session_id} region {req.region_index}") return {"image_b64": image_b64, "success": True} except httpx.ConnectError: logger.warning(f"mflux-service not available at {mflux_url}") return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} except Exception as e: logger.error(f"Image generation failed for {session_id}: {e}") return {"image_b64": None, "success": False, "error": str(e)} # --------------------------------------------------------------------------- # Validation save/get # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/reconstruction/validate") async def save_validation(session_id: str, req: ValidationRequest): """Save final validation results for step 8.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") or {} validation["validated_at"] = datetime.utcnow().isoformat() validation["notes"] = req.notes validation["score"] = req.score ground_truth["validation"] = validation await update_session_db(session_id, ground_truth=ground_truth, current_step=11) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"Validation saved for session {session_id}: score={req.score}") return {"session_id": session_id, "validation": validation} @router.get("/sessions/{session_id}/reconstruction/validation") async def get_validation(session_id: str): """Retrieve saved validation data for step 8.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} validation = ground_truth.get("validation") return { "session_id": session_id, "validation": validation, "word_result": session.get("word_result"), } # --------------------------------------------------------------------------- # Remove handwriting # --------------------------------------------------------------------------- @router.post("/sessions/{session_id}/remove-handwriting") async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): """Remove handwriting from a session image using inpainting.""" import time as _time from services.handwriting_detection import detect_handwriting from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") t0 = _time.monotonic() # 1. Determine source image source = req.use_source if source == "auto": deskewed = await get_session_image(session_id, "deskewed") source = "deskewed" if deskewed else "original" image_bytes = await get_session_image(session_id, source) if not image_bytes: raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") # 2. Detect handwriting mask detection = detect_handwriting(image_bytes, target_ink=req.target_ink) # 3. Convert mask to PNG bytes and dilate import io from PIL import Image as _PILImage mask_img = _PILImage.fromarray(detection.mask) mask_buf = io.BytesIO() mask_img.save(mask_buf, format="PNG") mask_bytes = mask_buf.getvalue() if req.dilation > 0: mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) # 4. Inpaint method_map = { "telea": InpaintingMethod.OPENCV_TELEA, "ns": InpaintingMethod.OPENCV_NS, "auto": InpaintingMethod.AUTO, } inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) if not result.success: raise HTTPException(status_code=500, detail="Inpainting failed") elapsed_ms = int((_time.monotonic() - t0) * 1000) meta = { "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), "handwriting_ratio": round(detection.handwriting_ratio, 4), "detection_confidence": round(detection.confidence, 4), "target_ink": req.target_ink, "dilation": req.dilation, "source_image": source, "processing_time_ms": elapsed_ms, } # 5. Persist clean image clean_png_bytes = image_to_png(result.image) await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) return { **meta, "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", "session_id": session_id, }