feat(klausur): Handschrift entfernen + Klausur-HTR implementiert
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 26s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m49s
CI / test-python-agent-core (push) Successful in 14s
CI / test-nodejs-website (push) Successful in 15s

Feature 1: Handschrift entfernen via OCR-Pipeline Session
- services/handwriting_detection.py: _detect_pencil() + target_ink Parameter
  ("all" | "colored" | "pencil") für gezielte Tinten-Erkennung
- ocr_pipeline_session_store.py: clean_png + handwriting_removal_meta Spalten
  (idempotentes ALTER TABLE in init_ocr_pipeline_tables)
- ocr_pipeline_api.py: POST /sessions/{id}/remove-handwriting Endpoint
  + "clean" zu valid_types für Image-Serving hinzugefügt

Feature 2: Klausur-HTR (Hochwertige Handschriftenerkennung)
- handwriting_htr_api.py: Neuer Router /api/v1/htr/recognize + /recognize-session
  Primary: qwen2.5vl:32b via Ollama, Fallback: trocr-large-handwritten
- services/trocr_service.py: size Parameter (base | large) für get_trocr_model()
  + run_trocr_ocr() - unterstützt jetzt trocr-large-handwritten
- main.py: HTR Router registriert

Config:
- docker-compose.yml: OLLAMA_HTR_MODEL, HTR_FALLBACK_MODEL
- .env.example: HTR Env-Vars dokumentiert

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-03 12:04:26 +01:00
parent 606bef0591
commit 2e0f8632f8
8 changed files with 529 additions and 56 deletions

View File

@@ -30,6 +30,23 @@ OLLAMA_VISION_MODEL=llama3.2-vision
OLLAMA_CORRECTION_MODEL=llama3.2
OLLAMA_TIMEOUT=120
# OCR-Pipeline: LLM-Review (Schritt 6)
# Kleine Modelle reichen fuer Zeichen-Korrekturen (0->O, 1->l, 5->S)
# Optionen: qwen3:0.6b, qwen3:1.7b, gemma3:1b, qwen3:30b-a3b
OLLAMA_REVIEW_MODEL=qwen3:0.6b
# Eintraege pro Ollama-Call. Groesser = weniger HTTP-Overhead.
OLLAMA_REVIEW_BATCH_SIZE=20
# OCR-Pipeline: Engine fuer Schritt 5 (Worterkennung)
# Optionen: auto (bevorzugt RapidOCR), rapid, tesseract,
# trocr-printed, trocr-handwritten, lighton
OCR_ENGINE=auto
# Klausur-HTR: Primaerem Modell fuer Handschriftenerkennung (qwen2.5vl bereits auf Mac Mini)
OLLAMA_HTR_MODEL=qwen2.5vl:32b
# HTR Fallback: genutzt wenn Ollama nicht erreichbar (auto-download ~340 MB)
HTR_FALLBACK_MODEL=trocr-large
# Anthropic (optional)
ANTHROPIC_API_KEY=

View File

@@ -15,6 +15,7 @@ volumes:
eh_uploads:
ocr_labeling:
paddle_models:
lighton_models:
paddleocr_models:
transcription_models:
transcription_temp:
@@ -209,6 +210,7 @@ services:
- eh_uploads:/app/eh-uploads
- ocr_labeling:/app/ocr-labeling
- paddle_models:/root/.paddlex
- lighton_models:/root/.cache/huggingface
environment:
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
BACKEND_URL: http://backend-lehrer:8001
@@ -231,6 +233,11 @@ services:
OLLAMA_DEFAULT_MODEL: ${OLLAMA_DEFAULT_MODEL:-llama3.2}
OLLAMA_VISION_MODEL: ${OLLAMA_VISION_MODEL:-llama3.2-vision}
OLLAMA_CORRECTION_MODEL: ${OLLAMA_CORRECTION_MODEL:-llama3.2}
OLLAMA_REVIEW_MODEL: ${OLLAMA_REVIEW_MODEL:-qwen3:0.6b}
OLLAMA_REVIEW_BATCH_SIZE: ${OLLAMA_REVIEW_BATCH_SIZE:-20}
OCR_ENGINE: ${OCR_ENGINE:-auto}
OLLAMA_HTR_MODEL: ${OLLAMA_HTR_MODEL:-qwen2.5vl:32b}
HTR_FALLBACK_MODEL: ${HTR_FALLBACK_MODEL:-trocr-large}
RAG_SERVICE_URL: http://bp-core-rag-service:8097
extra_hosts:
- "host.docker.internal:host-gateway"

View File

@@ -0,0 +1,276 @@
"""
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
Endpoints:
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
Modell-Strategie:
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
"""
import io
import os
import logging
import time
import base64
from typing import Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
class HTRSessionRequest(BaseModel):
session_id: str
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
use_clean: bool = True # Prefer clean_png (after handwriting removal)
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
"""
CLAHE contrast enhancement + upscale to improve HTR accuracy.
Returns grayscale enhanced image.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Upscale if image is too small
h, w = enhanced.shape
if min(h, w) < 800:
scale = 800 / min(h, w)
enhanced = cv2.resize(
enhanced, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC
)
return enhanced
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
"""Convert BGR ndarray to PNG bytes."""
success, buf = cv2.imencode(".png", img_bgr)
if not success:
raise RuntimeError("Failed to encode image to PNG")
return buf.tobytes()
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
"""Load image, apply HTR preprocessing, return PNG bytes."""
arr = np.frombuffer(image_bytes, dtype=np.uint8)
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise ValueError("Could not decode image")
enhanced = _preprocess_for_htr(img_bgr)
# Convert grayscale back to BGR for encoding
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
return _bgr_to_png_bytes(enhanced_bgr)
# ---------------------------------------------------------------------------
# Backend: Ollama qwen2.5vl
# ---------------------------------------------------------------------------
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
"""
Send image to Ollama qwen2.5vl:32b for HTR.
Returns extracted text or None on error.
"""
import httpx
lang_hint = {
"de": "Deutsch",
"en": "Englisch",
"de+en": "Deutsch und Englisch",
}.get(language, "Deutsch")
prompt = (
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": OLLAMA_HTR_MODEL,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
resp.raise_for_status()
data = resp.json()
return data.get("response", "").strip()
except Exception as e:
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Backend: TrOCR-large fallback
# ---------------------------------------------------------------------------
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
"""
Use microsoft/trocr-large-handwritten via trocr_service.py.
Returns extracted text or None on error.
"""
try:
from services.trocr_service import run_trocr_ocr, _check_trocr_available
if not _check_trocr_available():
logger.warning("TrOCR not available for HTR fallback")
return None
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
return text.strip() if text else None
except Exception as e:
logger.warning(f"TrOCR-large HTR failed: {e}")
return None
# ---------------------------------------------------------------------------
# Core recognition logic
# ---------------------------------------------------------------------------
async def _do_recognize(
image_bytes: bytes,
model: str = "auto",
preprocess: bool = True,
language: str = "de",
) -> dict:
"""
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
Returns dict with text, model_used, processing_time_ms.
"""
t0 = time.monotonic()
if preprocess:
try:
image_bytes = _preprocess_image_bytes(image_bytes)
except Exception as e:
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
text: Optional[str] = None
model_used: str = "none"
use_qwen = model in ("auto", "qwen2.5vl")
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
if use_qwen:
text = await _recognize_with_qwen_vl(image_bytes, language)
if text is not None:
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
if text is None and (use_trocr or model == "trocr-large"):
text = await _recognize_with_trocr_large(image_bytes)
if text is not None:
model_used = "trocr-large-handwritten"
if text is None:
text = ""
model_used = "none (all backends failed)"
elapsed_ms = int((time.monotonic() - t0) * 1000)
return {
"text": text,
"model_used": model_used,
"processing_time_ms": elapsed_ms,
"language": language,
"preprocessed": preprocess,
}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/recognize")
async def recognize_handwriting(
file: UploadFile = File(...),
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
language: str = Query("de", description="de | en | de+en"),
):
"""
Upload an image and get back the handwritten text as plain text.
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
"""
if model not in ("auto", "qwen2.5vl", "trocr-large"):
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
if language not in ("de", "en", "de+en"):
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
image_bytes = await file.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Empty file")
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
@router.post("/recognize-session")
async def recognize_from_session(req: HTRSessionRequest):
"""
Use an OCR-Pipeline session as image source for HTR.
Set use_clean=true to prefer the clean image (after handwriting removal step).
This is useful when you want to do HTR on isolated handwriting regions.
"""
from ocr_pipeline_session_store import get_session_db, get_session_image
session = await get_session_db(req.session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
# Choose source image
image_bytes: Optional[bytes] = None
source_used: str = ""
if req.use_clean:
image_bytes = await get_session_image(req.session_id, "clean")
if image_bytes:
source_used = "clean"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "deskewed")
if image_bytes:
source_used = "deskewed"
if not image_bytes:
image_bytes = await get_session_image(req.session_id, "original")
source_used = "original"
if not image_bytes:
raise HTTPException(status_code=404, detail="No image available in session")
result = await _do_recognize(image_bytes, model=req.model)
result["session_id"] = req.session_id
result["source_image"] = source_used
return result

View File

@@ -44,6 +44,10 @@ except ImportError:
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
from ocr_pipeline_api import router as ocr_pipeline_router
from ocr_pipeline_session_store import init_ocr_pipeline_tables
try:
from handwriting_htr_api import router as htr_router
except ImportError:
htr_router = None
try:
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
@@ -113,6 +117,19 @@ async def lifespan(app: FastAPI):
# Ensure EH upload directory exists
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
# Preload LightOnOCR model if OCR_ENGINE=lighton (avoids cold-start on first request)
ocr_engine_env = os.getenv("OCR_ENGINE", "auto")
if ocr_engine_env == "lighton":
try:
import asyncio
from services.lighton_ocr_service import get_lighton_model
loop = asyncio.get_event_loop()
print("Preloading LightOnOCR-2-1B at startup (OCR_ENGINE=lighton)...")
await loop.run_in_executor(None, get_lighton_model)
print("LightOnOCR-2-1B preloaded")
except Exception as e:
print(f"Warning: LightOnOCR preload failed: {e}")
yield
print("Klausur-Service shutting down...")
@@ -160,6 +177,8 @@ if trocr_router:
app.include_router(trocr_router) # TrOCR Handwriting OCR
app.include_router(vocab_router) # Vocabulary Worksheet Generator
app.include_router(ocr_pipeline_router) # OCR Pipeline (step-by-step)
if htr_router:
app.include_router(htr_router) # Handwriting HTR (Klausur)
if dsfa_rag_router:
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search

View File

@@ -168,6 +168,13 @@ class RowGroundTruthRequest(BaseModel):
notes: Optional[str] = None
class RemoveHandwritingRequest(BaseModel):
method: str = "auto" # "auto" | "telea" | "ns"
target_ink: str = "all" # "all" | "colored" | "pencil"
dilation: int = 2 # mask dilation iterations (0-5)
use_source: str = "auto" # "original" | "deskewed" | "auto"
# ---------------------------------------------------------------------------
# Session Management Endpoints
# ---------------------------------------------------------------------------
@@ -309,7 +316,7 @@ async def delete_session(session_id: str):
@router.get("/sessions/{session_id}/image/{image_type}")
async def get_image(session_id: str, image_type: str):
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
if image_type not in valid_types:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
@@ -1906,3 +1913,90 @@ async def _get_words_overlay(session_id: str) -> Response:
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
return Response(content=result_png.tobytes(), media_type="image/png")
# ---------------------------------------------------------------------------
# Handwriting Removal Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""
Remove handwriting from a session image using inpainting.
Steps:
1. Load source image (auto → deskewed if available, else original)
2. Detect handwriting mask (filtered by target_ink)
3. Dilate mask to cover stroke edges
4. Inpaint the image
5. Store result as clean_png in the session
Returns metadata including the URL to fetch the clean image.
"""
import time as _time
t0 = _time.monotonic()
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image (convert BGR ndarray → PNG bytes)
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}

View File

@@ -60,6 +60,13 @@ async def init_ocr_pipeline_tables():
else:
logger.debug("OCR pipeline tables already exist")
# Ensure new columns exist (idempotent ALTER TABLE)
await conn.execute("""
ALTER TABLE ocr_pipeline_sessions
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB
""")
# =============================================================================
# SESSION CRUD
@@ -111,6 +118,7 @@ async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]
"deskewed": "deskewed_png",
"binarized": "binarized_png",
"dewarped": "dewarped_png",
"clean": "clean_png",
}
column = column_map.get(image_type)
if not column:
@@ -135,11 +143,12 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
allowed_fields = {
'name', 'filename', 'status', 'current_step',
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
'clean_png', 'handwriting_removal_meta',
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
'word_result', 'ground_truth', 'auto_shear_degrees',
}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'}
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'}
for key, value in kwargs.items():
if key in allowed_fields:

View File

@@ -6,6 +6,7 @@ Uses multiple detection methods:
1. Color-based detection (blue/red ink)
2. Stroke analysis (thin irregular strokes)
3. Edge density variance
4. Pencil detection (gray ink)
DATENSCHUTZ: All processing happens locally on Mac Mini.
"""
@@ -37,12 +38,16 @@ class DetectionResult:
detection_method: str # Which method was primarily used
def detect_handwriting(image_bytes: bytes) -> DetectionResult:
def detect_handwriting(image_bytes: bytes, target_ink: str = "all") -> DetectionResult:
"""
Detect handwriting in an image.
Args:
image_bytes: Image as bytes (PNG, JPG, etc.)
target_ink: Which ink types to detect:
- "all" → all methods combined (incl. pencil)
- "colored" → only color-based (blue/red/green pen)
- "pencil" → only pencil (gray ink)
Returns:
DetectionResult with binary mask where handwriting is white (255)
@@ -62,35 +67,51 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult:
# Convert to BGR if needed (OpenCV format)
if len(img_array.shape) == 2:
# Grayscale to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
elif img_array.shape[2] == 4:
# RGBA to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
elif img_array.shape[2] == 3:
# RGB to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
img_bgr = img_array
# Run multiple detection methods
color_mask, color_confidence = _detect_by_color(img_bgr)
stroke_mask, stroke_confidence = _detect_by_stroke_analysis(img_bgr)
variance_mask, variance_confidence = _detect_by_variance(img_bgr)
# Select detection methods based on target_ink
masks_and_weights = []
if target_ink in ("all", "colored"):
color_mask, color_conf = _detect_by_color(img_bgr)
masks_and_weights.append((color_mask, color_conf, "color"))
if target_ink == "all":
stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr)
variance_mask, variance_conf = _detect_by_variance(img_bgr)
masks_and_weights.append((stroke_mask, stroke_conf, "stroke"))
masks_and_weights.append((variance_mask, variance_conf, "variance"))
if target_ink in ("all", "pencil"):
pencil_mask, pencil_conf = _detect_pencil(img_bgr)
masks_and_weights.append((pencil_mask, pencil_conf, "pencil"))
if not masks_and_weights:
# Fallback: use all methods
color_mask, color_conf = _detect_by_color(img_bgr)
stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr)
variance_mask, variance_conf = _detect_by_variance(img_bgr)
pencil_mask, pencil_conf = _detect_pencil(img_bgr)
masks_and_weights = [
(color_mask, color_conf, "color"),
(stroke_mask, stroke_conf, "stroke"),
(variance_mask, variance_conf, "variance"),
(pencil_mask, pencil_conf, "pencil"),
]
# Combine masks using weighted average
weights = [color_confidence, stroke_confidence, variance_confidence]
total_weight = sum(weights)
total_weight = sum(w for _, w, _ in masks_and_weights)
if total_weight > 0:
# Weighted combination
combined_mask = (
color_mask.astype(np.float32) * color_confidence +
stroke_mask.astype(np.float32) * stroke_confidence +
variance_mask.astype(np.float32) * variance_confidence
combined_mask = sum(
m.astype(np.float32) * w for m, w, _ in masks_and_weights
) / total_weight
# Threshold to binary
combined_mask = (combined_mask > 127).astype(np.uint8) * 255
else:
combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
@@ -103,19 +124,11 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult:
handwriting_pixels = np.sum(combined_mask > 0)
handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0
# Determine primary method
primary_method = "combined"
max_conf = max(color_confidence, stroke_confidence, variance_confidence)
if max_conf == color_confidence:
primary_method = "color"
elif max_conf == stroke_confidence:
primary_method = "stroke"
else:
primary_method = "variance"
# Determine primary method (highest confidence)
primary_method = max(masks_and_weights, key=lambda x: x[1])[2] if masks_and_weights else "combined"
overall_confidence = total_weight / len(masks_and_weights) if masks_and_weights else 0.0
overall_confidence = total_weight / 3.0 # Average confidence
logger.info(f"Handwriting detection: {handwriting_ratio:.2%} handwriting, "
logger.info(f"Handwriting detection (target_ink={target_ink}): {handwriting_ratio:.2%} handwriting, "
f"confidence={overall_confidence:.2f}, method={primary_method}")
return DetectionResult(
@@ -180,6 +193,27 @@ def _detect_by_color(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
return color_mask, confidence
def _detect_pencil(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Detect pencil marks (gray ink, ~140-220 on 255-scale).
Paper is usually >230, dark ink <130.
Pencil falls in the 140-220 gray range.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
pencil_mask = cv2.inRange(gray, 140, 220)
# Remove small noise artifacts
kernel = np.ones((2, 2), np.uint8)
pencil_mask = cv2.morphologyEx(pencil_mask, cv2.MORPH_OPEN, kernel, iterations=1)
ratio = np.sum(pencil_mask > 0) / pencil_mask.size
# Good confidence if pencil pixels are in a plausible range
confidence = 0.75 if 0.002 < ratio < 0.2 else 0.2
return pencil_mask, confidence
def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Detect handwriting by analyzing stroke characteristics.

View File

@@ -31,8 +31,10 @@ from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# Lazy loading for heavy dependencies
_trocr_processor = None
_trocr_model = None
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias → base-printed
_trocr_model = None # backwards-compat alias → base-printed
_trocr_available = None
_model_loaded_at = None
@@ -124,12 +126,14 @@ def _check_trocr_available() -> bool:
return _trocr_available
def get_trocr_model(handwritten: bool = False):
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
@@ -138,31 +142,42 @@ def get_trocr_model(handwritten: bool = False):
if not _check_trocr_available():
return None, None
if _trocr_processor is None or _trocr_model is None:
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
# Choose model based on use case
if handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
logger.info(f"Loading TrOCR model: {model_name}")
_trocr_processor = TrOCRProcessor.from_pretrained(model_name)
_trocr_model = VisionEncoderDecoderModel.from_pretrained(model_name)
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
_trocr_model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
except Exception as e:
logger.error(f"Failed to load TrOCR model: {e}")
return None, None
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
return _trocr_processor, _trocr_model
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
@@ -209,7 +224,8 @@ def get_model_status() -> Dict[str, Any]:
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
@@ -223,11 +239,12 @@ async def run_trocr_ocr(
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
processor, model = get_trocr_model(handwritten=handwritten)
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR model not available")