""" TrOCR Batch Processing & Streaming Batch OCR and SSE streaming for multiple images. DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ import asyncio import logging import time from typing import Optional, List, Dict, Any from .trocr_models import OCRResult, BatchOCRResult from .trocr_ocr import run_trocr_ocr_enhanced logger = logging.getLogger(__name__) async def run_trocr_batch( images: List[bytes], handwritten: bool = True, split_lines: bool = True, use_cache: bool = True, progress_callback: Optional[callable] = None ) -> BatchOCRResult: """ Process multiple images in batch. Args: images: List of image data bytes handwritten: Use handwritten model split_lines: Whether to split images into lines use_cache: Whether to use caching progress_callback: Optional callback(current, total) for progress updates Returns: BatchOCRResult with all results """ start_time = time.time() results = [] cached_count = 0 error_count = 0 for idx, image_data in enumerate(images): try: result = await run_trocr_ocr_enhanced( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache ) results.append(result) if result.from_cache: cached_count += 1 # Report progress if progress_callback: progress_callback(idx + 1, len(images)) except Exception as e: logger.error(f"Batch OCR error for image {idx}: {e}") error_count += 1 results.append(OCRResult( text=f"Error: {str(e)}", confidence=0.0, processing_time_ms=0, model="error", has_lora_adapter=False )) total_time_ms = int((time.time() - start_time) * 1000) return BatchOCRResult( results=results, total_time_ms=total_time_ms, processed_count=len(images), cached_count=cached_count, error_count=error_count ) # Generator for SSE streaming during batch processing async def run_trocr_batch_stream( images: List[bytes], handwritten: bool = True, split_lines: bool = True, use_cache: bool = True ): """ Process images and yield progress updates for SSE streaming. Yields: dict with current progress and result """ start_time = time.time() total = len(images) for idx, image_data in enumerate(images): try: result = await run_trocr_ocr_enhanced( image_data, handwritten=handwritten, split_lines=split_lines, use_cache=use_cache ) elapsed_ms = int((time.time() - start_time) * 1000) avg_time_per_image = elapsed_ms / (idx + 1) estimated_remaining = int(avg_time_per_image * (total - idx - 1)) yield { "type": "progress", "current": idx + 1, "total": total, "progress_percent": ((idx + 1) / total) * 100, "elapsed_ms": elapsed_ms, "estimated_remaining_ms": estimated_remaining, "result": { "text": result.text, "confidence": result.confidence, "processing_time_ms": result.processing_time_ms, "from_cache": result.from_cache } } except Exception as e: logger.error(f"Stream OCR error for image {idx}: {e}") yield { "type": "error", "current": idx + 1, "total": total, "error": str(e) } total_time_ms = int((time.time() - start_time) * 1000) yield { "type": "complete", "total_time_ms": total_time_ms, "processed_count": total } # Test function async def test_trocr_ocr(image_path: str, handwritten: bool = False): """Test TrOCR on a local image file.""" from .trocr_ocr import run_trocr_ocr with open(image_path, "rb") as f: image_data = f.read() text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten) print(f"\n=== TrOCR Test ===") print(f"Mode: {'Handwritten' if handwritten else 'Printed'}") print(f"Confidence: {confidence:.2f}") print(f"Text:\n{text}") return text, confidence