Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
255 lines
7.6 KiB
Python
255 lines
7.6 KiB
Python
"""
|
|
Donut OCR Service
|
|
|
|
Document Understanding Transformer (Donut) fuer strukturierte Dokumentenverarbeitung.
|
|
Besonders geeignet fuer:
|
|
- Tabellen
|
|
- Formulare
|
|
- Strukturierte Dokumente
|
|
- Rechnungen/Quittungen
|
|
|
|
Model: naver-clova-ix/donut-base (oder fine-tuned Variante)
|
|
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import io
|
|
import logging
|
|
from typing import Tuple, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Lazy loading for heavy dependencies
|
|
_donut_processor = None
|
|
_donut_model = None
|
|
_donut_available = None
|
|
|
|
|
|
def _check_donut_available() -> bool:
|
|
"""Check if Donut dependencies are available."""
|
|
global _donut_available
|
|
if _donut_available is not None:
|
|
return _donut_available
|
|
|
|
try:
|
|
import torch
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
|
_donut_available = True
|
|
except ImportError as e:
|
|
logger.warning(f"Donut dependencies not available: {e}")
|
|
_donut_available = False
|
|
|
|
return _donut_available
|
|
|
|
|
|
def get_donut_model():
|
|
"""
|
|
Lazy load Donut model and processor.
|
|
|
|
Returns tuple of (processor, model) or (None, None) if unavailable.
|
|
"""
|
|
global _donut_processor, _donut_model
|
|
|
|
if not _check_donut_available():
|
|
return None, None
|
|
|
|
if _donut_processor is None or _donut_model is None:
|
|
try:
|
|
import torch
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
|
|
|
model_name = "naver-clova-ix/donut-base"
|
|
|
|
logger.info(f"Loading Donut model: {model_name}")
|
|
_donut_processor = DonutProcessor.from_pretrained(model_name)
|
|
_donut_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
|
|
|
# Use GPU if available
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
_donut_model.to(device)
|
|
logger.info(f"Donut model loaded on device: {device}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Donut model: {e}")
|
|
return None, None
|
|
|
|
return _donut_processor, _donut_model
|
|
|
|
|
|
async def run_donut_ocr(image_data: bytes, task_prompt: str = "<s_cord-v2>") -> Tuple[Optional[str], float]:
|
|
"""
|
|
Run Donut OCR on an image.
|
|
|
|
Donut is an end-to-end document understanding model that can:
|
|
- Extract text
|
|
- Understand document structure
|
|
- Parse forms and tables
|
|
|
|
Args:
|
|
image_data: Raw image bytes
|
|
task_prompt: Donut task prompt (default: CORD for receipts/documents)
|
|
- "<s_cord-v2>": Receipt/document parsing
|
|
- "<s_docvqa>": Document Visual QA
|
|
- "<s_synthdog>": Synthetic document generation
|
|
|
|
Returns:
|
|
Tuple of (extracted_text, confidence)
|
|
"""
|
|
processor, model = get_donut_model()
|
|
|
|
if processor is None or model is None:
|
|
logger.error("Donut model not available")
|
|
return None, 0.0
|
|
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
|
|
# Load image
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
# Prepare input
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values
|
|
|
|
# Move to same device as model
|
|
device = next(model.parameters()).device
|
|
pixel_values = pixel_values.to(device)
|
|
|
|
# Prepare decoder input
|
|
task_prompt_ids = processor.tokenizer(
|
|
task_prompt,
|
|
add_special_tokens=False,
|
|
return_tensors="pt"
|
|
).input_ids.to(device)
|
|
|
|
# Generate
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
pixel_values,
|
|
decoder_input_ids=task_prompt_ids,
|
|
max_length=model.config.decoder.max_position_embeddings,
|
|
early_stopping=True,
|
|
pad_token_id=processor.tokenizer.pad_token_id,
|
|
eos_token_id=processor.tokenizer.eos_token_id,
|
|
use_cache=True,
|
|
num_beams=1,
|
|
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
|
return_dict_in_generate=True,
|
|
)
|
|
|
|
# Decode output
|
|
sequence = processor.batch_decode(outputs.sequences)[0]
|
|
|
|
# Remove task prompt and special tokens
|
|
sequence = sequence.replace(task_prompt, "").replace(
|
|
processor.tokenizer.eos_token, "").replace(
|
|
processor.tokenizer.pad_token, "")
|
|
|
|
# Parse the output (Donut outputs JSON-like structure)
|
|
text = _parse_donut_output(sequence)
|
|
|
|
# Calculate confidence (rough estimate based on output quality)
|
|
confidence = 0.8 if text and len(text) > 10 else 0.5
|
|
|
|
logger.info(f"Donut OCR extracted {len(text)} characters")
|
|
return text, confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"Donut OCR failed: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return None, 0.0
|
|
|
|
|
|
def _parse_donut_output(sequence: str) -> str:
|
|
"""
|
|
Parse Donut output into plain text.
|
|
|
|
Donut outputs structured data (JSON-like), we extract readable text.
|
|
"""
|
|
import re
|
|
import json
|
|
|
|
# Clean up the sequence
|
|
sequence = sequence.strip()
|
|
|
|
# Try to parse as JSON
|
|
try:
|
|
# Find JSON-like content
|
|
json_match = re.search(r'\{.*\}', sequence, re.DOTALL)
|
|
if json_match:
|
|
data = json.loads(json_match.group())
|
|
# Extract text from various possible fields
|
|
text_parts = []
|
|
_extract_text_recursive(data, text_parts)
|
|
return "\n".join(text_parts)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Fallback: extract text between tags
|
|
text_parts = []
|
|
# Match patterns like <text>value</text> or <s_field>value</s_field>
|
|
pattern = r'<[^>]+>([^<]+)</[^>]+>'
|
|
matches = re.findall(pattern, sequence)
|
|
if matches:
|
|
text_parts.extend(matches)
|
|
return "\n".join(text_parts)
|
|
|
|
# Last resort: return cleaned sequence
|
|
# Remove XML-like tags
|
|
clean = re.sub(r'<[^>]+>', ' ', sequence)
|
|
clean = re.sub(r'\s+', ' ', clean).strip()
|
|
return clean
|
|
|
|
|
|
def _extract_text_recursive(data, text_parts: list, indent: int = 0):
|
|
"""Recursively extract text from nested data structure."""
|
|
if isinstance(data, dict):
|
|
for key, value in data.items():
|
|
if isinstance(value, str) and value.strip():
|
|
# Skip keys that look like metadata
|
|
if not key.startswith('_'):
|
|
text_parts.append(f"{value.strip()}")
|
|
elif isinstance(value, (dict, list)):
|
|
_extract_text_recursive(value, text_parts, indent + 1)
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
_extract_text_recursive(item, text_parts, indent)
|
|
elif isinstance(data, str) and data.strip():
|
|
text_parts.append(data.strip())
|
|
|
|
|
|
# Alternative: Simple text extraction mode
|
|
async def run_donut_ocr_simple(image_data: bytes) -> Tuple[Optional[str], float]:
|
|
"""
|
|
Simplified Donut OCR that just extracts text without structured parsing.
|
|
|
|
Uses a more general task prompt for plain text extraction.
|
|
"""
|
|
return await run_donut_ocr(image_data, task_prompt="<s>")
|
|
|
|
|
|
# Test function
|
|
async def test_donut_ocr(image_path: str):
|
|
"""Test Donut OCR on a local image file."""
|
|
with open(image_path, "rb") as f:
|
|
image_data = f.read()
|
|
|
|
text, confidence = await run_donut_ocr(image_data)
|
|
|
|
print(f"\n=== Donut OCR Test ===")
|
|
print(f"Confidence: {confidence:.2f}")
|
|
print(f"Text:\n{text}")
|
|
|
|
return text, confidence
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
import sys
|
|
|
|
if len(sys.argv) > 1:
|
|
asyncio.run(test_donut_ocr(sys.argv[1]))
|
|
else:
|
|
print("Usage: python donut_ocr_service.py <image_path>")
|