""" OCR Pipeline Dewarp Endpoints Auto dewarp (with VLM/CV ensemble), manual dewarp, combined rotation+shear adjustment, and ground truth. Extracted from ocr_pipeline_geometry.py for file-size compliance. """ import json import logging import os import re import time from datetime import datetime from typing import Any, Dict import cv2 from fastapi import APIRouter, HTTPException, Query from cv_vocab_pipeline import ( _apply_shear, create_ocr_image, dewarp_image, dewarp_image_manual, ) from ocr_pipeline_session_store import ( get_session_db, update_session_db, ) from ocr_pipeline_common import ( _cache, _load_session_to_cache, _get_cached, _append_pipeline_log, ManualDewarpRequest, CombinedAdjustRequest, DewarpGroundTruthRequest, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. The VLM is shown the image and asked: are the column/table borders tilted? If yes, by how many degrees? Returns a dict with shear_degrees and confidence. Confidence is 0.0 if Ollama is unavailable or parsing fails. """ import httpx import base64 ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") prompt = ( "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " "Are they perfectly vertical, or do they tilt slightly? " "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" ) img_b64 = base64.b64encode(image_bytes).decode("utf-8") payload = { "model": model, "prompt": prompt, "images": [img_b64], "stream": False, } try: async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post(f"{ollama_base}/api/generate", json=payload) resp.raise_for_status() text = resp.json().get("response", "") # Parse JSON from response (may have surrounding text) match = re.search(r'\{[^}]+\}', text) if match: data = json.loads(match.group(0)) shear = float(data.get("shear_degrees", 0.0)) conf = float(data.get("confidence", 0.0)) # Clamp to reasonable range shear = max(-3.0, min(3.0, shear)) conf = max(0.0, min(1.0, conf)) return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} except Exception as e: logger.warning(f"VLM dewarp failed: {e}") return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} @router.post("/sessions/{session_id}/dewarp") async def auto_dewarp( session_id: str, method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), ): """Detect and correct vertical shear on the deskewed image. Methods: - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) - **cv**: CV ensemble only (same as ensemble) - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually """ if method not in ("ensemble", "cv", "vlm"): raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) deskewed_bgr = cached.get("deskewed_bgr") if deskewed_bgr is None: raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") t0 = time.time() if method == "vlm": # Encode deskewed image to PNG for VLM success, png_buf = cv2.imencode(".png", deskewed_bgr) img_bytes = png_buf.tobytes() if success else b"" vlm_det = await _detect_shear_with_vlm(img_bytes) shear_deg = vlm_det["shear_degrees"] if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) else: dewarped_bgr = deskewed_bgr dewarp_info = { "method": vlm_det["method"], "shear_degrees": shear_deg, "confidence": vlm_det["confidence"], "detections": [vlm_det], } else: dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) duration = time.time() - t0 # Encode as PNG success, png_buf = cv2.imencode(".png", dewarped_bgr) dewarped_png = png_buf.tobytes() if success else b"" dewarp_result = { "method_used": dewarp_info["method"], "shear_degrees": dewarp_info["shear_degrees"], "confidence": dewarp_info["confidence"], "duration_seconds": round(duration, 2), "detections": dewarp_info.get("detections", []), } # Update cache cached["dewarped_bgr"] = dewarped_bgr cached["dewarp_result"] = dewarp_result # Persist to DB await update_session_db( session_id, dewarped_png=dewarped_png, dewarp_result=dewarp_result, auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), current_step=4, ) logger.info(f"OCR Pipeline: dewarp session {session_id}: " f"method={dewarp_info['method']} shear={dewarp_info['shear_degrees']:.3f} " f"conf={dewarp_info['confidence']:.2f} ({duration:.2f}s)") await _append_pipeline_log(session_id, "dewarp", { "shear_degrees": dewarp_info["shear_degrees"], "confidence": dewarp_info["confidence"], "method": dewarp_info["method"], "ensemble_methods": [d.get("method", "") for d in dewarp_info.get("detections", [])], }, duration_ms=int(duration * 1000)) return { "session_id": session_id, **dewarp_result, "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", } @router.post("/sessions/{session_id}/dewarp/manual") async def manual_dewarp(session_id: str, req: ManualDewarpRequest): """Apply shear correction with a manual angle.""" if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) deskewed_bgr = cached.get("deskewed_bgr") if deskewed_bgr is None: raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") shear_deg = max(-2.0, min(2.0, req.shear_degrees)) if abs(shear_deg) < 0.001: dewarped_bgr = deskewed_bgr else: dewarped_bgr = dewarp_image_manual(deskewed_bgr, shear_deg) success, png_buf = cv2.imencode(".png", dewarped_bgr) dewarped_png = png_buf.tobytes() if success else b"" dewarp_result = { **(cached.get("dewarp_result") or {}), "method_used": "manual", "shear_degrees": round(shear_deg, 3), } # Update cache cached["dewarped_bgr"] = dewarped_bgr cached["dewarp_result"] = dewarp_result # Persist to DB await update_session_db( session_id, dewarped_png=dewarped_png, dewarp_result=dewarp_result, ) logger.info(f"OCR Pipeline: manual dewarp session {session_id}: shear={shear_deg:.3f}") return { "session_id": session_id, "shear_degrees": round(shear_deg, 3), "method_used": "manual", "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", } @router.post("/sessions/{session_id}/adjust-combined") async def adjust_combined(session_id: str, req: CombinedAdjustRequest): """Apply rotation + shear combined to the original image. Used by the fine-tuning sliders to preview arbitrary rotation/shear combinations without re-running the full deskew/dewarp pipeline. """ if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) img_bgr = cached.get("original_bgr") if img_bgr is None: raise HTTPException(status_code=400, detail="Original image not available") rotation = max(-15.0, min(15.0, req.rotation_degrees)) shear_deg = max(-5.0, min(5.0, req.shear_degrees)) h, w = img_bgr.shape[:2] result_bgr = img_bgr # Step 1: Apply rotation if abs(rotation) >= 0.001: center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, rotation, 1.0) result_bgr = cv2.warpAffine(result_bgr, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) # Step 2: Apply shear if abs(shear_deg) >= 0.001: result_bgr = dewarp_image_manual(result_bgr, shear_deg) # Encode success, png_buf = cv2.imencode(".png", result_bgr) dewarped_png = png_buf.tobytes() if success else b"" # Binarize binarized_png = None try: binarized = create_ocr_image(result_bgr) success_bin, bin_buf = cv2.imencode(".png", binarized) binarized_png = bin_buf.tobytes() if success_bin else None except Exception: pass # Build combined result dicts deskew_result = { **(cached.get("deskew_result") or {}), "angle_applied": round(rotation, 3), "method_used": "manual_combined", } dewarp_result = { **(cached.get("dewarp_result") or {}), "method_used": "manual_combined", "shear_degrees": round(shear_deg, 3), } # Update cache cached["deskewed_bgr"] = result_bgr cached["dewarped_bgr"] = result_bgr cached["deskew_result"] = deskew_result cached["dewarp_result"] = dewarp_result # Persist to DB db_update = { "dewarped_png": dewarped_png, "deskew_result": deskew_result, "dewarp_result": dewarp_result, } if binarized_png: db_update["binarized_png"] = binarized_png db_update["deskewed_png"] = dewarped_png await update_session_db(session_id, **db_update) logger.info(f"OCR Pipeline: combined adjust session {session_id}: " f"rotation={rotation:.3f} shear={shear_deg:.3f}") return { "session_id": session_id, "rotation_degrees": round(rotation, 3), "shear_degrees": round(shear_deg, 3), "method_used": "manual_combined", "dewarped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/dewarped", } @router.post("/sessions/{session_id}/ground-truth/dewarp") async def save_dewarp_ground_truth(session_id: str, req: DewarpGroundTruthRequest): """Save ground truth feedback for the dewarp step.""" session = await get_session_db(session_id) if not session: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") ground_truth = session.get("ground_truth") or {} gt = { "is_correct": req.is_correct, "corrected_shear": req.corrected_shear, "notes": req.notes, "saved_at": datetime.utcnow().isoformat(), "dewarp_result": session.get("dewarp_result"), } ground_truth["dewarp"] = gt await update_session_db(session_id, ground_truth=ground_truth) if session_id in _cache: _cache[session_id]["ground_truth"] = ground_truth logger.info(f"OCR Pipeline: ground truth dewarp session {session_id}: " f"correct={req.is_correct}, corrected_shear={req.corrected_shear}") return {"session_id": session_id, "ground_truth": gt}