diff --git a/.env.example b/.env.example index b5cdf99..5ba1abc 100644 --- a/.env.example +++ b/.env.example @@ -30,6 +30,23 @@ OLLAMA_VISION_MODEL=llama3.2-vision OLLAMA_CORRECTION_MODEL=llama3.2 OLLAMA_TIMEOUT=120 +# OCR-Pipeline: LLM-Review (Schritt 6) +# Kleine Modelle reichen fuer Zeichen-Korrekturen (0->O, 1->l, 5->S) +# Optionen: qwen3:0.6b, qwen3:1.7b, gemma3:1b, qwen3:30b-a3b +OLLAMA_REVIEW_MODEL=qwen3:0.6b +# Eintraege pro Ollama-Call. Groesser = weniger HTTP-Overhead. +OLLAMA_REVIEW_BATCH_SIZE=20 + +# OCR-Pipeline: Engine fuer Schritt 5 (Worterkennung) +# Optionen: auto (bevorzugt RapidOCR), rapid, tesseract, +# trocr-printed, trocr-handwritten, lighton +OCR_ENGINE=auto + +# Klausur-HTR: Primaerem Modell fuer Handschriftenerkennung (qwen2.5vl bereits auf Mac Mini) +OLLAMA_HTR_MODEL=qwen2.5vl:32b +# HTR Fallback: genutzt wenn Ollama nicht erreichbar (auto-download ~340 MB) +HTR_FALLBACK_MODEL=trocr-large + # Anthropic (optional) ANTHROPIC_API_KEY= diff --git a/docker-compose.yml b/docker-compose.yml index 962352e..7a3686a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ volumes: eh_uploads: ocr_labeling: paddle_models: + lighton_models: paddleocr_models: transcription_models: transcription_temp: @@ -209,6 +210,7 @@ services: - eh_uploads:/app/eh-uploads - ocr_labeling:/app/ocr-labeling - paddle_models:/root/.paddlex + - lighton_models:/root/.cache/huggingface environment: JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production} BACKEND_URL: http://backend-lehrer:8001 @@ -231,6 +233,11 @@ services: OLLAMA_DEFAULT_MODEL: ${OLLAMA_DEFAULT_MODEL:-llama3.2} OLLAMA_VISION_MODEL: ${OLLAMA_VISION_MODEL:-llama3.2-vision} OLLAMA_CORRECTION_MODEL: ${OLLAMA_CORRECTION_MODEL:-llama3.2} + OLLAMA_REVIEW_MODEL: ${OLLAMA_REVIEW_MODEL:-qwen3:0.6b} + OLLAMA_REVIEW_BATCH_SIZE: ${OLLAMA_REVIEW_BATCH_SIZE:-20} + OCR_ENGINE: ${OCR_ENGINE:-auto} + OLLAMA_HTR_MODEL: ${OLLAMA_HTR_MODEL:-qwen2.5vl:32b} + HTR_FALLBACK_MODEL: ${HTR_FALLBACK_MODEL:-trocr-large} RAG_SERVICE_URL: http://bp-core-rag-service:8097 extra_hosts: - "host.docker.internal:host-gateway" diff --git a/klausur-service/backend/handwriting_htr_api.py b/klausur-service/backend/handwriting_htr_api.py new file mode 100644 index 0000000..2976069 --- /dev/null +++ b/klausur-service/backend/handwriting_htr_api.py @@ -0,0 +1,276 @@ +""" +Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen. + +Endpoints: +- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text +- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen + +Modell-Strategie: + 1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM) + 2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama) + +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini. +""" + +import io +import os +import logging +import time +import base64 +from typing import Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Query, UploadFile, File +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/htr", tags=["HTR"]) + +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") +HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large") + + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +class HTRSessionRequest(BaseModel): + session_id: str + model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large" + use_clean: bool = True # Prefer clean_png (after handwriting removal) + + +# --------------------------------------------------------------------------- +# Preprocessing +# --------------------------------------------------------------------------- + +def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray: + """ + CLAHE contrast enhancement + upscale to improve HTR accuracy. + Returns grayscale enhanced image. + """ + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(gray) + + # Upscale if image is too small + h, w = enhanced.shape + if min(h, w) < 800: + scale = 800 / min(h, w) + enhanced = cv2.resize( + enhanced, None, fx=scale, fy=scale, + interpolation=cv2.INTER_CUBIC + ) + + return enhanced + + +def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes: + """Convert BGR ndarray to PNG bytes.""" + success, buf = cv2.imencode(".png", img_bgr) + if not success: + raise RuntimeError("Failed to encode image to PNG") + return buf.tobytes() + + +def _preprocess_image_bytes(image_bytes: bytes) -> bytes: + """Load image, apply HTR preprocessing, return PNG bytes.""" + arr = np.frombuffer(image_bytes, dtype=np.uint8) + img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise ValueError("Could not decode image") + + enhanced = _preprocess_for_htr(img_bgr) + # Convert grayscale back to BGR for encoding + enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) + return _bgr_to_png_bytes(enhanced_bgr) + + +# --------------------------------------------------------------------------- +# Backend: Ollama qwen2.5vl +# --------------------------------------------------------------------------- + +async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]: + """ + Send image to Ollama qwen2.5vl:32b for HTR. + Returns extracted text or None on error. + """ + import httpx + + lang_hint = { + "de": "Deutsch", + "en": "Englisch", + "de+en": "Deutsch und Englisch", + }.get(language, "Deutsch") + + prompt = ( + f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. " + "Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. " + "Antworte NUR mit dem erkannten Text, ohne Erklaerungen." + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + + payload = { + "model": OLLAMA_HTR_MODEL, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload) + resp.raise_for_status() + data = resp.json() + return data.get("response", "").strip() + except Exception as e: + logger.warning(f"Ollama qwen2.5vl HTR failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Backend: TrOCR-large fallback +# --------------------------------------------------------------------------- + +async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]: + """ + Use microsoft/trocr-large-handwritten via trocr_service.py. + Returns extracted text or None on error. + """ + try: + from services.trocr_service import run_trocr_ocr, _check_trocr_available + if not _check_trocr_available(): + logger.warning("TrOCR not available for HTR fallback") + return None + + text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large") + return text.strip() if text else None + except Exception as e: + logger.warning(f"TrOCR-large HTR failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Core recognition logic +# --------------------------------------------------------------------------- + +async def _do_recognize( + image_bytes: bytes, + model: str = "auto", + preprocess: bool = True, + language: str = "de", +) -> dict: + """ + Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large. + Returns dict with text, model_used, processing_time_ms. + """ + t0 = time.monotonic() + + if preprocess: + try: + image_bytes = _preprocess_image_bytes(image_bytes) + except Exception as e: + logger.warning(f"HTR preprocessing failed, using raw image: {e}") + + text: Optional[str] = None + model_used: str = "none" + + use_qwen = model in ("auto", "qwen2.5vl") + use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None) + + if use_qwen: + text = await _recognize_with_qwen_vl(image_bytes, language) + if text is not None: + model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})" + + if text is None and (use_trocr or model == "trocr-large"): + text = await _recognize_with_trocr_large(image_bytes) + if text is not None: + model_used = "trocr-large-handwritten" + + if text is None: + text = "" + model_used = "none (all backends failed)" + + elapsed_ms = int((time.monotonic() - t0) * 1000) + + return { + "text": text, + "model_used": model_used, + "processing_time_ms": elapsed_ms, + "language": language, + "preprocessed": preprocess, + } + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@router.post("/recognize") +async def recognize_handwriting( + file: UploadFile = File(...), + model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"), + preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"), + language: str = Query("de", description="de | en | de+en"), +): + """ + Upload an image and get back the handwritten text as plain text. + + Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten. + """ + if model not in ("auto", "qwen2.5vl", "trocr-large"): + raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large") + if language not in ("de", "en", "de+en"): + raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en") + + image_bytes = await file.read() + if not image_bytes: + raise HTTPException(status_code=400, detail="Empty file") + + return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language) + + +@router.post("/recognize-session") +async def recognize_from_session(req: HTRSessionRequest): + """ + Use an OCR-Pipeline session as image source for HTR. + + Set use_clean=true to prefer the clean image (after handwriting removal step). + This is useful when you want to do HTR on isolated handwriting regions. + """ + from ocr_pipeline_session_store import get_session_db, get_session_image + + session = await get_session_db(req.session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found") + + # Choose source image + image_bytes: Optional[bytes] = None + source_used: str = "" + + if req.use_clean: + image_bytes = await get_session_image(req.session_id, "clean") + if image_bytes: + source_used = "clean" + + if not image_bytes: + image_bytes = await get_session_image(req.session_id, "deskewed") + if image_bytes: + source_used = "deskewed" + + if not image_bytes: + image_bytes = await get_session_image(req.session_id, "original") + source_used = "original" + + if not image_bytes: + raise HTTPException(status_code=404, detail="No image available in session") + + result = await _do_recognize(image_bytes, model=req.model) + result["session_id"] = req.session_id + result["source_image"] = source_used + return result diff --git a/klausur-service/backend/main.py b/klausur-service/backend/main.py index 51887c1..4c1ef18 100644 --- a/klausur-service/backend/main.py +++ b/klausur-service/backend/main.py @@ -44,6 +44,10 @@ except ImportError: from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL from ocr_pipeline_api import router as ocr_pipeline_router from ocr_pipeline_session_store import init_ocr_pipeline_tables +try: + from handwriting_htr_api import router as htr_router +except ImportError: + htr_router = None try: from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL @@ -113,6 +117,19 @@ async def lifespan(app: FastAPI): # Ensure EH upload directory exists os.makedirs(EH_UPLOAD_DIR, exist_ok=True) + # Preload LightOnOCR model if OCR_ENGINE=lighton (avoids cold-start on first request) + ocr_engine_env = os.getenv("OCR_ENGINE", "auto") + if ocr_engine_env == "lighton": + try: + import asyncio + from services.lighton_ocr_service import get_lighton_model + loop = asyncio.get_event_loop() + print("Preloading LightOnOCR-2-1B at startup (OCR_ENGINE=lighton)...") + await loop.run_in_executor(None, get_lighton_model) + print("LightOnOCR-2-1B preloaded") + except Exception as e: + print(f"Warning: LightOnOCR preload failed: {e}") + yield print("Klausur-Service shutting down...") @@ -160,6 +177,8 @@ if trocr_router: app.include_router(trocr_router) # TrOCR Handwriting OCR app.include_router(vocab_router) # Vocabulary Worksheet Generator app.include_router(ocr_pipeline_router) # OCR Pipeline (step-by-step) +if htr_router: + app.include_router(htr_router) # Handwriting HTR (Klausur) if dsfa_rag_router: app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index a989c4e..cae425d 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -168,6 +168,13 @@ class RowGroundTruthRequest(BaseModel): notes: Optional[str] = None +class RemoveHandwritingRequest(BaseModel): + method: str = "auto" # "auto" | "telea" | "ns" + target_ink: str = "all" # "all" | "colored" | "pencil" + dilation: int = 2 # mask dilation iterations (0-5) + use_source: str = "auto" # "original" | "deskewed" | "auto" + + # --------------------------------------------------------------------------- # Session Management Endpoints # --------------------------------------------------------------------------- @@ -309,7 +316,7 @@ async def delete_session(session_id: str): @router.get("/sessions/{session_id}/image/{image_type}") async def get_image(session_id: str, image_type: str): """Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay.""" - valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"} + valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay", "clean"} if image_type not in valid_types: raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") @@ -1906,3 +1913,90 @@ async def _get_words_overlay(session_id: str) -> Response: raise HTTPException(status_code=500, detail="Failed to encode overlay image") return Response(content=result_png.tobytes(), media_type="image/png") + + +# --------------------------------------------------------------------------- +# Handwriting Removal Endpoint +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/remove-handwriting") +async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): + """ + Remove handwriting from a session image using inpainting. + + Steps: + 1. Load source image (auto → deskewed if available, else original) + 2. Detect handwriting mask (filtered by target_ink) + 3. Dilate mask to cover stroke edges + 4. Inpaint the image + 5. Store result as clean_png in the session + + Returns metadata including the URL to fetch the clean image. + """ + import time as _time + t0 = _time.monotonic() + + from services.handwriting_detection import detect_handwriting + from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # 1. Determine source image + source = req.use_source + if source == "auto": + deskewed = await get_session_image(session_id, "deskewed") + source = "deskewed" if deskewed else "original" + + image_bytes = await get_session_image(session_id, source) + if not image_bytes: + raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") + + # 2. Detect handwriting mask + detection = detect_handwriting(image_bytes, target_ink=req.target_ink) + + # 3. Convert mask to PNG bytes and dilate + import io + from PIL import Image as _PILImage + mask_img = _PILImage.fromarray(detection.mask) + mask_buf = io.BytesIO() + mask_img.save(mask_buf, format="PNG") + mask_bytes = mask_buf.getvalue() + + if req.dilation > 0: + mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) + + # 4. Inpaint + method_map = { + "telea": InpaintingMethod.OPENCV_TELEA, + "ns": InpaintingMethod.OPENCV_NS, + "auto": InpaintingMethod.AUTO, + } + inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) + + result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) + if not result.success: + raise HTTPException(status_code=500, detail="Inpainting failed") + + elapsed_ms = int((_time.monotonic() - t0) * 1000) + + meta = { + "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), + "handwriting_ratio": round(detection.handwriting_ratio, 4), + "detection_confidence": round(detection.confidence, 4), + "target_ink": req.target_ink, + "dilation": req.dilation, + "source_image": source, + "processing_time_ms": elapsed_ms, + } + + # 5. Persist clean image (convert BGR ndarray → PNG bytes) + clean_png_bytes = image_to_png(result.image) + await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) + + return { + **meta, + "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", + "session_id": session_id, + } diff --git a/klausur-service/backend/ocr_pipeline_session_store.py b/klausur-service/backend/ocr_pipeline_session_store.py index 84343c6..8c58def 100644 --- a/klausur-service/backend/ocr_pipeline_session_store.py +++ b/klausur-service/backend/ocr_pipeline_session_store.py @@ -60,6 +60,13 @@ async def init_ocr_pipeline_tables(): else: logger.debug("OCR pipeline tables already exist") + # Ensure new columns exist (idempotent ALTER TABLE) + await conn.execute(""" + ALTER TABLE ocr_pipeline_sessions + ADD COLUMN IF NOT EXISTS clean_png BYTEA, + ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB + """) + # ============================================================================= # SESSION CRUD @@ -111,6 +118,7 @@ async def get_session_image(session_id: str, image_type: str) -> Optional[bytes] "deskewed": "deskewed_png", "binarized": "binarized_png", "dewarped": "dewarped_png", + "clean": "clean_png", } column = column_map.get(image_type) if not column: @@ -135,11 +143,12 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any allowed_fields = { 'name', 'filename', 'status', 'current_step', 'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png', + 'clean_png', 'handwriting_removal_meta', 'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'auto_shear_degrees', } - jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'} + jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'} for key, value in kwargs.items(): if key in allowed_fields: diff --git a/klausur-service/backend/services/handwriting_detection.py b/klausur-service/backend/services/handwriting_detection.py index 081f177..2537ad3 100644 --- a/klausur-service/backend/services/handwriting_detection.py +++ b/klausur-service/backend/services/handwriting_detection.py @@ -6,6 +6,7 @@ Uses multiple detection methods: 1. Color-based detection (blue/red ink) 2. Stroke analysis (thin irregular strokes) 3. Edge density variance +4. Pencil detection (gray ink) DATENSCHUTZ: All processing happens locally on Mac Mini. """ @@ -37,12 +38,16 @@ class DetectionResult: detection_method: str # Which method was primarily used -def detect_handwriting(image_bytes: bytes) -> DetectionResult: +def detect_handwriting(image_bytes: bytes, target_ink: str = "all") -> DetectionResult: """ Detect handwriting in an image. Args: image_bytes: Image as bytes (PNG, JPG, etc.) + target_ink: Which ink types to detect: + - "all" → all methods combined (incl. pencil) + - "colored" → only color-based (blue/red/green pen) + - "pencil" → only pencil (gray ink) Returns: DetectionResult with binary mask where handwriting is white (255) @@ -62,35 +67,51 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult: # Convert to BGR if needed (OpenCV format) if len(img_array.shape) == 2: - # Grayscale to BGR img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR) elif img_array.shape[2] == 4: - # RGBA to BGR img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR) elif img_array.shape[2] == 3: - # RGB to BGR img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) else: img_bgr = img_array - # Run multiple detection methods - color_mask, color_confidence = _detect_by_color(img_bgr) - stroke_mask, stroke_confidence = _detect_by_stroke_analysis(img_bgr) - variance_mask, variance_confidence = _detect_by_variance(img_bgr) + # Select detection methods based on target_ink + masks_and_weights = [] + + if target_ink in ("all", "colored"): + color_mask, color_conf = _detect_by_color(img_bgr) + masks_and_weights.append((color_mask, color_conf, "color")) + + if target_ink == "all": + stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr) + variance_mask, variance_conf = _detect_by_variance(img_bgr) + masks_and_weights.append((stroke_mask, stroke_conf, "stroke")) + masks_and_weights.append((variance_mask, variance_conf, "variance")) + + if target_ink in ("all", "pencil"): + pencil_mask, pencil_conf = _detect_pencil(img_bgr) + masks_and_weights.append((pencil_mask, pencil_conf, "pencil")) + + if not masks_and_weights: + # Fallback: use all methods + color_mask, color_conf = _detect_by_color(img_bgr) + stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr) + variance_mask, variance_conf = _detect_by_variance(img_bgr) + pencil_mask, pencil_conf = _detect_pencil(img_bgr) + masks_and_weights = [ + (color_mask, color_conf, "color"), + (stroke_mask, stroke_conf, "stroke"), + (variance_mask, variance_conf, "variance"), + (pencil_mask, pencil_conf, "pencil"), + ] # Combine masks using weighted average - weights = [color_confidence, stroke_confidence, variance_confidence] - total_weight = sum(weights) + total_weight = sum(w for _, w, _ in masks_and_weights) if total_weight > 0: - # Weighted combination - combined_mask = ( - color_mask.astype(np.float32) * color_confidence + - stroke_mask.astype(np.float32) * stroke_confidence + - variance_mask.astype(np.float32) * variance_confidence + combined_mask = sum( + m.astype(np.float32) * w for m, w, _ in masks_and_weights ) / total_weight - - # Threshold to binary combined_mask = (combined_mask > 127).astype(np.uint8) * 255 else: combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8) @@ -103,19 +124,11 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult: handwriting_pixels = np.sum(combined_mask > 0) handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0 - # Determine primary method - primary_method = "combined" - max_conf = max(color_confidence, stroke_confidence, variance_confidence) - if max_conf == color_confidence: - primary_method = "color" - elif max_conf == stroke_confidence: - primary_method = "stroke" - else: - primary_method = "variance" + # Determine primary method (highest confidence) + primary_method = max(masks_and_weights, key=lambda x: x[1])[2] if masks_and_weights else "combined" + overall_confidence = total_weight / len(masks_and_weights) if masks_and_weights else 0.0 - overall_confidence = total_weight / 3.0 # Average confidence - - logger.info(f"Handwriting detection: {handwriting_ratio:.2%} handwriting, " + logger.info(f"Handwriting detection (target_ink={target_ink}): {handwriting_ratio:.2%} handwriting, " f"confidence={overall_confidence:.2f}, method={primary_method}") return DetectionResult( @@ -180,6 +193,27 @@ def _detect_by_color(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]: return color_mask, confidence +def _detect_pencil(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]: + """ + Detect pencil marks (gray ink, ~140-220 on 255-scale). + + Paper is usually >230, dark ink <130. + Pencil falls in the 140-220 gray range. + """ + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + pencil_mask = cv2.inRange(gray, 140, 220) + + # Remove small noise artifacts + kernel = np.ones((2, 2), np.uint8) + pencil_mask = cv2.morphologyEx(pencil_mask, cv2.MORPH_OPEN, kernel, iterations=1) + + ratio = np.sum(pencil_mask > 0) / pencil_mask.size + # Good confidence if pencil pixels are in a plausible range + confidence = 0.75 if 0.002 < ratio < 0.2 else 0.2 + + return pencil_mask, confidence + + def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]: """ Detect handwriting by analyzing stroke characteristics. diff --git a/klausur-service/backend/services/trocr_service.py b/klausur-service/backend/services/trocr_service.py index 0715d1c..1ff32fa 100644 --- a/klausur-service/backend/services/trocr_service.py +++ b/klausur-service/backend/services/trocr_service.py @@ -31,8 +31,10 @@ from datetime import datetime, timedelta logger = logging.getLogger(__name__) # Lazy loading for heavy dependencies -_trocr_processor = None -_trocr_model = None +# Cache keyed by model_name to support base and large variants simultaneously +_trocr_models: dict = {} # {model_name: (processor, model)} +_trocr_processor = None # backwards-compat alias → base-printed +_trocr_model = None # backwards-compat alias → base-printed _trocr_available = None _model_loaded_at = None @@ -124,12 +126,14 @@ def _check_trocr_available() -> bool: return _trocr_available -def get_trocr_model(handwritten: bool = False): +def get_trocr_model(handwritten: bool = False, size: str = "base"): """ Lazy load TrOCR model and processor. Args: handwritten: Use handwritten model instead of printed model + size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy + for exam HTR). Only applies to handwritten variant. Returns tuple of (processor, model) or (None, None) if unavailable. """ @@ -138,31 +142,42 @@ def get_trocr_model(handwritten: bool = False): if not _check_trocr_available(): return None, None - if _trocr_processor is None or _trocr_model is None: - try: - import torch - from transformers import TrOCRProcessor, VisionEncoderDecoderModel + # Select model name + if size == "large" and handwritten: + model_name = "microsoft/trocr-large-handwritten" + elif handwritten: + model_name = "microsoft/trocr-base-handwritten" + else: + model_name = "microsoft/trocr-base-printed" - # Choose model based on use case - if handwritten: - model_name = "microsoft/trocr-base-handwritten" - else: - model_name = "microsoft/trocr-base-printed" + if model_name in _trocr_models: + return _trocr_models[model_name] - logger.info(f"Loading TrOCR model: {model_name}") - _trocr_processor = TrOCRProcessor.from_pretrained(model_name) - _trocr_model = VisionEncoderDecoderModel.from_pretrained(model_name) + try: + import torch + from transformers import TrOCRProcessor, VisionEncoderDecoderModel - # Use GPU if available - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - _trocr_model.to(device) - logger.info(f"TrOCR model loaded on device: {device}") + logger.info(f"Loading TrOCR model: {model_name}") + processor = TrOCRProcessor.from_pretrained(model_name) + model = VisionEncoderDecoderModel.from_pretrained(model_name) - except Exception as e: - logger.error(f"Failed to load TrOCR model: {e}") - return None, None + # Use GPU if available + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + model.to(device) + logger.info(f"TrOCR model loaded on device: {device}") - return _trocr_processor, _trocr_model + _trocr_models[model_name] = (processor, model) + + # Keep backwards-compat globals pointing at base-printed + if model_name == "microsoft/trocr-base-printed": + _trocr_processor = processor + _trocr_model = model + + return processor, model + + except Exception as e: + logger.error(f"Failed to load TrOCR model {model_name}: {e}") + return None, None def preload_trocr_model(handwritten: bool = True) -> bool: @@ -209,7 +224,8 @@ def get_model_status() -> Dict[str, Any]: async def run_trocr_ocr( image_data: bytes, handwritten: bool = False, - split_lines: bool = True + split_lines: bool = True, + size: str = "base", ) -> Tuple[Optional[str], float]: """ Run TrOCR on an image. @@ -223,11 +239,12 @@ async def run_trocr_ocr( image_data: Raw image bytes handwritten: Use handwritten model (slower but better for handwriting) split_lines: Whether to split image into lines first + size: "base" or "large" (only for handwritten variant) Returns: Tuple of (extracted_text, confidence) """ - processor, model = get_trocr_model(handwritten=handwritten) + processor, model = get_trocr_model(handwritten=handwritten, size=size) if processor is None or model is None: logger.error("TrOCR model not available")