""" TrOCR Service for Handwriting Recognition. Uses Microsoft's TrOCR model for extracting handwritten text from exam images. Supports fine-tuning with teacher corrections via LoRA adapters. PRIVACY BY DESIGN: - All processing happens locally - No data sent to external services - Fine-tuning data stays on-premise """ import logging import os from pathlib import Path from typing import Optional, List, Dict, Tuple from dataclasses import dataclass from io import BytesIO import json logger = logging.getLogger(__name__) # Model paths MODEL_CACHE_DIR = Path(os.environ.get("TROCR_CACHE_DIR", "/app/models/trocr")) LORA_ADAPTERS_DIR = Path(os.environ.get("TROCR_LORA_DIR", "/app/models/trocr/lora")) TRAINING_DATA_DIR = Path(os.environ.get("TROCR_TRAINING_DIR", "/app/data/trocr_training")) @dataclass class OCRResult: """Result from TrOCR extraction.""" text: str confidence: float bounding_boxes: List[Dict] # [{"x": 0, "y": 0, "w": 100, "h": 20, "text": "..."}] processing_time_ms: int @dataclass class TrainingExample: """A single training example for fine-tuning.""" image_path: str ground_truth: str teacher_id: str created_at: str class TrOCRService: """ Handwriting recognition service using TrOCR. Features: - Line-by-line handwriting extraction - Confidence scoring - LoRA fine-tuning support - Batch processing """ # Available models (from smallest to largest) MODELS = { "trocr-small": "microsoft/trocr-small-handwritten", "trocr-base": "microsoft/trocr-base-handwritten", # Recommended "trocr-large": "microsoft/trocr-large-handwritten", } def __init__(self, model_name: str = "trocr-base", device: str = "auto"): """ Initialize TrOCR service. Args: model_name: One of "trocr-small", "trocr-base", "trocr-large" device: "cpu", "cuda", "mps" (Apple Silicon), or "auto" """ self.model_name = model_name self.model_id = self.MODELS.get(model_name, self.MODELS["trocr-base"]) self.device = self._get_device(device) self._processor = None self._model = None self._lora_adapter = None # Create directories MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) LORA_ADAPTERS_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"TrOCR Service initialized: model={model_name}, device={self.device}") def _get_device(self, device: str) -> str: """Determine the best device for inference.""" if device != "auto": return device try: import torch if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" except ImportError: return "cpu" def _load_model(self): """Lazy-load the TrOCR model.""" if self._model is not None: return try: from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch logger.info(f"Loading TrOCR model: {self.model_id}") self._processor = TrOCRProcessor.from_pretrained( self.model_id, cache_dir=str(MODEL_CACHE_DIR) ) self._model = VisionEncoderDecoderModel.from_pretrained( self.model_id, cache_dir=str(MODEL_CACHE_DIR) ) # Move to device if self.device == "cuda": self._model = self._model.cuda() elif self.device == "mps": self._model = self._model.to("mps") # Load LoRA adapter if exists adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter" if adapter_path.exists(): self._load_lora_adapter(adapter_path) logger.info(f"TrOCR model loaded successfully on {self.device}") except ImportError as e: logger.error(f"Missing dependencies: {e}") logger.error("Install with: pip install transformers torch pillow") raise except Exception as e: logger.error(f"Failed to load TrOCR model: {e}") raise def _load_lora_adapter(self, adapter_path: Path): """Load a LoRA adapter for fine-tuned model.""" try: from peft import PeftModel logger.info(f"Loading LoRA adapter from {adapter_path}") self._model = PeftModel.from_pretrained(self._model, str(adapter_path)) self._lora_adapter = str(adapter_path) logger.info("LoRA adapter loaded successfully") except ImportError: logger.warning("peft not installed, skipping LoRA adapter") except Exception as e: logger.warning(f"Failed to load LoRA adapter: {e}") async def extract_text( self, image_data: bytes, detect_lines: bool = True ) -> OCRResult: """ Extract handwritten text from an image. Args: image_data: Raw image bytes (PNG, JPG, etc.) detect_lines: If True, detect and process individual lines Returns: OCRResult with extracted text and confidence """ import time start_time = time.time() self._load_model() try: from PIL import Image import torch # Load image image = Image.open(BytesIO(image_data)).convert("RGB") if detect_lines: # Detect text lines and process each lines, bboxes = await self._detect_and_extract_lines(image) text = "\n".join(lines) confidence = 0.85 # Average confidence estimate else: # Process whole image text, confidence = await self._extract_single(image) bboxes = [] processing_time_ms = int((time.time() - start_time) * 1000) return OCRResult( text=text, confidence=confidence, bounding_boxes=bboxes, processing_time_ms=processing_time_ms ) except Exception as e: logger.error(f"OCR extraction failed: {e}") return OCRResult( text="", confidence=0.0, bounding_boxes=[], processing_time_ms=int((time.time() - start_time) * 1000) ) async def _extract_single(self, image) -> Tuple[str, float]: """Extract text from a single image (no line detection).""" import torch # Preprocess pixel_values = self._processor( images=image, return_tensors="pt" ).pixel_values if self.device == "cuda": pixel_values = pixel_values.cuda() elif self.device == "mps": pixel_values = pixel_values.to("mps") # Generate with torch.no_grad(): generated_ids = self._model.generate( pixel_values, max_length=128, num_beams=4, return_dict_in_generate=True, output_scores=True ) # Decode text = self._processor.batch_decode( generated_ids.sequences, skip_special_tokens=True )[0] # Estimate confidence from generation scores confidence = self._estimate_confidence(generated_ids) return text.strip(), confidence async def _detect_and_extract_lines(self, image) -> Tuple[List[str], List[Dict]]: """Detect text lines and extract each separately.""" from PIL import Image import numpy as np # Convert to numpy for line detection img_array = np.array(image.convert("L")) # Grayscale # Simple horizontal projection for line detection lines_y = self._detect_line_positions(img_array) if not lines_y: # Fallback: process whole image text, _ = await self._extract_single(image) return [text], [] # Extract each line results = [] bboxes = [] width = image.width for i, (y_start, y_end) in enumerate(lines_y): # Crop line line_img = image.crop((0, y_start, width, y_end)) # Ensure minimum height if line_img.height < 20: continue # Extract text text, conf = await self._extract_single(line_img) if text.strip(): results.append(text) bboxes.append({ "x": 0, "y": y_start, "w": width, "h": y_end - y_start, "text": text, "confidence": conf }) return results, bboxes def _detect_line_positions(self, img_array) -> List[Tuple[int, int]]: """Detect horizontal text line positions using projection profile.""" import numpy as np # Horizontal projection (sum of pixels per row) projection = np.sum(255 - img_array, axis=1) # Threshold to find text rows threshold = np.max(projection) * 0.1 text_rows = projection > threshold # Find line boundaries lines = [] in_line = False line_start = 0 for i, is_text in enumerate(text_rows): if is_text and not in_line: in_line = True line_start = max(0, i - 5) # Add padding elif not is_text and in_line: in_line = False line_end = min(len(text_rows) - 1, i + 5) # Add padding if line_end - line_start > 15: # Minimum line height lines.append((line_start, line_end)) # Handle last line if in_line: lines.append((line_start, len(text_rows) - 1)) return lines def _estimate_confidence(self, generated_output) -> float: """Estimate confidence from generation scores.""" try: import torch if hasattr(generated_output, 'scores') and generated_output.scores: # Average probability of selected tokens probs = [] for score in generated_output.scores: prob = torch.softmax(score, dim=-1).max().item() probs.append(prob) return sum(probs) / len(probs) if probs else 0.5 return 0.75 # Default confidence except Exception: return 0.75 async def batch_extract( self, images: List[bytes], detect_lines: bool = True ) -> List[OCRResult]: """ Extract text from multiple images. Args: images: List of image bytes detect_lines: If True, detect lines in each image Returns: List of OCRResult """ results = [] for img_data in images: result = await self.extract_text(img_data, detect_lines) results.append(result) return results # ========================================== # FINE-TUNING SUPPORT # ========================================== def add_training_example( self, image_data: bytes, ground_truth: str, teacher_id: str ) -> str: """ Add a training example for fine-tuning. Args: image_data: Image bytes ground_truth: Correct text (teacher-provided) teacher_id: ID of the teacher providing correction Returns: Example ID """ import uuid from datetime import datetime example_id = str(uuid.uuid4()) # Save image image_path = TRAINING_DATA_DIR / f"{example_id}.png" with open(image_path, "wb") as f: f.write(image_data) # Save metadata example = TrainingExample( image_path=str(image_path), ground_truth=ground_truth, teacher_id=teacher_id, created_at=datetime.utcnow().isoformat() ) meta_path = TRAINING_DATA_DIR / f"{example_id}.json" with open(meta_path, "w") as f: json.dump(example.__dict__, f, indent=2) logger.info(f"Training example added: {example_id}") return example_id def get_training_examples(self, teacher_id: Optional[str] = None) -> List[TrainingExample]: """Get all training examples, optionally filtered by teacher.""" examples = [] for meta_file in TRAINING_DATA_DIR.glob("*.json"): with open(meta_file) as f: data = json.load(f) example = TrainingExample(**data) if teacher_id is None or example.teacher_id == teacher_id: examples.append(example) return examples async def fine_tune( self, teacher_id: Optional[str] = None, epochs: int = 3, learning_rate: float = 5e-5 ) -> Dict: """ Fine-tune the model with collected training examples. Uses LoRA for efficient fine-tuning. Args: teacher_id: If provided, only use examples from this teacher epochs: Number of training epochs learning_rate: Learning rate for fine-tuning Returns: Training statistics """ examples = self.get_training_examples(teacher_id) if len(examples) < 10: return { "status": "error", "message": f"Need at least 10 examples, have {len(examples)}" } try: from peft import LoraConfig, get_peft_model, TaskType from transformers import Trainer, TrainingArguments from PIL import Image import torch self._load_model() logger.info(f"Starting fine-tuning with {len(examples)} examples") # Configure LoRA lora_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM, r=16, # LoRA rank lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"] # Attention layers ) # Apply LoRA model = get_peft_model(self._model, lora_config) # Prepare dataset class OCRDataset(torch.utils.data.Dataset): def __init__(self, examples, processor): self.examples = examples self.processor = processor def __len__(self): return len(self.examples) def __getitem__(self, idx): ex = self.examples[idx] image = Image.open(ex.image_path).convert("RGB") pixel_values = self.processor( images=image, return_tensors="pt" ).pixel_values.squeeze() labels = self.processor.tokenizer( ex.ground_truth, return_tensors="pt", padding="max_length", max_length=128 ).input_ids.squeeze() return { "pixel_values": pixel_values, "labels": labels } dataset = OCRDataset(examples, self._processor) # Training arguments output_dir = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter" training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=epochs, per_device_train_batch_size=4, learning_rate=learning_rate, save_strategy="epoch", logging_steps=10, remove_unused_columns=False, ) # Train trainer = Trainer( model=model, args=training_args, train_dataset=dataset, ) train_result = trainer.train() # Save adapter model.save_pretrained(str(output_dir)) # Reload model with new adapter self._model = None self._load_model() return { "status": "success", "examples_used": len(examples), "epochs": epochs, "adapter_path": str(output_dir), "train_loss": train_result.training_loss } except ImportError as e: logger.error(f"Missing dependencies for fine-tuning: {e}") return { "status": "error", "message": f"Missing dependencies: {e}. Install with: pip install peft" } except Exception as e: logger.error(f"Fine-tuning failed: {e}") return { "status": "error", "message": str(e) } def get_model_info(self) -> Dict: """Get information about the loaded model.""" adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter" return { "model_name": self.model_name, "model_id": self.model_id, "device": self.device, "is_loaded": self._model is not None, "has_lora_adapter": adapter_path.exists(), "lora_adapter_path": str(adapter_path) if adapter_path.exists() else None, "training_examples_count": len(list(TRAINING_DATA_DIR.glob("*.json"))), } # Singleton instance _trocr_service: Optional[TrOCRService] = None def get_trocr_service(model_name: str = "trocr-base") -> TrOCRService: """Get or create the TrOCR service singleton.""" global _trocr_service if _trocr_service is None or _trocr_service.model_name != model_name: _trocr_service = TrOCRService(model_name=model_name) return _trocr_service