""" Donut OCR Service Document Understanding Transformer (Donut) fuer strukturierte Dokumentenverarbeitung. Besonders geeignet fuer: - Tabellen - Formulare - Strukturierte Dokumente - Rechnungen/Quittungen Model: naver-clova-ix/donut-base (oder fine-tuned Variante) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import io import logging from typing import Tuple, Optional logger = logging.getLogger(__name__) # Lazy loading for heavy dependencies _donut_processor = None _donut_model = None _donut_available = None def _check_donut_available() -> bool: """Check if Donut dependencies are available.""" global _donut_available if _donut_available is not None: return _donut_available try: import torch from transformers import DonutProcessor, VisionEncoderDecoderModel _donut_available = True except ImportError as e: logger.warning(f"Donut dependencies not available: {e}") _donut_available = False return _donut_available def get_donut_model(): """ Lazy load Donut model and processor. Returns tuple of (processor, model) or (None, None) if unavailable. """ global _donut_processor, _donut_model if not _check_donut_available(): return None, None if _donut_processor is None or _donut_model is None: try: import torch from transformers import DonutProcessor, VisionEncoderDecoderModel model_name = "naver-clova-ix/donut-base" logger.info(f"Loading Donut model: {model_name}") _donut_processor = DonutProcessor.from_pretrained(model_name) _donut_model = VisionEncoderDecoderModel.from_pretrained(model_name) # Use GPU if available device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" _donut_model.to(device) logger.info(f"Donut model loaded on device: {device}") except Exception as e: logger.error(f"Failed to load Donut model: {e}") return None, None return _donut_processor, _donut_model async def run_donut_ocr(image_data: bytes, task_prompt: str = "") -> Tuple[Optional[str], float]: """ Run Donut OCR on an image. Donut is an end-to-end document understanding model that can: - Extract text - Understand document structure - Parse forms and tables Args: image_data: Raw image bytes task_prompt: Donut task prompt (default: CORD for receipts/documents) - "": Receipt/document parsing - "": Document Visual QA - "": Synthetic document generation Returns: Tuple of (extracted_text, confidence) """ processor, model = get_donut_model() if processor is None or model is None: logger.error("Donut model not available") return None, 0.0 try: import torch from PIL import Image # Load image image = Image.open(io.BytesIO(image_data)).convert("RGB") # Prepare input pixel_values = processor(image, return_tensors="pt").pixel_values # Move to same device as model device = next(model.parameters()).device pixel_values = pixel_values.to(device) # Prepare decoder input task_prompt_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids.to(device) # Generate with torch.no_grad(): outputs = model.generate( pixel_values, decoder_input_ids=task_prompt_ids, max_length=model.config.decoder.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # Decode output sequence = processor.batch_decode(outputs.sequences)[0] # Remove task prompt and special tokens sequence = sequence.replace(task_prompt, "").replace( processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "") # Parse the output (Donut outputs JSON-like structure) text = _parse_donut_output(sequence) # Calculate confidence (rough estimate based on output quality) confidence = 0.8 if text and len(text) > 10 else 0.5 logger.info(f"Donut OCR extracted {len(text)} characters") return text, confidence except Exception as e: logger.error(f"Donut OCR failed: {e}") import traceback logger.error(traceback.format_exc()) return None, 0.0 def _parse_donut_output(sequence: str) -> str: """ Parse Donut output into plain text. Donut outputs structured data (JSON-like), we extract readable text. """ import re import json # Clean up the sequence sequence = sequence.strip() # Try to parse as JSON try: # Find JSON-like content json_match = re.search(r'\{.*\}', sequence, re.DOTALL) if json_match: data = json.loads(json_match.group()) # Extract text from various possible fields text_parts = [] _extract_text_recursive(data, text_parts) return "\n".join(text_parts) except json.JSONDecodeError: pass # Fallback: extract text between tags text_parts = [] # Match patterns like value or value pattern = r'<[^>]+>([^<]+)]+>' matches = re.findall(pattern, sequence) if matches: text_parts.extend(matches) return "\n".join(text_parts) # Last resort: return cleaned sequence # Remove XML-like tags clean = re.sub(r'<[^>]+>', ' ', sequence) clean = re.sub(r'\s+', ' ', clean).strip() return clean def _extract_text_recursive(data, text_parts: list, indent: int = 0): """Recursively extract text from nested data structure.""" if isinstance(data, dict): for key, value in data.items(): if isinstance(value, str) and value.strip(): # Skip keys that look like metadata if not key.startswith('_'): text_parts.append(f"{value.strip()}") elif isinstance(value, (dict, list)): _extract_text_recursive(value, text_parts, indent + 1) elif isinstance(data, list): for item in data: _extract_text_recursive(item, text_parts, indent) elif isinstance(data, str) and data.strip(): text_parts.append(data.strip()) # Alternative: Simple text extraction mode async def run_donut_ocr_simple(image_data: bytes) -> Tuple[Optional[str], float]: """ Simplified Donut OCR that just extracts text without structured parsing. Uses a more general task prompt for plain text extraction. """ return await run_donut_ocr(image_data, task_prompt="") # Test function async def test_donut_ocr(image_path: str): """Test Donut OCR on a local image file.""" with open(image_path, "rb") as f: image_data = f.read() text, confidence = await run_donut_ocr(image_data) print(f"\n=== Donut OCR Test ===") print(f"Confidence: {confidence:.2f}") print(f"Text:\n{text}") return text, confidence if __name__ == "__main__": import asyncio import sys if len(sys.argv) > 1: asyncio.run(test_donut_ocr(sys.argv[1])) else: print("Usage: python donut_ocr_service.py ")