This repository has been archived on 2026-02-15. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
breakpilot-pwa/backend/klausur/services/trocr_service.py
Benjamin Admin bfdaf63ba9 fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:51:32 +01:00

578 lines
18 KiB
Python

"""
TrOCR Service for Handwriting Recognition.
Uses Microsoft's TrOCR model for extracting handwritten text from exam images.
Supports fine-tuning with teacher corrections via LoRA adapters.
PRIVACY BY DESIGN:
- All processing happens locally
- No data sent to external services
- Fine-tuning data stays on-premise
"""
import logging
import os
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass
from io import BytesIO
import json
logger = logging.getLogger(__name__)
# Model paths
MODEL_CACHE_DIR = Path(os.environ.get("TROCR_CACHE_DIR", "/app/models/trocr"))
LORA_ADAPTERS_DIR = Path(os.environ.get("TROCR_LORA_DIR", "/app/models/trocr/lora"))
TRAINING_DATA_DIR = Path(os.environ.get("TROCR_TRAINING_DIR", "/app/data/trocr_training"))
@dataclass
class OCRResult:
"""Result from TrOCR extraction."""
text: str
confidence: float
bounding_boxes: List[Dict] # [{"x": 0, "y": 0, "w": 100, "h": 20, "text": "..."}]
processing_time_ms: int
@dataclass
class TrainingExample:
"""A single training example for fine-tuning."""
image_path: str
ground_truth: str
teacher_id: str
created_at: str
class TrOCRService:
"""
Handwriting recognition service using TrOCR.
Features:
- Line-by-line handwriting extraction
- Confidence scoring
- LoRA fine-tuning support
- Batch processing
"""
# Available models (from smallest to largest)
MODELS = {
"trocr-small": "microsoft/trocr-small-handwritten",
"trocr-base": "microsoft/trocr-base-handwritten", # Recommended
"trocr-large": "microsoft/trocr-large-handwritten",
}
def __init__(self, model_name: str = "trocr-base", device: str = "auto"):
"""
Initialize TrOCR service.
Args:
model_name: One of "trocr-small", "trocr-base", "trocr-large"
device: "cpu", "cuda", "mps" (Apple Silicon), or "auto"
"""
self.model_name = model_name
self.model_id = self.MODELS.get(model_name, self.MODELS["trocr-base"])
self.device = self._get_device(device)
self._processor = None
self._model = None
self._lora_adapter = None
# Create directories
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
LORA_ADAPTERS_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"TrOCR Service initialized: model={model_name}, device={self.device}")
def _get_device(self, device: str) -> str:
"""Determine the best device for inference."""
if device != "auto":
return device
try:
import torch
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
except ImportError:
return "cpu"
def _load_model(self):
"""Lazy-load the TrOCR model."""
if self._model is not None:
return
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
logger.info(f"Loading TrOCR model: {self.model_id}")
self._processor = TrOCRProcessor.from_pretrained(
self.model_id,
cache_dir=str(MODEL_CACHE_DIR)
)
self._model = VisionEncoderDecoderModel.from_pretrained(
self.model_id,
cache_dir=str(MODEL_CACHE_DIR)
)
# Move to device
if self.device == "cuda":
self._model = self._model.cuda()
elif self.device == "mps":
self._model = self._model.to("mps")
# Load LoRA adapter if exists
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
if adapter_path.exists():
self._load_lora_adapter(adapter_path)
logger.info(f"TrOCR model loaded successfully on {self.device}")
except ImportError as e:
logger.error(f"Missing dependencies: {e}")
logger.error("Install with: pip install transformers torch pillow")
raise
except Exception as e:
logger.error(f"Failed to load TrOCR model: {e}")
raise
def _load_lora_adapter(self, adapter_path: Path):
"""Load a LoRA adapter for fine-tuned model."""
try:
from peft import PeftModel
logger.info(f"Loading LoRA adapter from {adapter_path}")
self._model = PeftModel.from_pretrained(self._model, str(adapter_path))
self._lora_adapter = str(adapter_path)
logger.info("LoRA adapter loaded successfully")
except ImportError:
logger.warning("peft not installed, skipping LoRA adapter")
except Exception as e:
logger.warning(f"Failed to load LoRA adapter: {e}")
async def extract_text(
self,
image_data: bytes,
detect_lines: bool = True
) -> OCRResult:
"""
Extract handwritten text from an image.
Args:
image_data: Raw image bytes (PNG, JPG, etc.)
detect_lines: If True, detect and process individual lines
Returns:
OCRResult with extracted text and confidence
"""
import time
start_time = time.time()
self._load_model()
try:
from PIL import Image
import torch
# Load image
image = Image.open(BytesIO(image_data)).convert("RGB")
if detect_lines:
# Detect text lines and process each
lines, bboxes = await self._detect_and_extract_lines(image)
text = "\n".join(lines)
confidence = 0.85 # Average confidence estimate
else:
# Process whole image
text, confidence = await self._extract_single(image)
bboxes = []
processing_time_ms = int((time.time() - start_time) * 1000)
return OCRResult(
text=text,
confidence=confidence,
bounding_boxes=bboxes,
processing_time_ms=processing_time_ms
)
except Exception as e:
logger.error(f"OCR extraction failed: {e}")
return OCRResult(
text="",
confidence=0.0,
bounding_boxes=[],
processing_time_ms=int((time.time() - start_time) * 1000)
)
async def _extract_single(self, image) -> Tuple[str, float]:
"""Extract text from a single image (no line detection)."""
import torch
# Preprocess
pixel_values = self._processor(
images=image,
return_tensors="pt"
).pixel_values
if self.device == "cuda":
pixel_values = pixel_values.cuda()
elif self.device == "mps":
pixel_values = pixel_values.to("mps")
# Generate
with torch.no_grad():
generated_ids = self._model.generate(
pixel_values,
max_length=128,
num_beams=4,
return_dict_in_generate=True,
output_scores=True
)
# Decode
text = self._processor.batch_decode(
generated_ids.sequences,
skip_special_tokens=True
)[0]
# Estimate confidence from generation scores
confidence = self._estimate_confidence(generated_ids)
return text.strip(), confidence
async def _detect_and_extract_lines(self, image) -> Tuple[List[str], List[Dict]]:
"""Detect text lines and extract each separately."""
from PIL import Image
import numpy as np
# Convert to numpy for line detection
img_array = np.array(image.convert("L")) # Grayscale
# Simple horizontal projection for line detection
lines_y = self._detect_line_positions(img_array)
if not lines_y:
# Fallback: process whole image
text, _ = await self._extract_single(image)
return [text], []
# Extract each line
results = []
bboxes = []
width = image.width
for i, (y_start, y_end) in enumerate(lines_y):
# Crop line
line_img = image.crop((0, y_start, width, y_end))
# Ensure minimum height
if line_img.height < 20:
continue
# Extract text
text, conf = await self._extract_single(line_img)
if text.strip():
results.append(text)
bboxes.append({
"x": 0,
"y": y_start,
"w": width,
"h": y_end - y_start,
"text": text,
"confidence": conf
})
return results, bboxes
def _detect_line_positions(self, img_array) -> List[Tuple[int, int]]:
"""Detect horizontal text line positions using projection profile."""
import numpy as np
# Horizontal projection (sum of pixels per row)
projection = np.sum(255 - img_array, axis=1)
# Threshold to find text rows
threshold = np.max(projection) * 0.1
text_rows = projection > threshold
# Find line boundaries
lines = []
in_line = False
line_start = 0
for i, is_text in enumerate(text_rows):
if is_text and not in_line:
in_line = True
line_start = max(0, i - 5) # Add padding
elif not is_text and in_line:
in_line = False
line_end = min(len(text_rows) - 1, i + 5) # Add padding
if line_end - line_start > 15: # Minimum line height
lines.append((line_start, line_end))
# Handle last line
if in_line:
lines.append((line_start, len(text_rows) - 1))
return lines
def _estimate_confidence(self, generated_output) -> float:
"""Estimate confidence from generation scores."""
try:
import torch
if hasattr(generated_output, 'scores') and generated_output.scores:
# Average probability of selected tokens
probs = []
for score in generated_output.scores:
prob = torch.softmax(score, dim=-1).max().item()
probs.append(prob)
return sum(probs) / len(probs) if probs else 0.5
return 0.75 # Default confidence
except Exception:
return 0.75
async def batch_extract(
self,
images: List[bytes],
detect_lines: bool = True
) -> List[OCRResult]:
"""
Extract text from multiple images.
Args:
images: List of image bytes
detect_lines: If True, detect lines in each image
Returns:
List of OCRResult
"""
results = []
for img_data in images:
result = await self.extract_text(img_data, detect_lines)
results.append(result)
return results
# ==========================================
# FINE-TUNING SUPPORT
# ==========================================
def add_training_example(
self,
image_data: bytes,
ground_truth: str,
teacher_id: str
) -> str:
"""
Add a training example for fine-tuning.
Args:
image_data: Image bytes
ground_truth: Correct text (teacher-provided)
teacher_id: ID of the teacher providing correction
Returns:
Example ID
"""
import uuid
from datetime import datetime
example_id = str(uuid.uuid4())
# Save image
image_path = TRAINING_DATA_DIR / f"{example_id}.png"
with open(image_path, "wb") as f:
f.write(image_data)
# Save metadata
example = TrainingExample(
image_path=str(image_path),
ground_truth=ground_truth,
teacher_id=teacher_id,
created_at=datetime.utcnow().isoformat()
)
meta_path = TRAINING_DATA_DIR / f"{example_id}.json"
with open(meta_path, "w") as f:
json.dump(example.__dict__, f, indent=2)
logger.info(f"Training example added: {example_id}")
return example_id
def get_training_examples(self, teacher_id: Optional[str] = None) -> List[TrainingExample]:
"""Get all training examples, optionally filtered by teacher."""
examples = []
for meta_file in TRAINING_DATA_DIR.glob("*.json"):
with open(meta_file) as f:
data = json.load(f)
example = TrainingExample(**data)
if teacher_id is None or example.teacher_id == teacher_id:
examples.append(example)
return examples
async def fine_tune(
self,
teacher_id: Optional[str] = None,
epochs: int = 3,
learning_rate: float = 5e-5
) -> Dict:
"""
Fine-tune the model with collected training examples.
Uses LoRA for efficient fine-tuning.
Args:
teacher_id: If provided, only use examples from this teacher
epochs: Number of training epochs
learning_rate: Learning rate for fine-tuning
Returns:
Training statistics
"""
examples = self.get_training_examples(teacher_id)
if len(examples) < 10:
return {
"status": "error",
"message": f"Need at least 10 examples, have {len(examples)}"
}
try:
from peft import LoraConfig, get_peft_model, TaskType
from transformers import Trainer, TrainingArguments
from PIL import Image
import torch
self._load_model()
logger.info(f"Starting fine-tuning with {len(examples)} examples")
# Configure LoRA
lora_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=16, # LoRA rank
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"] # Attention layers
)
# Apply LoRA
model = get_peft_model(self._model, lora_config)
# Prepare dataset
class OCRDataset(torch.utils.data.Dataset):
def __init__(self, examples, processor):
self.examples = examples
self.processor = processor
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
image = Image.open(ex.image_path).convert("RGB")
pixel_values = self.processor(
images=image, return_tensors="pt"
).pixel_values.squeeze()
labels = self.processor.tokenizer(
ex.ground_truth,
return_tensors="pt",
padding="max_length",
max_length=128
).input_ids.squeeze()
return {
"pixel_values": pixel_values,
"labels": labels
}
dataset = OCRDataset(examples, self._processor)
# Training arguments
output_dir = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=epochs,
per_device_train_batch_size=4,
learning_rate=learning_rate,
save_strategy="epoch",
logging_steps=10,
remove_unused_columns=False,
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
train_result = trainer.train()
# Save adapter
model.save_pretrained(str(output_dir))
# Reload model with new adapter
self._model = None
self._load_model()
return {
"status": "success",
"examples_used": len(examples),
"epochs": epochs,
"adapter_path": str(output_dir),
"train_loss": train_result.training_loss
}
except ImportError as e:
logger.error(f"Missing dependencies for fine-tuning: {e}")
return {
"status": "error",
"message": f"Missing dependencies: {e}. Install with: pip install peft"
}
except Exception as e:
logger.error(f"Fine-tuning failed: {e}")
return {
"status": "error",
"message": str(e)
}
def get_model_info(self) -> Dict:
"""Get information about the loaded model."""
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
return {
"model_name": self.model_name,
"model_id": self.model_id,
"device": self.device,
"is_loaded": self._model is not None,
"has_lora_adapter": adapter_path.exists(),
"lora_adapter_path": str(adapter_path) if adapter_path.exists() else None,
"training_examples_count": len(list(TRAINING_DATA_DIR.glob("*.json"))),
}
# Singleton instance
_trocr_service: Optional[TrOCRService] = None
def get_trocr_service(model_name: str = "trocr-base") -> TrOCRService:
"""Get or create the TrOCR service singleton."""
global _trocr_service
if _trocr_service is None or _trocr_service.model_name != model_name:
_trocr_service = TrOCRService(model_name=model_name)
return _trocr_service