From 50e1c964ee905306eb7ceec2fe1794c0d28bf35e Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Tue, 3 Mar 2026 13:13:20 +0100 Subject: [PATCH] feat(klausur-service): OCR-Pipeline Optimierungen (Improvements 2-4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Improvement 2: VLM-basierter Dewarp - Neuer Query-Parameter `method` für POST /sessions/{id}/dewarp Optionen: ensemble (default) | vlm | cv - `_detect_shear_with_vlm()`: fragt qwen2.5vl:32b per Ollama nach dem Scherwinkel — gibt Zahlenwert + Konfidenz zurück - `os`, `Query` zu ocr_pipeline_api.py Imports hinzugefügt - `_apply_shear` aus cv_vocab_pipeline importiert ## Improvement 4: 3-Methoden Ensemble-Dewarp - `_detect_shear_by_projection()`: Varianz-Sweep ±3° / 0.25°-Schritte auf horizontalen Text-Zeilen-Projektionen (~30ms) - `_detect_shear_by_hough()`: Gewichteter Median über HoughLinesP auf Tabellen-Linien, Vorzeichen-Inversion (~20ms) - `_ensemble_shear()`: Kombiniert alle 3 Methoden (conf >= 0.3), Ausreißer-Filter bei >1° Abweichung, Bonus bei Agreement <0.5° - `dewarp_image()` nutzt jetzt alle 3 Methoden parallel, `use_ensemble: bool = True` für Rückwärtskompatibilität - auto_dewarp Response enthält jetzt `detections`-Array ## Improvement 3: Vollautomatik-Endpoint - POST /sessions/{id}/run-auto mit RunAutoRequest: from_step (1-6), ocr_engine, pronunciation, skip_llm_review, dewarp_method - SSE-Streaming für alle 5+1 Schritte (deskew→dewarp→columns→rows→words→llm-review) - Jeder Schritt: start / done / skipped / error Events - Abschluss-Event: {steps_run, steps_skipped} - LLM-Review-Fehler sind nicht-fatal (Pipeline läuft weiter) Co-Authored-By: Claude Sonnet 4.6 --- klausur-service/backend/cv_vocab_pipeline.py | 404 ++++++++++++- klausur-service/backend/ocr_pipeline_api.py | 598 ++++++++++++++++++- 2 files changed, 975 insertions(+), 27 deletions(-) diff --git a/klausur-service/backend/cv_vocab_pipeline.py b/klausur-service/backend/cv_vocab_pipeline.py index 5452f31..1d62319 100644 --- a/klausur-service/backend/cv_vocab_pipeline.py +++ b/klausur-service/backend/cv_vocab_pipeline.py @@ -484,6 +484,133 @@ def _detect_shear_angle(img: np.ndarray) -> Dict[str, Any]: return result +def _detect_shear_by_projection(img: np.ndarray) -> Dict[str, Any]: + """Detect shear angle by maximising variance of horizontal text-line projections. + + Principle: horizontal text lines produce a row-projection profile with sharp + peaks (high variance) when the image is correctly aligned. Any residual shear + smears the peaks and reduces variance. We sweep ±3° and pick the angle whose + corrected projection has the highest variance. + + Works best on pages with clear horizontal banding (vocabulary tables, prose). + Complements _detect_shear_angle() which needs strong vertical edges. + + Returns: + Dict with keys: method, shear_degrees, confidence. + """ + import math + result = {"method": "projection", "shear_degrees": 0.0, "confidence": 0.0} + + h, w = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Otsu binarisation + _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + # Work at half resolution for speed + small = cv2.resize(binary, (w // 2, h // 2), interpolation=cv2.INTER_AREA) + sh, sw = small.shape + + # Angle sweep: ±3° in 0.25° steps + angles = [a * 0.25 for a in range(-12, 13)] # 25 values + best_angle = 0.0 + best_variance = -1.0 + variances: List[Tuple[float, float]] = [] + + for angle_deg in angles: + if abs(angle_deg) < 0.01: + rotated = small + else: + shear_tan = math.tan(math.radians(angle_deg)) + M = np.float32([[1, shear_tan, -sh / 2.0 * shear_tan], [0, 1, 0]]) + rotated = cv2.warpAffine(small, M, (sw, sh), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT) + profile = np.sum(rotated, axis=1).astype(float) + var = float(np.var(profile)) + variances.append((angle_deg, var)) + if var > best_variance: + best_variance = var + best_angle = angle_deg + + # Confidence: how much sharper is the best angle vs. the mean? + all_mean = sum(v for _, v in variances) / len(variances) + if all_mean > 0 and best_variance > all_mean: + confidence = min(1.0, (best_variance - all_mean) / (all_mean + 1.0) * 0.6) + else: + confidence = 0.0 + + result["shear_degrees"] = round(best_angle, 3) + result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) + return result + + +def _detect_shear_by_hough(img: np.ndarray) -> Dict[str, Any]: + """Detect shear using Hough transform on printed table / ruled lines. + + Vocabulary worksheets have near-horizontal printed table borders. After + deskew these should be exactly horizontal; any residual tilt equals the + vertical shear angle (with inverted sign). + + The sign convention: a horizontal line tilting +α degrees (left end lower) + means the page has vertical shear of -α degrees (left column edge drifts + to the left going downward). + + Returns: + Dict with keys: method, shear_degrees, confidence. + """ + result = {"method": "hough_lines", "shear_degrees": 0.0, "confidence": 0.0} + + h, w = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + edges = cv2.Canny(gray, 50, 150, apertureSize=3) + + min_len = int(w * 0.15) + lines = cv2.HoughLinesP( + edges, rho=1, theta=np.pi / 360, + threshold=int(w * 0.08), + minLineLength=min_len, + maxLineGap=20, + ) + + if lines is None or len(lines) < 3: + return result + + horizontal_angles: List[Tuple[float, float]] = [] + for line in lines: + x1, y1, x2, y2 = line[0] + if x1 == x2: + continue + angle = float(np.degrees(np.arctan2(y2 - y1, x2 - x1))) + if abs(angle) <= 5.0: + length = float(np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)) + horizontal_angles.append((angle, length)) + + if len(horizontal_angles) < 3: + return result + + # Weighted median + angles_arr = np.array([a for a, _ in horizontal_angles]) + weights_arr = np.array([l for _, l in horizontal_angles]) + sorted_idx = np.argsort(angles_arr) + s_angles = angles_arr[sorted_idx] + s_weights = weights_arr[sorted_idx] + cum = np.cumsum(s_weights) + mid_idx = int(np.searchsorted(cum, cum[-1] / 2.0)) + median_angle = float(s_angles[min(mid_idx, len(s_angles) - 1)]) + + agree = sum(1 for a, _ in horizontal_angles if abs(a - median_angle) < 1.0) + confidence = min(1.0, agree / max(len(horizontal_angles), 1)) * 0.85 + + # Sign inversion: horizontal line tilt is complementary to vertical shear + shear_degrees = -median_angle + + result["shear_degrees"] = round(shear_degrees, 3) + result["confidence"] = round(max(0.0, min(1.0, confidence)), 2) + return result + + def _apply_shear(img: np.ndarray, shear_degrees: float) -> np.ndarray: """Apply a vertical shear correction to an image. @@ -516,24 +643,78 @@ def _apply_shear(img: np.ndarray, shear_degrees: float) -> np.ndarray: return corrected -def dewarp_image(img: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]: +def _ensemble_shear(detections: List[Dict[str, Any]]) -> Tuple[float, float, str]: + """Combine multiple shear detections into a single weighted estimate. + + Only methods with confidence >= 0.3 are considered. + Results are outlier-filtered: if any accepted result differs by more than + 1° from the weighted mean, it is discarded. + + Returns: + (shear_degrees, ensemble_confidence, methods_used_str) + """ + accepted = [(d["shear_degrees"], d["confidence"], d["method"]) + for d in detections if d["confidence"] >= 0.3] + + if not accepted: + return 0.0, 0.0, "none" + + if len(accepted) == 1: + deg, conf, method = accepted[0] + return deg, conf, method + + # First pass: weighted mean + total_w = sum(c for _, c, _ in accepted) + w_mean = sum(d * c for d, c, _ in accepted) / total_w + + # Outlier filter: keep results within 1° of weighted mean + filtered = [(d, c, m) for d, c, m in accepted if abs(d - w_mean) <= 1.0] + if not filtered: + filtered = accepted # fallback: keep all + + # Second pass: weighted mean on filtered results + total_w2 = sum(c for _, c, _ in filtered) + final_deg = sum(d * c for d, c, _ in filtered) / total_w2 + + # Ensemble confidence: average of individual confidences, boosted when + # methods agree (all within 0.5° of each other) + avg_conf = total_w2 / len(filtered) + spread = max(d for d, _, _ in filtered) - min(d for d, _, _ in filtered) + agreement_bonus = 0.15 if spread < 0.5 else 0.0 + ensemble_conf = min(1.0, avg_conf + agreement_bonus) + + methods_str = "+".join(m for _, _, m in filtered) + return round(final_deg, 3), round(ensemble_conf, 2), methods_str + + +def dewarp_image(img: np.ndarray, use_ensemble: bool = True) -> Tuple[np.ndarray, Dict[str, Any]]: """Correct vertical shear after deskew. After deskew aligns horizontal text lines, vertical features (column - edges) may still be tilted. This detects the tilt angle of the strongest - vertical edge and applies an affine shear correction. + edges) may still be tilted. This detects the tilt angle using an ensemble + of three complementary methods and applies an affine shear correction. + + Methods (all run in ~100ms total): + A. _detect_shear_angle() — vertical edge profile (~50ms) + B. _detect_shear_by_projection() — horizontal text-line variance (~30ms) + C. _detect_shear_by_hough() — Hough lines on table borders (~20ms) + + Only methods with confidence >= 0.3 contribute to the ensemble. + Outlier filtering discards results deviating > 1° from the weighted mean. Args: img: BGR image (already deskewed). + use_ensemble: If False, fall back to single-method behaviour (method A only). Returns: Tuple of (corrected_image, dewarp_info). - dewarp_info keys: method, shear_degrees, confidence. + dewarp_info keys: method, shear_degrees, confidence, detections. """ no_correction = { "method": "none", "shear_degrees": 0.0, "confidence": 0.0, + "detections": [], } if not CV2_AVAILABLE: @@ -541,14 +722,31 @@ def dewarp_image(img: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]: t0 = time.time() - detection = _detect_shear_angle(img) + if use_ensemble: + det_a = _detect_shear_angle(img) + det_b = _detect_shear_by_projection(img) + det_c = _detect_shear_by_hough(img) + detections = [det_a, det_b, det_c] + shear_deg, confidence, method = _ensemble_shear(detections) + else: + det_a = _detect_shear_angle(img) + detections = [det_a] + shear_deg = det_a["shear_degrees"] + confidence = det_a["confidence"] + method = det_a["method"] + duration = time.time() - t0 - shear_deg = detection["shear_degrees"] - confidence = detection["confidence"] - - logger.info(f"dewarp: detected shear={shear_deg:.3f}° " - f"conf={confidence:.2f} ({duration:.2f}s)") + logger.info( + "dewarp: ensemble shear=%.3f° conf=%.2f method=%s (%.2fs) | " + "A=%.3f/%.2f B=%.3f/%.2f C=%.3f/%.2f", + shear_deg, confidence, method, duration, + detections[0]["shear_degrees"], detections[0]["confidence"], + detections[1]["shear_degrees"] if len(detections) > 1 else 0.0, + detections[1]["confidence"] if len(detections) > 1 else 0.0, + detections[2]["shear_degrees"] if len(detections) > 2 else 0.0, + detections[2]["confidence"] if len(detections) > 2 else 0.0, + ) # Only correct if shear is significant (> 0.05°) if abs(shear_deg) < 0.05 or confidence < 0.3: @@ -558,9 +756,14 @@ def dewarp_image(img: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]: corrected = _apply_shear(img, -shear_deg) info = { - "method": detection["method"], + "method": method, "shear_degrees": shear_deg, "confidence": confidence, + "detections": [ + {"method": d["method"], "shear_degrees": d["shear_degrees"], + "confidence": d["confidence"]} + for d in detections + ], } return corrected, info @@ -3053,6 +3256,142 @@ def ocr_region_rapid( return words +def ocr_region_trocr(img_bgr: np.ndarray, region: PageRegion, handwritten: bool = False) -> List[Dict[str, Any]]: + """Run TrOCR on a region. Returns line-level word dicts (same format as ocr_region_rapid). + + Uses trocr_service.get_trocr_model() + _split_into_lines() for line segmentation. + Bboxes are approximated from equal line-height distribution within the region. + Falls back to Tesseract if TrOCR is not available. + """ + from services.trocr_service import get_trocr_model, _split_into_lines, _check_trocr_available + + if not _check_trocr_available(): + logger.warning("TrOCR not available, falling back to Tesseract") + if region.height > 0 and region.width > 0: + ocr_img_crop = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) if img_bgr is not None else None + if ocr_img_crop is not None: + return ocr_region(ocr_img_crop, region, lang="eng+deu", psm=6) + return [] + + crop = img_bgr[region.y:region.y + region.height, region.x:region.x + region.width] + if crop.size == 0: + return [] + + try: + import torch + from PIL import Image as _PILImage + + processor, model = get_trocr_model(handwritten=handwritten) + if processor is None or model is None: + logger.warning("TrOCR model not loaded, falling back to Tesseract") + ocr_img_crop = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + return ocr_region(ocr_img_crop, region, lang="eng+deu", psm=6) + + pil_crop = _PILImage.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) + lines = _split_into_lines(pil_crop) + if not lines: + lines = [pil_crop] + + device = next(model.parameters()).device + all_text = [] + confidences = [] + for line_img in lines: + pixel_values = processor(images=line_img, return_tensors="pt").pixel_values.to(device) + with torch.no_grad(): + generated_ids = model.generate(pixel_values, max_length=128) + text_line = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + if text_line: + all_text.append(text_line) + confidences.append(0.85 if len(text_line) > 3 else 0.5) + + if not all_text: + return [] + + avg_conf = int(sum(confidences) / len(confidences) * 100) + line_h = region.height // max(len(all_text), 1) + words = [] + for i, line in enumerate(all_text): + words.append({ + "text": line, + "left": region.x, + "top": region.y + i * line_h, + "width": region.width, + "height": line_h, + "conf": avg_conf, + "region_type": region.type, + }) + return words + + except Exception as e: + logger.error(f"ocr_region_trocr failed: {e}") + return [] + + +def ocr_region_lighton(img_bgr: np.ndarray, region: PageRegion) -> List[Dict[str, Any]]: + """Run LightOnOCR-2-1B on a region. Returns line-level word dicts (same format as ocr_region_rapid). + + Falls back to RapidOCR or Tesseract if LightOnOCR is not available. + """ + from services.lighton_ocr_service import get_lighton_model, _check_lighton_available + + if not _check_lighton_available(): + logger.warning("LightOnOCR not available, falling back to RapidOCR/Tesseract") + if RAPIDOCR_AVAILABLE and img_bgr is not None: + return ocr_region_rapid(img_bgr, region) + ocr_img_crop = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) if img_bgr is not None else None + return ocr_region(ocr_img_crop, region, lang="eng+deu", psm=6) if ocr_img_crop is not None else [] + + crop = img_bgr[region.y:region.y + region.height, region.x:region.x + region.width] + if crop.size == 0: + return [] + + try: + import io + import torch + from PIL import Image as _PILImage + + processor, model = get_lighton_model() + if processor is None or model is None: + logger.warning("LightOnOCR model not loaded, falling back to RapidOCR/Tesseract") + if RAPIDOCR_AVAILABLE and img_bgr is not None: + return ocr_region_rapid(img_bgr, region) + ocr_img_crop = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + return ocr_region(ocr_img_crop, region, lang="eng+deu", psm=6) + + pil_crop = _PILImage.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) + conversation = [{"role": "user", "content": [{"type": "image"}]}] + inputs = processor.apply_chat_template( + conversation, images=[pil_crop], + add_generation_prompt=True, return_tensors="pt" + ).to(model.device) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=1024) + + text = processor.decode(output_ids[0], skip_special_tokens=True).strip() + if not text: + return [] + + lines = [l.strip() for l in text.split("\n") if l.strip()] + line_h = region.height // max(len(lines), 1) + words = [] + for i, line in enumerate(lines): + words.append({ + "text": line, + "left": region.x, + "top": region.y + i * line_h, + "width": region.width, + "height": line_h, + "conf": 85, + "region_type": region.type, + }) + return words + + except Exception as e: + logger.error(f"ocr_region_lighton failed: {e}") + return [] + + # ============================================================================= # Post-Processing: Deterministic Quality Fixes # ============================================================================= @@ -3900,7 +4239,11 @@ def _ocr_single_cell( x=cell_x, y=cell_y, width=cell_w, height=cell_h, ) - if use_rapid and img_bgr is not None: + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + fallback_words = ocr_region_trocr(img_bgr, cell_region, handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + fallback_words = ocr_region_lighton(img_bgr, cell_region) + elif use_rapid and img_bgr is not None: fallback_words = ocr_region_rapid(img_bgr, cell_region) else: cell_lang = lang_map.get(col.type, lang) @@ -3981,8 +4324,8 @@ def build_cell_grid( img_w: Image width in pixels. img_h: Image height in pixels. lang: Default Tesseract language. - ocr_engine: 'tesseract', 'rapid', or 'auto'. - img_bgr: BGR color image (required for RapidOCR). + ocr_engine: 'tesseract', 'rapid', 'auto', 'trocr-printed', 'trocr-handwritten', or 'lighton'. + img_bgr: BGR color image (required for RapidOCR / TrOCR / LightOnOCR). Returns: (cells, columns_meta) where cells is a list of cell dicts and @@ -3990,15 +4333,20 @@ def build_cell_grid( """ # Resolve engine choice use_rapid = False - if ocr_engine == "auto": + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" elif ocr_engine == "rapid": if not RAPIDOCR_AVAILABLE: logger.warning("RapidOCR requested but not available, falling back to Tesseract") else: use_rapid = True + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" - engine_name = "rapid" if use_rapid else "tesseract" logger.info(f"build_cell_grid: using OCR engine '{engine_name}'") # Filter to content rows only (skip header/footer) @@ -4093,7 +4441,11 @@ def build_cell_grid( ) strip_lang = lang_map.get(relevant_cols[col_idx].type, lang) - if use_rapid and img_bgr is not None: + if engine_name in ("trocr-printed", "trocr-handwritten") and img_bgr is not None: + strip_words = ocr_region_trocr(img_bgr, strip_region, handwritten=(engine_name == "trocr-handwritten")) + elif engine_name == "lighton" and img_bgr is not None: + strip_words = ocr_region_lighton(img_bgr, strip_region) + elif use_rapid and img_bgr is not None: strip_words = ocr_region_rapid(img_bgr, strip_region) else: strip_words = ocr_region(ocr_img, strip_region, lang=strip_lang, psm=6) @@ -4169,15 +4521,19 @@ def build_cell_grid_streaming( """ # Resolve engine choice (same as build_cell_grid) use_rapid = False - if ocr_engine == "auto": + if ocr_engine in ("trocr-printed", "trocr-handwritten", "lighton"): + engine_name = ocr_engine + elif ocr_engine == "auto": use_rapid = RAPIDOCR_AVAILABLE and img_bgr is not None + engine_name = "rapid" if use_rapid else "tesseract" elif ocr_engine == "rapid": if not RAPIDOCR_AVAILABLE: logger.warning("RapidOCR requested but not available, falling back to Tesseract") else: use_rapid = True - - engine_name = "rapid" if use_rapid else "tesseract" + engine_name = "rapid" if use_rapid else "tesseract" + else: + engine_name = "tesseract" content_rows = [r for r in row_geometries if r.row_type == 'content'] if not content_rows: @@ -5026,8 +5382,10 @@ import os import json as _json import re as _re -_OLLAMA_URL = os.getenv("OLLAMA_URL", os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")) -OLLAMA_REVIEW_MODEL = os.getenv("OLLAMA_REVIEW_MODEL", "qwen3:30b-a3b") +_OLLAMA_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +OLLAMA_REVIEW_MODEL = os.getenv("OLLAMA_REVIEW_MODEL", "qwen3:0.6b") +_REVIEW_BATCH_SIZE = int(os.getenv("OLLAMA_REVIEW_BATCH_SIZE", "20")) +logger.info("LLM review model: %s (batch=%d)", OLLAMA_REVIEW_MODEL, _REVIEW_BATCH_SIZE) # Regex: entry contains IPA phonetic brackets like "dance [dɑːns]" _HAS_PHONETIC_RE = _re.compile(r'\[.*?[ˈˌːʃʒθðŋɑɒɔəɜɪʊʌæ].*?\]') @@ -5205,7 +5563,7 @@ async def llm_review_entries( async def llm_review_entries_streaming( entries: List[Dict], model: str = None, - batch_size: int = 8, + batch_size: int = _REVIEW_BATCH_SIZE, ): """Async generator: yield SSE events while reviewing entries in batches.""" model = model or OLLAMA_REVIEW_MODEL diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index cae425d..6224660 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -17,6 +17,7 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. import json import logging +import os import time import uuid from dataclasses import asdict @@ -25,7 +26,7 @@ from typing import Any, Dict, List, Optional import cv2 import numpy as np -from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile +from fastapi import APIRouter, File, Form, HTTPException, Query, Request, UploadFile from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel @@ -50,6 +51,7 @@ from cv_vocab_pipeline import ( deskew_image_by_word_alignment, detect_column_geometry, detect_row_geometry, + _apply_shear, dewarp_image, dewarp_image_manual, llm_review_entries, @@ -544,9 +546,75 @@ async def save_deskew_ground_truth(session_id: str, req: DeskewGroundTruthReques # Dewarp Endpoints # --------------------------------------------------------------------------- +async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + import re + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + import json + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} + + @router.post("/sessions/{session_id}/dewarp") -async def auto_dewarp(session_id: str): - """Detect and correct vertical shear on the deskewed image.""" +async def auto_dewarp( + session_id: str, + method: str = Query("ensemble", description="Detection method: ensemble | vlm | cv"), +): + """Detect and correct vertical shear on the deskewed image. + + Methods: + - **ensemble** (default): 3-method CV ensemble (vertical edges + projection + Hough) + - **cv**: CV ensemble only (same as ensemble) + - **vlm**: Ask qwen2.5vl:32b to estimate the shear angle visually + """ + if method not in ("ensemble", "cv", "vlm"): + raise HTTPException(status_code=400, detail="method must be one of: ensemble, cv, vlm") + if session_id not in _cache: await _load_session_to_cache(session_id) cached = _get_cached(session_id) @@ -556,7 +624,26 @@ async def auto_dewarp(session_id: str): raise HTTPException(status_code=400, detail="Deskew must be completed before dewarp") t0 = time.time() - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + if method == "vlm": + # Encode deskewed image to PNG for VLM + success, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + duration = time.time() - t0 # Encode as PNG @@ -568,6 +655,7 @@ async def auto_dewarp(session_id: str): "shear_degrees": dewarp_info["shear_degrees"], "confidence": dewarp_info["confidence"], "duration_seconds": round(duration, 2), + "detections": dewarp_info.get("detections", []), } # Update cache @@ -2000,3 +2088,505 @@ async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingReq "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", "session_id": session_id, } + + +# --------------------------------------------------------------------------- +# Auto-Mode Endpoint (Improvement 3) +# --------------------------------------------------------------------------- + +class RunAutoRequest(BaseModel): + from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review + ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" + pronunciation: str = "british" + skip_llm_review: bool = False + dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" + + +async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: + """Format a single SSE event line.""" + import json as _json + payload = {"step": step, "status": status, **data} + return f"data: {_json.dumps(payload)}\n\n" + + +@router.post("/sessions/{session_id}/run-auto") +async def run_auto(session_id: str, req: RunAutoRequest, request: Request): + """Run the full OCR pipeline automatically from a given step, streaming SSE progress. + + Steps: + 1. Deskew — straighten the scan + 2. Dewarp — correct vertical shear (ensemble CV or VLM) + 3. Columns — detect column layout + 4. Rows — detect row layout + 5. Words — OCR each cell + 6. LLM review — correct OCR errors (optional) + + Already-completed steps are skipped unless `from_step` forces a rerun. + Yields SSE events of the form: + data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} + + Final event: + data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} + """ + if req.from_step < 1 or req.from_step > 6: + raise HTTPException(status_code=400, detail="from_step must be 1-6") + if req.dewarp_method not in ("ensemble", "vlm", "cv"): + raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + + async def _generate(): + steps_run: List[str] = [] + steps_skipped: List[str] = [] + error_step: Optional[str] = None + + session = await get_session_db(session_id) + if not session: + yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) + return + + cached = _get_cached(session_id) + + # ----------------------------------------------------------------- + # Step 1: Deskew + # ----------------------------------------------------------------- + if req.from_step <= 1: + yield await _auto_sse_event("deskew", "start", {}) + try: + t0 = time.time() + orig_bgr = cached.get("original_bgr") + if orig_bgr is None: + raise ValueError("Original image not loaded") + + # Method 1: Hough lines + try: + deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) + except Exception: + deskewed_hough, angle_hough = orig_bgr, 0.0 + + # Method 2: Word alignment + success_enc, png_orig = cv2.imencode(".png", orig_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 + + # Pick best method + if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: + method_used = "word_alignment" + angle_applied = angle_wa + wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) + deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) + if deskewed_bgr is None: + deskewed_bgr = deskewed_hough + method_used = "hough" + angle_applied = angle_hough + else: + method_used = "hough" + angle_applied = angle_hough + deskewed_bgr = deskewed_hough + + success, png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = png_buf.tobytes() if success else b"" + + deskew_result = { + "method_used": method_used, + "rotation_degrees": round(float(angle_applied), 3), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["deskewed_bgr"] = deskewed_bgr + cached["deskew_result"] = deskew_result + await update_session_db( + session_id, + deskewed_png=deskewed_png, + deskew_result=deskew_result, + auto_rotation_degrees=float(angle_applied), + current_step=2, + ) + session = await get_session_db(session_id) + + steps_run.append("deskew") + yield await _auto_sse_event("deskew", "done", deskew_result) + except Exception as e: + logger.error(f"Auto-mode deskew failed for {session_id}: {e}") + error_step = "deskew" + yield await _auto_sse_event("deskew", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("deskew") + yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) + + # ----------------------------------------------------------------- + # Step 2: Dewarp + # ----------------------------------------------------------------- + if req.from_step <= 2: + yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) + try: + t0 = time.time() + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise ValueError("Deskewed image not available") + + if req.dewarp_method == "vlm": + success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success_enc else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success_enc else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(time.time() - t0, 2), + "detections": dewarp_info.get("detections", []), + } + + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=3, + ) + session = await get_session_db(session_id) + + steps_run.append("dewarp") + yield await _auto_sse_event("dewarp", "done", dewarp_result) + except Exception as e: + logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") + error_step = "dewarp" + yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("dewarp") + yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) + + # ----------------------------------------------------------------- + # Step 3: Columns + # ----------------------------------------------------------------- + if req.from_step <= 3: + yield await _auto_sse_event("columns", "start", {}) + try: + t0 = time.time() + dewarped_bgr = cached.get("dewarped_bgr") + if dewarped_bgr is None: + raise ValueError("Dewarped image not available") + + ocr_img = create_ocr_image(dewarped_bgr) + h, w = ocr_img.shape[:2] + + geo_result = detect_column_geometry(ocr_img, dewarped_bgr) + if geo_result is None: + layout_img = create_layout_image(dewarped_bgr) + regions = analyze_layout(layout_img, ocr_img) + cached["_word_dicts"] = None + cached["_inv"] = None + cached["_content_bounds"] = None + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + content_w = right_x - left_x + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + columns = [asdict(r) for r in regions] + column_result = { + "columns": columns, + "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["column_result"] = column_result + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None, current_step=4) + session = await get_session_db(session_id) + + steps_run.append("columns") + yield await _auto_sse_event("columns", "done", { + "column_count": len(columns), + "duration_seconds": column_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode columns failed for {session_id}: {e}") + error_step = "columns" + yield await _auto_sse_event("columns", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("columns") + yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) + + # ----------------------------------------------------------------- + # Step 4: Rows + # ----------------------------------------------------------------- + if req.from_step <= 4: + yield await _auto_sse_event("rows", "start", {}) + try: + t0 = time.time() + dewarped_bgr = cached.get("dewarped_bgr") + session = await get_session_db(session_id) + column_result = session.get("column_result") or cached.get("column_result") + if not column_result or not column_result.get("columns"): + raise ValueError("Column detection must complete first") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is None: + raise ValueError("Column geometry detection failed — cannot detect rows") + _g, lx, rx, ty, by, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (lx, rx, ty, by) + content_bounds = (lx, rx, ty, by) + + left_x, right_x, top_y, bottom_y = content_bounds + row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + row_list = [ + { + "index": r.index, "x": r.x, "y": r.y, + "width": r.width, "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + } + for r in row_geoms + ] + row_result = { + "rows": row_list, + "row_count": len(row_list), + "content_rows": len([r for r in row_geoms if r.row_type == "content"]), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["row_result"] = row_result + await update_session_db(session_id, row_result=row_result, current_step=5) + session = await get_session_db(session_id) + + steps_run.append("rows") + yield await _auto_sse_event("rows", "done", { + "row_count": len(row_list), + "content_rows": row_result["content_rows"], + "duration_seconds": row_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode rows failed for {session_id}: {e}") + error_step = "rows" + yield await _auto_sse_event("rows", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("rows") + yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) + + # ----------------------------------------------------------------- + # Step 5: Words (OCR) + # ----------------------------------------------------------------- + if req.from_step <= 5: + yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) + try: + t0 = time.time() + dewarped_bgr = cached.get("dewarped_bgr") + session = await get_session_db(session_id) + + column_result = session.get("column_result") or cached.get("column_result") + row_result = session.get("row_result") or cached.get("row_result") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + row_geoms = [ + RowGeometry( + index=r["index"], x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + word_dicts = cached.get("_word_dicts") + if word_dicts is not None: + content_bounds = cached.get("_content_bounds") + top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + cells, columns_meta = build_cell_grid( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=req.ocr_engine, img_bgr=dewarped_bgr, + ) + duration = time.time() - t0 + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine + + word_result_data = { + "cells": cells, + "grid_shape": { + "rows": n_content_rows, + "cols": len(columns_meta), + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + if is_vocab: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) + word_result_data["vocab_entries"] = entries + word_result_data["entries"] = entries + word_result_data["entry_count"] = len(entries) + word_result_data["summary"]["total_entries"] = len(entries) + + await update_session_db(session_id, word_result=word_result_data, current_step=6) + cached["word_result"] = word_result_data + session = await get_session_db(session_id) + + steps_run.append("words") + yield await _auto_sse_event("words", "done", { + "total_cells": len(cells), + "layout": word_result_data["layout"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": word_result_data["summary"], + }) + except Exception as e: + logger.error(f"Auto-mode words failed for {session_id}: {e}") + error_step = "words" + yield await _auto_sse_event("words", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("words") + yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) + + # ----------------------------------------------------------------- + # Step 6: LLM Review (optional) + # ----------------------------------------------------------------- + if req.from_step <= 6 and not req.skip_llm_review: + yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) + try: + session = await get_session_db(session_id) + word_result = session.get("word_result") or cached.get("word_result") + entries = word_result.get("entries") or word_result.get("vocab_entries") or [] + + if not entries: + yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) + steps_skipped.append("llm_review") + else: + reviewed = await llm_review_entries(entries) + + session = await get_session_db(session_id) + word_result_updated = dict(session.get("word_result") or {}) + word_result_updated["entries"] = reviewed + word_result_updated["vocab_entries"] = reviewed + word_result_updated["llm_reviewed"] = True + word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL + + await update_session_db(session_id, word_result=word_result_updated, current_step=7) + cached["word_result"] = word_result_updated + + steps_run.append("llm_review") + yield await _auto_sse_event("llm_review", "done", { + "entries_reviewed": len(reviewed), + "model": OLLAMA_REVIEW_MODEL, + }) + except Exception as e: + logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") + yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) + steps_skipped.append("llm_review") + else: + steps_skipped.append("llm_review") + reason = "skipped by request" if req.skip_llm_review else "from_step > 6" + yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) + + # ----------------------------------------------------------------- + # Final event + # ----------------------------------------------------------------- + yield await _auto_sse_event("complete", "done", { + "steps_run": steps_run, + "steps_skipped": steps_skipped, + }) + + return StreamingResponse( + _generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + )