feat(klausur): Handschrift entfernen + Klausur-HTR implementiert
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 26s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m49s
CI / test-python-agent-core (push) Successful in 14s
CI / test-nodejs-website (push) Successful in 15s
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 26s
CI / test-go-edu-search (push) Successful in 26s
CI / test-python-klausur (push) Failing after 1m49s
CI / test-python-agent-core (push) Successful in 14s
CI / test-nodejs-website (push) Successful in 15s
Feature 1: Handschrift entfernen via OCR-Pipeline Session
- services/handwriting_detection.py: _detect_pencil() + target_ink Parameter
("all" | "colored" | "pencil") für gezielte Tinten-Erkennung
- ocr_pipeline_session_store.py: clean_png + handwriting_removal_meta Spalten
(idempotentes ALTER TABLE in init_ocr_pipeline_tables)
- ocr_pipeline_api.py: POST /sessions/{id}/remove-handwriting Endpoint
+ "clean" zu valid_types für Image-Serving hinzugefügt
Feature 2: Klausur-HTR (Hochwertige Handschriftenerkennung)
- handwriting_htr_api.py: Neuer Router /api/v1/htr/recognize + /recognize-session
Primary: qwen2.5vl:32b via Ollama, Fallback: trocr-large-handwritten
- services/trocr_service.py: size Parameter (base | large) für get_trocr_model()
+ run_trocr_ocr() - unterstützt jetzt trocr-large-handwritten
- main.py: HTR Router registriert
Config:
- docker-compose.yml: OLLAMA_HTR_MODEL, HTR_FALLBACK_MODEL
- .env.example: HTR Env-Vars dokumentiert
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
17
.env.example
17
.env.example
@@ -30,6 +30,23 @@ OLLAMA_VISION_MODEL=llama3.2-vision
|
|||||||
OLLAMA_CORRECTION_MODEL=llama3.2
|
OLLAMA_CORRECTION_MODEL=llama3.2
|
||||||
OLLAMA_TIMEOUT=120
|
OLLAMA_TIMEOUT=120
|
||||||
|
|
||||||
|
# OCR-Pipeline: LLM-Review (Schritt 6)
|
||||||
|
# Kleine Modelle reichen fuer Zeichen-Korrekturen (0->O, 1->l, 5->S)
|
||||||
|
# Optionen: qwen3:0.6b, qwen3:1.7b, gemma3:1b, qwen3:30b-a3b
|
||||||
|
OLLAMA_REVIEW_MODEL=qwen3:0.6b
|
||||||
|
# Eintraege pro Ollama-Call. Groesser = weniger HTTP-Overhead.
|
||||||
|
OLLAMA_REVIEW_BATCH_SIZE=20
|
||||||
|
|
||||||
|
# OCR-Pipeline: Engine fuer Schritt 5 (Worterkennung)
|
||||||
|
# Optionen: auto (bevorzugt RapidOCR), rapid, tesseract,
|
||||||
|
# trocr-printed, trocr-handwritten, lighton
|
||||||
|
OCR_ENGINE=auto
|
||||||
|
|
||||||
|
# Klausur-HTR: Primaerem Modell fuer Handschriftenerkennung (qwen2.5vl bereits auf Mac Mini)
|
||||||
|
OLLAMA_HTR_MODEL=qwen2.5vl:32b
|
||||||
|
# HTR Fallback: genutzt wenn Ollama nicht erreichbar (auto-download ~340 MB)
|
||||||
|
HTR_FALLBACK_MODEL=trocr-large
|
||||||
|
|
||||||
# Anthropic (optional)
|
# Anthropic (optional)
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ volumes:
|
|||||||
eh_uploads:
|
eh_uploads:
|
||||||
ocr_labeling:
|
ocr_labeling:
|
||||||
paddle_models:
|
paddle_models:
|
||||||
|
lighton_models:
|
||||||
paddleocr_models:
|
paddleocr_models:
|
||||||
transcription_models:
|
transcription_models:
|
||||||
transcription_temp:
|
transcription_temp:
|
||||||
@@ -209,6 +210,7 @@ services:
|
|||||||
- eh_uploads:/app/eh-uploads
|
- eh_uploads:/app/eh-uploads
|
||||||
- ocr_labeling:/app/ocr-labeling
|
- ocr_labeling:/app/ocr-labeling
|
||||||
- paddle_models:/root/.paddlex
|
- paddle_models:/root/.paddlex
|
||||||
|
- lighton_models:/root/.cache/huggingface
|
||||||
environment:
|
environment:
|
||||||
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
|
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
|
||||||
BACKEND_URL: http://backend-lehrer:8001
|
BACKEND_URL: http://backend-lehrer:8001
|
||||||
@@ -231,6 +233,11 @@ services:
|
|||||||
OLLAMA_DEFAULT_MODEL: ${OLLAMA_DEFAULT_MODEL:-llama3.2}
|
OLLAMA_DEFAULT_MODEL: ${OLLAMA_DEFAULT_MODEL:-llama3.2}
|
||||||
OLLAMA_VISION_MODEL: ${OLLAMA_VISION_MODEL:-llama3.2-vision}
|
OLLAMA_VISION_MODEL: ${OLLAMA_VISION_MODEL:-llama3.2-vision}
|
||||||
OLLAMA_CORRECTION_MODEL: ${OLLAMA_CORRECTION_MODEL:-llama3.2}
|
OLLAMA_CORRECTION_MODEL: ${OLLAMA_CORRECTION_MODEL:-llama3.2}
|
||||||
|
OLLAMA_REVIEW_MODEL: ${OLLAMA_REVIEW_MODEL:-qwen3:0.6b}
|
||||||
|
OLLAMA_REVIEW_BATCH_SIZE: ${OLLAMA_REVIEW_BATCH_SIZE:-20}
|
||||||
|
OCR_ENGINE: ${OCR_ENGINE:-auto}
|
||||||
|
OLLAMA_HTR_MODEL: ${OLLAMA_HTR_MODEL:-qwen2.5vl:32b}
|
||||||
|
HTR_FALLBACK_MODEL: ${HTR_FALLBACK_MODEL:-trocr-large}
|
||||||
RAG_SERVICE_URL: http://bp-core-rag-service:8097
|
RAG_SERVICE_URL: http://bp-core-rag-service:8097
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
|
|||||||
276
klausur-service/backend/handwriting_htr_api.py
Normal file
276
klausur-service/backend/handwriting_htr_api.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Handwriting HTR API - Hochwertige Handschriftenerkennung (HTR) fuer Klausurkorrekturen.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- POST /api/v1/htr/recognize - Bild hochladen → handgeschriebener Text
|
||||||
|
- POST /api/v1/htr/recognize-session - OCR-Pipeline Session als Quelle nutzen
|
||||||
|
|
||||||
|
Modell-Strategie:
|
||||||
|
1. qwen2.5vl:32b via Ollama (primaer, hoechste Qualitaet als VLM)
|
||||||
|
2. microsoft/trocr-large-handwritten (Fallback, offline, kein Ollama)
|
||||||
|
|
||||||
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/htr", tags=["HTR"])
|
||||||
|
|
||||||
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||||
|
OLLAMA_HTR_MODEL = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||||
|
HTR_FALLBACK_MODEL = os.getenv("HTR_FALLBACK_MODEL", "trocr-large")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class HTRSessionRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
model: str = "auto" # "auto" | "qwen2.5vl" | "trocr-large"
|
||||||
|
use_clean: bool = True # Prefer clean_png (after handwriting removal)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Preprocessing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _preprocess_for_htr(img_bgr: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
CLAHE contrast enhancement + upscale to improve HTR accuracy.
|
||||||
|
Returns grayscale enhanced image.
|
||||||
|
"""
|
||||||
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||||
|
enhanced = clahe.apply(gray)
|
||||||
|
|
||||||
|
# Upscale if image is too small
|
||||||
|
h, w = enhanced.shape
|
||||||
|
if min(h, w) < 800:
|
||||||
|
scale = 800 / min(h, w)
|
||||||
|
enhanced = cv2.resize(
|
||||||
|
enhanced, None, fx=scale, fy=scale,
|
||||||
|
interpolation=cv2.INTER_CUBIC
|
||||||
|
)
|
||||||
|
|
||||||
|
return enhanced
|
||||||
|
|
||||||
|
|
||||||
|
def _bgr_to_png_bytes(img_bgr: np.ndarray) -> bytes:
|
||||||
|
"""Convert BGR ndarray to PNG bytes."""
|
||||||
|
success, buf = cv2.imencode(".png", img_bgr)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to encode image to PNG")
|
||||||
|
return buf.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_image_bytes(image_bytes: bytes) -> bytes:
|
||||||
|
"""Load image, apply HTR preprocessing, return PNG bytes."""
|
||||||
|
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||||
|
img_bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise ValueError("Could not decode image")
|
||||||
|
|
||||||
|
enhanced = _preprocess_for_htr(img_bgr)
|
||||||
|
# Convert grayscale back to BGR for encoding
|
||||||
|
enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
|
||||||
|
return _bgr_to_png_bytes(enhanced_bgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backend: Ollama qwen2.5vl
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _recognize_with_qwen_vl(image_bytes: bytes, language: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Send image to Ollama qwen2.5vl:32b for HTR.
|
||||||
|
Returns extracted text or None on error.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
lang_hint = {
|
||||||
|
"de": "Deutsch",
|
||||||
|
"en": "Englisch",
|
||||||
|
"de+en": "Deutsch und Englisch",
|
||||||
|
}.get(language, "Deutsch")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Du bist ein OCR-Experte fuer handgeschriebenen Text auf {lang_hint}. "
|
||||||
|
"Lies den Text im Bild exakt ab — korrigiere KEINE Rechtschreibfehler. "
|
||||||
|
"Antworte NUR mit dem erkannten Text, ohne Erklaerungen."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": OLLAMA_HTR_MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": [img_b64],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
return data.get("response", "").strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Ollama qwen2.5vl HTR failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backend: TrOCR-large fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _recognize_with_trocr_large(image_bytes: bytes) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Use microsoft/trocr-large-handwritten via trocr_service.py.
|
||||||
|
Returns extracted text or None on error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from services.trocr_service import run_trocr_ocr, _check_trocr_available
|
||||||
|
if not _check_trocr_available():
|
||||||
|
logger.warning("TrOCR not available for HTR fallback")
|
||||||
|
return None
|
||||||
|
|
||||||
|
text, confidence = await run_trocr_ocr(image_bytes, handwritten=True, size="large")
|
||||||
|
return text.strip() if text else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TrOCR-large HTR failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core recognition logic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _do_recognize(
|
||||||
|
image_bytes: bytes,
|
||||||
|
model: str = "auto",
|
||||||
|
preprocess: bool = True,
|
||||||
|
language: str = "de",
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Core HTR logic: preprocess → try Ollama → fallback to TrOCR-large.
|
||||||
|
Returns dict with text, model_used, processing_time_ms.
|
||||||
|
"""
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
if preprocess:
|
||||||
|
try:
|
||||||
|
image_bytes = _preprocess_image_bytes(image_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"HTR preprocessing failed, using raw image: {e}")
|
||||||
|
|
||||||
|
text: Optional[str] = None
|
||||||
|
model_used: str = "none"
|
||||||
|
|
||||||
|
use_qwen = model in ("auto", "qwen2.5vl")
|
||||||
|
use_trocr = model in ("auto", "trocr-large") or (use_qwen and text is None)
|
||||||
|
|
||||||
|
if use_qwen:
|
||||||
|
text = await _recognize_with_qwen_vl(image_bytes, language)
|
||||||
|
if text is not None:
|
||||||
|
model_used = f"qwen2.5vl ({OLLAMA_HTR_MODEL})"
|
||||||
|
|
||||||
|
if text is None and (use_trocr or model == "trocr-large"):
|
||||||
|
text = await _recognize_with_trocr_large(image_bytes)
|
||||||
|
if text is not None:
|
||||||
|
model_used = "trocr-large-handwritten"
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
text = ""
|
||||||
|
model_used = "none (all backends failed)"
|
||||||
|
|
||||||
|
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"model_used": model_used,
|
||||||
|
"processing_time_ms": elapsed_ms,
|
||||||
|
"language": language,
|
||||||
|
"preprocessed": preprocess,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/recognize")
|
||||||
|
async def recognize_handwriting(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
model: str = Query("auto", description="auto | qwen2.5vl | trocr-large"),
|
||||||
|
preprocess: bool = Query(True, description="Apply CLAHE + upscale before recognition"),
|
||||||
|
language: str = Query("de", description="de | en | de+en"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload an image and get back the handwritten text as plain text.
|
||||||
|
|
||||||
|
Tries qwen2.5vl:32b via Ollama first, falls back to TrOCR-large-handwritten.
|
||||||
|
"""
|
||||||
|
if model not in ("auto", "qwen2.5vl", "trocr-large"):
|
||||||
|
raise HTTPException(status_code=400, detail="model must be one of: auto, qwen2.5vl, trocr-large")
|
||||||
|
if language not in ("de", "en", "de+en"):
|
||||||
|
raise HTTPException(status_code=400, detail="language must be one of: de, en, de+en")
|
||||||
|
|
||||||
|
image_bytes = await file.read()
|
||||||
|
if not image_bytes:
|
||||||
|
raise HTTPException(status_code=400, detail="Empty file")
|
||||||
|
|
||||||
|
return await _do_recognize(image_bytes, model=model, preprocess=preprocess, language=language)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize-session")
|
||||||
|
async def recognize_from_session(req: HTRSessionRequest):
|
||||||
|
"""
|
||||||
|
Use an OCR-Pipeline session as image source for HTR.
|
||||||
|
|
||||||
|
Set use_clean=true to prefer the clean image (after handwriting removal step).
|
||||||
|
This is useful when you want to do HTR on isolated handwriting regions.
|
||||||
|
"""
|
||||||
|
from ocr_pipeline_session_store import get_session_db, get_session_image
|
||||||
|
|
||||||
|
session = await get_session_db(req.session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {req.session_id} not found")
|
||||||
|
|
||||||
|
# Choose source image
|
||||||
|
image_bytes: Optional[bytes] = None
|
||||||
|
source_used: str = ""
|
||||||
|
|
||||||
|
if req.use_clean:
|
||||||
|
image_bytes = await get_session_image(req.session_id, "clean")
|
||||||
|
if image_bytes:
|
||||||
|
source_used = "clean"
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
image_bytes = await get_session_image(req.session_id, "deskewed")
|
||||||
|
if image_bytes:
|
||||||
|
source_used = "deskewed"
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
image_bytes = await get_session_image(req.session_id, "original")
|
||||||
|
source_used = "original"
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
raise HTTPException(status_code=404, detail="No image available in session")
|
||||||
|
|
||||||
|
result = await _do_recognize(image_bytes, model=req.model)
|
||||||
|
result["session_id"] = req.session_id
|
||||||
|
result["source_image"] = source_used
|
||||||
|
return result
|
||||||
@@ -44,6 +44,10 @@ except ImportError:
|
|||||||
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
|
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
|
||||||
from ocr_pipeline_api import router as ocr_pipeline_router
|
from ocr_pipeline_api import router as ocr_pipeline_router
|
||||||
from ocr_pipeline_session_store import init_ocr_pipeline_tables
|
from ocr_pipeline_session_store import init_ocr_pipeline_tables
|
||||||
|
try:
|
||||||
|
from handwriting_htr_api import router as htr_router
|
||||||
|
except ImportError:
|
||||||
|
htr_router = None
|
||||||
try:
|
try:
|
||||||
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
|
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
|
||||||
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
|
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
|
||||||
@@ -113,6 +117,19 @@ async def lifespan(app: FastAPI):
|
|||||||
# Ensure EH upload directory exists
|
# Ensure EH upload directory exists
|
||||||
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
|
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# Preload LightOnOCR model if OCR_ENGINE=lighton (avoids cold-start on first request)
|
||||||
|
ocr_engine_env = os.getenv("OCR_ENGINE", "auto")
|
||||||
|
if ocr_engine_env == "lighton":
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
from services.lighton_ocr_service import get_lighton_model
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
print("Preloading LightOnOCR-2-1B at startup (OCR_ENGINE=lighton)...")
|
||||||
|
await loop.run_in_executor(None, get_lighton_model)
|
||||||
|
print("LightOnOCR-2-1B preloaded")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: LightOnOCR preload failed: {e}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
print("Klausur-Service shutting down...")
|
print("Klausur-Service shutting down...")
|
||||||
@@ -160,6 +177,8 @@ if trocr_router:
|
|||||||
app.include_router(trocr_router) # TrOCR Handwriting OCR
|
app.include_router(trocr_router) # TrOCR Handwriting OCR
|
||||||
app.include_router(vocab_router) # Vocabulary Worksheet Generator
|
app.include_router(vocab_router) # Vocabulary Worksheet Generator
|
||||||
app.include_router(ocr_pipeline_router) # OCR Pipeline (step-by-step)
|
app.include_router(ocr_pipeline_router) # OCR Pipeline (step-by-step)
|
||||||
|
if htr_router:
|
||||||
|
app.include_router(htr_router) # Handwriting HTR (Klausur)
|
||||||
if dsfa_rag_router:
|
if dsfa_rag_router:
|
||||||
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search
|
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,13 @@ class RowGroundTruthRequest(BaseModel):
|
|||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveHandwritingRequest(BaseModel):
|
||||||
|
method: str = "auto" # "auto" | "telea" | "ns"
|
||||||
|
target_ink: str = "all" # "all" | "colored" | "pencil"
|
||||||
|
dilation: int = 2 # mask dilation iterations (0-5)
|
||||||
|
use_source: str = "auto" # "original" | "deskewed" | "auto"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Session Management Endpoints
|
# Session Management Endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -309,7 +316,7 @@ async def delete_session(session_id: str):
|
|||||||
@router.get("/sessions/{session_id}/image/{image_type}")
|
@router.get("/sessions/{session_id}/image/{image_type}")
|
||||||
async def get_image(session_id: str, image_type: str):
|
async def get_image(session_id: str, image_type: str):
|
||||||
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
"""Serve session images: original, deskewed, dewarped, binarized, columns-overlay, or rows-overlay."""
|
||||||
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay"}
|
valid_types = {"original", "deskewed", "dewarped", "binarized", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
|
||||||
if image_type not in valid_types:
|
if image_type not in valid_types:
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
||||||
|
|
||||||
@@ -1906,3 +1913,90 @@ async def _get_words_overlay(session_id: str) -> Response:
|
|||||||
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
raise HTTPException(status_code=500, detail="Failed to encode overlay image")
|
||||||
|
|
||||||
return Response(content=result_png.tobytes(), media_type="image/png")
|
return Response(content=result_png.tobytes(), media_type="image/png")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Handwriting Removal Endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/remove-handwriting")
|
||||||
|
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||||||
|
"""
|
||||||
|
Remove handwriting from a session image using inpainting.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Load source image (auto → deskewed if available, else original)
|
||||||
|
2. Detect handwriting mask (filtered by target_ink)
|
||||||
|
3. Dilate mask to cover stroke edges
|
||||||
|
4. Inpaint the image
|
||||||
|
5. Store result as clean_png in the session
|
||||||
|
|
||||||
|
Returns metadata including the URL to fetch the clean image.
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
t0 = _time.monotonic()
|
||||||
|
|
||||||
|
from services.handwriting_detection import detect_handwriting
|
||||||
|
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||||||
|
|
||||||
|
session = await get_session_db(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# 1. Determine source image
|
||||||
|
source = req.use_source
|
||||||
|
if source == "auto":
|
||||||
|
deskewed = await get_session_image(session_id, "deskewed")
|
||||||
|
source = "deskewed" if deskewed else "original"
|
||||||
|
|
||||||
|
image_bytes = await get_session_image(session_id, source)
|
||||||
|
if not image_bytes:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||||||
|
|
||||||
|
# 2. Detect handwriting mask
|
||||||
|
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||||||
|
|
||||||
|
# 3. Convert mask to PNG bytes and dilate
|
||||||
|
import io
|
||||||
|
from PIL import Image as _PILImage
|
||||||
|
mask_img = _PILImage.fromarray(detection.mask)
|
||||||
|
mask_buf = io.BytesIO()
|
||||||
|
mask_img.save(mask_buf, format="PNG")
|
||||||
|
mask_bytes = mask_buf.getvalue()
|
||||||
|
|
||||||
|
if req.dilation > 0:
|
||||||
|
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||||||
|
|
||||||
|
# 4. Inpaint
|
||||||
|
method_map = {
|
||||||
|
"telea": InpaintingMethod.OPENCV_TELEA,
|
||||||
|
"ns": InpaintingMethod.OPENCV_NS,
|
||||||
|
"auto": InpaintingMethod.AUTO,
|
||||||
|
}
|
||||||
|
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||||||
|
|
||||||
|
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||||||
|
if not result.success:
|
||||||
|
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||||||
|
|
||||||
|
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||||||
|
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||||||
|
"detection_confidence": round(detection.confidence, 4),
|
||||||
|
"target_ink": req.target_ink,
|
||||||
|
"dilation": req.dilation,
|
||||||
|
"source_image": source,
|
||||||
|
"processing_time_ms": elapsed_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Persist clean image (convert BGR ndarray → PNG bytes)
|
||||||
|
clean_png_bytes = image_to_png(result.image)
|
||||||
|
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**meta,
|
||||||
|
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ async def init_ocr_pipeline_tables():
|
|||||||
else:
|
else:
|
||||||
logger.debug("OCR pipeline tables already exist")
|
logger.debug("OCR pipeline tables already exist")
|
||||||
|
|
||||||
|
# Ensure new columns exist (idempotent ALTER TABLE)
|
||||||
|
await conn.execute("""
|
||||||
|
ALTER TABLE ocr_pipeline_sessions
|
||||||
|
ADD COLUMN IF NOT EXISTS clean_png BYTEA,
|
||||||
|
ADD COLUMN IF NOT EXISTS handwriting_removal_meta JSONB
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SESSION CRUD
|
# SESSION CRUD
|
||||||
@@ -111,6 +118,7 @@ async def get_session_image(session_id: str, image_type: str) -> Optional[bytes]
|
|||||||
"deskewed": "deskewed_png",
|
"deskewed": "deskewed_png",
|
||||||
"binarized": "binarized_png",
|
"binarized": "binarized_png",
|
||||||
"dewarped": "dewarped_png",
|
"dewarped": "dewarped_png",
|
||||||
|
"clean": "clean_png",
|
||||||
}
|
}
|
||||||
column = column_map.get(image_type)
|
column = column_map.get(image_type)
|
||||||
if not column:
|
if not column:
|
||||||
@@ -135,11 +143,12 @@ async def update_session_db(session_id: str, **kwargs) -> Optional[Dict[str, Any
|
|||||||
allowed_fields = {
|
allowed_fields = {
|
||||||
'name', 'filename', 'status', 'current_step',
|
'name', 'filename', 'status', 'current_step',
|
||||||
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
|
'original_png', 'deskewed_png', 'binarized_png', 'dewarped_png',
|
||||||
|
'clean_png', 'handwriting_removal_meta',
|
||||||
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
'deskew_result', 'dewarp_result', 'column_result', 'row_result',
|
||||||
'word_result', 'ground_truth', 'auto_shear_degrees',
|
'word_result', 'ground_truth', 'auto_shear_degrees',
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth'}
|
jsonb_fields = {'deskew_result', 'dewarp_result', 'column_result', 'row_result', 'word_result', 'ground_truth', 'handwriting_removal_meta'}
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in allowed_fields:
|
if key in allowed_fields:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Uses multiple detection methods:
|
|||||||
1. Color-based detection (blue/red ink)
|
1. Color-based detection (blue/red ink)
|
||||||
2. Stroke analysis (thin irregular strokes)
|
2. Stroke analysis (thin irregular strokes)
|
||||||
3. Edge density variance
|
3. Edge density variance
|
||||||
|
4. Pencil detection (gray ink)
|
||||||
|
|
||||||
DATENSCHUTZ: All processing happens locally on Mac Mini.
|
DATENSCHUTZ: All processing happens locally on Mac Mini.
|
||||||
"""
|
"""
|
||||||
@@ -37,12 +38,16 @@ class DetectionResult:
|
|||||||
detection_method: str # Which method was primarily used
|
detection_method: str # Which method was primarily used
|
||||||
|
|
||||||
|
|
||||||
def detect_handwriting(image_bytes: bytes) -> DetectionResult:
|
def detect_handwriting(image_bytes: bytes, target_ink: str = "all") -> DetectionResult:
|
||||||
"""
|
"""
|
||||||
Detect handwriting in an image.
|
Detect handwriting in an image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_bytes: Image as bytes (PNG, JPG, etc.)
|
image_bytes: Image as bytes (PNG, JPG, etc.)
|
||||||
|
target_ink: Which ink types to detect:
|
||||||
|
- "all" → all methods combined (incl. pencil)
|
||||||
|
- "colored" → only color-based (blue/red/green pen)
|
||||||
|
- "pencil" → only pencil (gray ink)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DetectionResult with binary mask where handwriting is white (255)
|
DetectionResult with binary mask where handwriting is white (255)
|
||||||
@@ -62,35 +67,51 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult:
|
|||||||
|
|
||||||
# Convert to BGR if needed (OpenCV format)
|
# Convert to BGR if needed (OpenCV format)
|
||||||
if len(img_array.shape) == 2:
|
if len(img_array.shape) == 2:
|
||||||
# Grayscale to BGR
|
|
||||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
|
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
|
||||||
elif img_array.shape[2] == 4:
|
elif img_array.shape[2] == 4:
|
||||||
# RGBA to BGR
|
|
||||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
|
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
|
||||||
elif img_array.shape[2] == 3:
|
elif img_array.shape[2] == 3:
|
||||||
# RGB to BGR
|
|
||||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||||
else:
|
else:
|
||||||
img_bgr = img_array
|
img_bgr = img_array
|
||||||
|
|
||||||
# Run multiple detection methods
|
# Select detection methods based on target_ink
|
||||||
color_mask, color_confidence = _detect_by_color(img_bgr)
|
masks_and_weights = []
|
||||||
stroke_mask, stroke_confidence = _detect_by_stroke_analysis(img_bgr)
|
|
||||||
variance_mask, variance_confidence = _detect_by_variance(img_bgr)
|
if target_ink in ("all", "colored"):
|
||||||
|
color_mask, color_conf = _detect_by_color(img_bgr)
|
||||||
|
masks_and_weights.append((color_mask, color_conf, "color"))
|
||||||
|
|
||||||
|
if target_ink == "all":
|
||||||
|
stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr)
|
||||||
|
variance_mask, variance_conf = _detect_by_variance(img_bgr)
|
||||||
|
masks_and_weights.append((stroke_mask, stroke_conf, "stroke"))
|
||||||
|
masks_and_weights.append((variance_mask, variance_conf, "variance"))
|
||||||
|
|
||||||
|
if target_ink in ("all", "pencil"):
|
||||||
|
pencil_mask, pencil_conf = _detect_pencil(img_bgr)
|
||||||
|
masks_and_weights.append((pencil_mask, pencil_conf, "pencil"))
|
||||||
|
|
||||||
|
if not masks_and_weights:
|
||||||
|
# Fallback: use all methods
|
||||||
|
color_mask, color_conf = _detect_by_color(img_bgr)
|
||||||
|
stroke_mask, stroke_conf = _detect_by_stroke_analysis(img_bgr)
|
||||||
|
variance_mask, variance_conf = _detect_by_variance(img_bgr)
|
||||||
|
pencil_mask, pencil_conf = _detect_pencil(img_bgr)
|
||||||
|
masks_and_weights = [
|
||||||
|
(color_mask, color_conf, "color"),
|
||||||
|
(stroke_mask, stroke_conf, "stroke"),
|
||||||
|
(variance_mask, variance_conf, "variance"),
|
||||||
|
(pencil_mask, pencil_conf, "pencil"),
|
||||||
|
]
|
||||||
|
|
||||||
# Combine masks using weighted average
|
# Combine masks using weighted average
|
||||||
weights = [color_confidence, stroke_confidence, variance_confidence]
|
total_weight = sum(w for _, w, _ in masks_and_weights)
|
||||||
total_weight = sum(weights)
|
|
||||||
|
|
||||||
if total_weight > 0:
|
if total_weight > 0:
|
||||||
# Weighted combination
|
combined_mask = sum(
|
||||||
combined_mask = (
|
m.astype(np.float32) * w for m, w, _ in masks_and_weights
|
||||||
color_mask.astype(np.float32) * color_confidence +
|
|
||||||
stroke_mask.astype(np.float32) * stroke_confidence +
|
|
||||||
variance_mask.astype(np.float32) * variance_confidence
|
|
||||||
) / total_weight
|
) / total_weight
|
||||||
|
|
||||||
# Threshold to binary
|
|
||||||
combined_mask = (combined_mask > 127).astype(np.uint8) * 255
|
combined_mask = (combined_mask > 127).astype(np.uint8) * 255
|
||||||
else:
|
else:
|
||||||
combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
|
combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
|
||||||
@@ -103,19 +124,11 @@ def detect_handwriting(image_bytes: bytes) -> DetectionResult:
|
|||||||
handwriting_pixels = np.sum(combined_mask > 0)
|
handwriting_pixels = np.sum(combined_mask > 0)
|
||||||
handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0
|
handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0
|
||||||
|
|
||||||
# Determine primary method
|
# Determine primary method (highest confidence)
|
||||||
primary_method = "combined"
|
primary_method = max(masks_and_weights, key=lambda x: x[1])[2] if masks_and_weights else "combined"
|
||||||
max_conf = max(color_confidence, stroke_confidence, variance_confidence)
|
overall_confidence = total_weight / len(masks_and_weights) if masks_and_weights else 0.0
|
||||||
if max_conf == color_confidence:
|
|
||||||
primary_method = "color"
|
|
||||||
elif max_conf == stroke_confidence:
|
|
||||||
primary_method = "stroke"
|
|
||||||
else:
|
|
||||||
primary_method = "variance"
|
|
||||||
|
|
||||||
overall_confidence = total_weight / 3.0 # Average confidence
|
logger.info(f"Handwriting detection (target_ink={target_ink}): {handwriting_ratio:.2%} handwriting, "
|
||||||
|
|
||||||
logger.info(f"Handwriting detection: {handwriting_ratio:.2%} handwriting, "
|
|
||||||
f"confidence={overall_confidence:.2f}, method={primary_method}")
|
f"confidence={overall_confidence:.2f}, method={primary_method}")
|
||||||
|
|
||||||
return DetectionResult(
|
return DetectionResult(
|
||||||
@@ -180,6 +193,27 @@ def _detect_by_color(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
|||||||
return color_mask, confidence
|
return color_mask, confidence
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_pencil(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||||
|
"""
|
||||||
|
Detect pencil marks (gray ink, ~140-220 on 255-scale).
|
||||||
|
|
||||||
|
Paper is usually >230, dark ink <130.
|
||||||
|
Pencil falls in the 140-220 gray range.
|
||||||
|
"""
|
||||||
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
pencil_mask = cv2.inRange(gray, 140, 220)
|
||||||
|
|
||||||
|
# Remove small noise artifacts
|
||||||
|
kernel = np.ones((2, 2), np.uint8)
|
||||||
|
pencil_mask = cv2.morphologyEx(pencil_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
||||||
|
|
||||||
|
ratio = np.sum(pencil_mask > 0) / pencil_mask.size
|
||||||
|
# Good confidence if pencil pixels are in a plausible range
|
||||||
|
confidence = 0.75 if 0.002 < ratio < 0.2 else 0.2
|
||||||
|
|
||||||
|
return pencil_mask, confidence
|
||||||
|
|
||||||
|
|
||||||
def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||||
"""
|
"""
|
||||||
Detect handwriting by analyzing stroke characteristics.
|
Detect handwriting by analyzing stroke characteristics.
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ from datetime import datetime, timedelta
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Lazy loading for heavy dependencies
|
# Lazy loading for heavy dependencies
|
||||||
_trocr_processor = None
|
# Cache keyed by model_name to support base and large variants simultaneously
|
||||||
_trocr_model = None
|
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||||
|
_trocr_processor = None # backwards-compat alias → base-printed
|
||||||
|
_trocr_model = None # backwards-compat alias → base-printed
|
||||||
_trocr_available = None
|
_trocr_available = None
|
||||||
_model_loaded_at = None
|
_model_loaded_at = None
|
||||||
|
|
||||||
@@ -124,12 +126,14 @@ def _check_trocr_available() -> bool:
|
|||||||
return _trocr_available
|
return _trocr_available
|
||||||
|
|
||||||
|
|
||||||
def get_trocr_model(handwritten: bool = False):
|
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||||
"""
|
"""
|
||||||
Lazy load TrOCR model and processor.
|
Lazy load TrOCR model and processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
handwritten: Use handwritten model instead of printed model
|
handwritten: Use handwritten model instead of printed model
|
||||||
|
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||||
|
for exam HTR). Only applies to handwritten variant.
|
||||||
|
|
||||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||||
"""
|
"""
|
||||||
@@ -138,31 +142,42 @@ def get_trocr_model(handwritten: bool = False):
|
|||||||
if not _check_trocr_available():
|
if not _check_trocr_available():
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if _trocr_processor is None or _trocr_model is None:
|
# Select model name
|
||||||
try:
|
if size == "large" and handwritten:
|
||||||
import torch
|
model_name = "microsoft/trocr-large-handwritten"
|
||||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
elif handwritten:
|
||||||
|
model_name = "microsoft/trocr-base-handwritten"
|
||||||
|
else:
|
||||||
|
model_name = "microsoft/trocr-base-printed"
|
||||||
|
|
||||||
# Choose model based on use case
|
if model_name in _trocr_models:
|
||||||
if handwritten:
|
return _trocr_models[model_name]
|
||||||
model_name = "microsoft/trocr-base-handwritten"
|
|
||||||
else:
|
|
||||||
model_name = "microsoft/trocr-base-printed"
|
|
||||||
|
|
||||||
logger.info(f"Loading TrOCR model: {model_name}")
|
try:
|
||||||
_trocr_processor = TrOCRProcessor.from_pretrained(model_name)
|
import torch
|
||||||
_trocr_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
# Use GPU if available
|
logger.info(f"Loading TrOCR model: {model_name}")
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||||
_trocr_model.to(device)
|
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||||
logger.info(f"TrOCR model loaded on device: {device}")
|
|
||||||
|
|
||||||
except Exception as e:
|
# Use GPU if available
|
||||||
logger.error(f"Failed to load TrOCR model: {e}")
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
return None, None
|
model.to(device)
|
||||||
|
logger.info(f"TrOCR model loaded on device: {device}")
|
||||||
|
|
||||||
return _trocr_processor, _trocr_model
|
_trocr_models[model_name] = (processor, model)
|
||||||
|
|
||||||
|
# Keep backwards-compat globals pointing at base-printed
|
||||||
|
if model_name == "microsoft/trocr-base-printed":
|
||||||
|
_trocr_processor = processor
|
||||||
|
_trocr_model = model
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||||
@@ -209,7 +224,8 @@ def get_model_status() -> Dict[str, Any]:
|
|||||||
async def run_trocr_ocr(
|
async def run_trocr_ocr(
|
||||||
image_data: bytes,
|
image_data: bytes,
|
||||||
handwritten: bool = False,
|
handwritten: bool = False,
|
||||||
split_lines: bool = True
|
split_lines: bool = True,
|
||||||
|
size: str = "base",
|
||||||
) -> Tuple[Optional[str], float]:
|
) -> Tuple[Optional[str], float]:
|
||||||
"""
|
"""
|
||||||
Run TrOCR on an image.
|
Run TrOCR on an image.
|
||||||
@@ -223,11 +239,12 @@ async def run_trocr_ocr(
|
|||||||
image_data: Raw image bytes
|
image_data: Raw image bytes
|
||||||
handwritten: Use handwritten model (slower but better for handwriting)
|
handwritten: Use handwritten model (slower but better for handwriting)
|
||||||
split_lines: Whether to split image into lines first
|
split_lines: Whether to split image into lines first
|
||||||
|
size: "base" or "large" (only for handwritten variant)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (extracted_text, confidence)
|
Tuple of (extracted_text, confidence)
|
||||||
"""
|
"""
|
||||||
processor, model = get_trocr_model(handwritten=handwritten)
|
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||||
|
|
||||||
if processor is None or model is None:
|
if processor is None or model is None:
|
||||||
logger.error("TrOCR model not available")
|
logger.error("TrOCR model not available")
|
||||||
|
|||||||
Reference in New Issue
Block a user