Files
breakpilot-lehrer/klausur-service/backend/services/donut_ocr_service.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

255 lines
7.6 KiB
Python

"""
Donut OCR Service
Document Understanding Transformer (Donut) fuer strukturierte Dokumentenverarbeitung.
Besonders geeignet fuer:
- Tabellen
- Formulare
- Strukturierte Dokumente
- Rechnungen/Quittungen
Model: naver-clova-ix/donut-base (oder fine-tuned Variante)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import logging
from typing import Tuple, Optional
logger = logging.getLogger(__name__)
# Lazy loading for heavy dependencies
_donut_processor = None
_donut_model = None
_donut_available = None
def _check_donut_available() -> bool:
"""Check if Donut dependencies are available."""
global _donut_available
if _donut_available is not None:
return _donut_available
try:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
_donut_available = True
except ImportError as e:
logger.warning(f"Donut dependencies not available: {e}")
_donut_available = False
return _donut_available
def get_donut_model():
"""
Lazy load Donut model and processor.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _donut_processor, _donut_model
if not _check_donut_available():
return None, None
if _donut_processor is None or _donut_model is None:
try:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
model_name = "naver-clova-ix/donut-base"
logger.info(f"Loading Donut model: {model_name}")
_donut_processor = DonutProcessor.from_pretrained(model_name)
_donut_model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
_donut_model.to(device)
logger.info(f"Donut model loaded on device: {device}")
except Exception as e:
logger.error(f"Failed to load Donut model: {e}")
return None, None
return _donut_processor, _donut_model
async def run_donut_ocr(image_data: bytes, task_prompt: str = "<s_cord-v2>") -> Tuple[Optional[str], float]:
"""
Run Donut OCR on an image.
Donut is an end-to-end document understanding model that can:
- Extract text
- Understand document structure
- Parse forms and tables
Args:
image_data: Raw image bytes
task_prompt: Donut task prompt (default: CORD for receipts/documents)
- "<s_cord-v2>": Receipt/document parsing
- "<s_docvqa>": Document Visual QA
- "<s_synthdog>": Synthetic document generation
Returns:
Tuple of (extracted_text, confidence)
"""
processor, model = get_donut_model()
if processor is None or model is None:
logger.error("Donut model not available")
return None, 0.0
try:
import torch
from PIL import Image
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Prepare input
pixel_values = processor(image, return_tensors="pt").pixel_values
# Move to same device as model
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
# Prepare decoder input
task_prompt_ids = processor.tokenizer(
task_prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids.to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
pixel_values,
decoder_input_ids=task_prompt_ids,
max_length=model.config.decoder.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode output
sequence = processor.batch_decode(outputs.sequences)[0]
# Remove task prompt and special tokens
sequence = sequence.replace(task_prompt, "").replace(
processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# Parse the output (Donut outputs JSON-like structure)
text = _parse_donut_output(sequence)
# Calculate confidence (rough estimate based on output quality)
confidence = 0.8 if text and len(text) > 10 else 0.5
logger.info(f"Donut OCR extracted {len(text)} characters")
return text, confidence
except Exception as e:
logger.error(f"Donut OCR failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
def _parse_donut_output(sequence: str) -> str:
"""
Parse Donut output into plain text.
Donut outputs structured data (JSON-like), we extract readable text.
"""
import re
import json
# Clean up the sequence
sequence = sequence.strip()
# Try to parse as JSON
try:
# Find JSON-like content
json_match = re.search(r'\{.*\}', sequence, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
# Extract text from various possible fields
text_parts = []
_extract_text_recursive(data, text_parts)
return "\n".join(text_parts)
except json.JSONDecodeError:
pass
# Fallback: extract text between tags
text_parts = []
# Match patterns like <text>value</text> or <s_field>value</s_field>
pattern = r'<[^>]+>([^<]+)</[^>]+>'
matches = re.findall(pattern, sequence)
if matches:
text_parts.extend(matches)
return "\n".join(text_parts)
# Last resort: return cleaned sequence
# Remove XML-like tags
clean = re.sub(r'<[^>]+>', ' ', sequence)
clean = re.sub(r'\s+', ' ', clean).strip()
return clean
def _extract_text_recursive(data, text_parts: list, indent: int = 0):
"""Recursively extract text from nested data structure."""
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, str) and value.strip():
# Skip keys that look like metadata
if not key.startswith('_'):
text_parts.append(f"{value.strip()}")
elif isinstance(value, (dict, list)):
_extract_text_recursive(value, text_parts, indent + 1)
elif isinstance(data, list):
for item in data:
_extract_text_recursive(item, text_parts, indent)
elif isinstance(data, str) and data.strip():
text_parts.append(data.strip())
# Alternative: Simple text extraction mode
async def run_donut_ocr_simple(image_data: bytes) -> Tuple[Optional[str], float]:
"""
Simplified Donut OCR that just extracts text without structured parsing.
Uses a more general task prompt for plain text extraction.
"""
return await run_donut_ocr(image_data, task_prompt="<s>")
# Test function
async def test_donut_ocr(image_path: str):
"""Test Donut OCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_donut_ocr(image_data)
print(f"\n=== Donut OCR Test ===")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
asyncio.run(test_donut_ocr(sys.argv[1]))
else:
print("Usage: python donut_ocr_service.py <image_path>")