A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
578 lines
18 KiB
Python
578 lines
18 KiB
Python
"""
|
|
TrOCR Service for Handwriting Recognition.
|
|
|
|
Uses Microsoft's TrOCR model for extracting handwritten text from exam images.
|
|
Supports fine-tuning with teacher corrections via LoRA adapters.
|
|
|
|
PRIVACY BY DESIGN:
|
|
- All processing happens locally
|
|
- No data sent to external services
|
|
- Fine-tuning data stays on-premise
|
|
"""
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict, Tuple
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Model paths
|
|
MODEL_CACHE_DIR = Path(os.environ.get("TROCR_CACHE_DIR", "/app/models/trocr"))
|
|
LORA_ADAPTERS_DIR = Path(os.environ.get("TROCR_LORA_DIR", "/app/models/trocr/lora"))
|
|
TRAINING_DATA_DIR = Path(os.environ.get("TROCR_TRAINING_DIR", "/app/data/trocr_training"))
|
|
|
|
|
|
@dataclass
|
|
class OCRResult:
|
|
"""Result from TrOCR extraction."""
|
|
text: str
|
|
confidence: float
|
|
bounding_boxes: List[Dict] # [{"x": 0, "y": 0, "w": 100, "h": 20, "text": "..."}]
|
|
processing_time_ms: int
|
|
|
|
|
|
@dataclass
|
|
class TrainingExample:
|
|
"""A single training example for fine-tuning."""
|
|
image_path: str
|
|
ground_truth: str
|
|
teacher_id: str
|
|
created_at: str
|
|
|
|
|
|
class TrOCRService:
|
|
"""
|
|
Handwriting recognition service using TrOCR.
|
|
|
|
Features:
|
|
- Line-by-line handwriting extraction
|
|
- Confidence scoring
|
|
- LoRA fine-tuning support
|
|
- Batch processing
|
|
"""
|
|
|
|
# Available models (from smallest to largest)
|
|
MODELS = {
|
|
"trocr-small": "microsoft/trocr-small-handwritten",
|
|
"trocr-base": "microsoft/trocr-base-handwritten", # Recommended
|
|
"trocr-large": "microsoft/trocr-large-handwritten",
|
|
}
|
|
|
|
def __init__(self, model_name: str = "trocr-base", device: str = "auto"):
|
|
"""
|
|
Initialize TrOCR service.
|
|
|
|
Args:
|
|
model_name: One of "trocr-small", "trocr-base", "trocr-large"
|
|
device: "cpu", "cuda", "mps" (Apple Silicon), or "auto"
|
|
"""
|
|
self.model_name = model_name
|
|
self.model_id = self.MODELS.get(model_name, self.MODELS["trocr-base"])
|
|
self.device = self._get_device(device)
|
|
|
|
self._processor = None
|
|
self._model = None
|
|
self._lora_adapter = None
|
|
|
|
# Create directories
|
|
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
LORA_ADAPTERS_DIR.mkdir(parents=True, exist_ok=True)
|
|
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"TrOCR Service initialized: model={model_name}, device={self.device}")
|
|
|
|
def _get_device(self, device: str) -> str:
|
|
"""Determine the best device for inference."""
|
|
if device != "auto":
|
|
return device
|
|
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
except ImportError:
|
|
return "cpu"
|
|
|
|
def _load_model(self):
|
|
"""Lazy-load the TrOCR model."""
|
|
if self._model is not None:
|
|
return
|
|
|
|
try:
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
import torch
|
|
|
|
logger.info(f"Loading TrOCR model: {self.model_id}")
|
|
|
|
self._processor = TrOCRProcessor.from_pretrained(
|
|
self.model_id,
|
|
cache_dir=str(MODEL_CACHE_DIR)
|
|
)
|
|
|
|
self._model = VisionEncoderDecoderModel.from_pretrained(
|
|
self.model_id,
|
|
cache_dir=str(MODEL_CACHE_DIR)
|
|
)
|
|
|
|
# Move to device
|
|
if self.device == "cuda":
|
|
self._model = self._model.cuda()
|
|
elif self.device == "mps":
|
|
self._model = self._model.to("mps")
|
|
|
|
# Load LoRA adapter if exists
|
|
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
|
if adapter_path.exists():
|
|
self._load_lora_adapter(adapter_path)
|
|
|
|
logger.info(f"TrOCR model loaded successfully on {self.device}")
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Missing dependencies: {e}")
|
|
logger.error("Install with: pip install transformers torch pillow")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to load TrOCR model: {e}")
|
|
raise
|
|
|
|
def _load_lora_adapter(self, adapter_path: Path):
|
|
"""Load a LoRA adapter for fine-tuned model."""
|
|
try:
|
|
from peft import PeftModel
|
|
|
|
logger.info(f"Loading LoRA adapter from {adapter_path}")
|
|
self._model = PeftModel.from_pretrained(self._model, str(adapter_path))
|
|
self._lora_adapter = str(adapter_path)
|
|
logger.info("LoRA adapter loaded successfully")
|
|
|
|
except ImportError:
|
|
logger.warning("peft not installed, skipping LoRA adapter")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load LoRA adapter: {e}")
|
|
|
|
async def extract_text(
|
|
self,
|
|
image_data: bytes,
|
|
detect_lines: bool = True
|
|
) -> OCRResult:
|
|
"""
|
|
Extract handwritten text from an image.
|
|
|
|
Args:
|
|
image_data: Raw image bytes (PNG, JPG, etc.)
|
|
detect_lines: If True, detect and process individual lines
|
|
|
|
Returns:
|
|
OCRResult with extracted text and confidence
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
self._load_model()
|
|
|
|
try:
|
|
from PIL import Image
|
|
import torch
|
|
|
|
# Load image
|
|
image = Image.open(BytesIO(image_data)).convert("RGB")
|
|
|
|
if detect_lines:
|
|
# Detect text lines and process each
|
|
lines, bboxes = await self._detect_and_extract_lines(image)
|
|
text = "\n".join(lines)
|
|
confidence = 0.85 # Average confidence estimate
|
|
else:
|
|
# Process whole image
|
|
text, confidence = await self._extract_single(image)
|
|
bboxes = []
|
|
|
|
processing_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
return OCRResult(
|
|
text=text,
|
|
confidence=confidence,
|
|
bounding_boxes=bboxes,
|
|
processing_time_ms=processing_time_ms
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"OCR extraction failed: {e}")
|
|
return OCRResult(
|
|
text="",
|
|
confidence=0.0,
|
|
bounding_boxes=[],
|
|
processing_time_ms=int((time.time() - start_time) * 1000)
|
|
)
|
|
|
|
async def _extract_single(self, image) -> Tuple[str, float]:
|
|
"""Extract text from a single image (no line detection)."""
|
|
import torch
|
|
|
|
# Preprocess
|
|
pixel_values = self._processor(
|
|
images=image,
|
|
return_tensors="pt"
|
|
).pixel_values
|
|
|
|
if self.device == "cuda":
|
|
pixel_values = pixel_values.cuda()
|
|
elif self.device == "mps":
|
|
pixel_values = pixel_values.to("mps")
|
|
|
|
# Generate
|
|
with torch.no_grad():
|
|
generated_ids = self._model.generate(
|
|
pixel_values,
|
|
max_length=128,
|
|
num_beams=4,
|
|
return_dict_in_generate=True,
|
|
output_scores=True
|
|
)
|
|
|
|
# Decode
|
|
text = self._processor.batch_decode(
|
|
generated_ids.sequences,
|
|
skip_special_tokens=True
|
|
)[0]
|
|
|
|
# Estimate confidence from generation scores
|
|
confidence = self._estimate_confidence(generated_ids)
|
|
|
|
return text.strip(), confidence
|
|
|
|
async def _detect_and_extract_lines(self, image) -> Tuple[List[str], List[Dict]]:
|
|
"""Detect text lines and extract each separately."""
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
# Convert to numpy for line detection
|
|
img_array = np.array(image.convert("L")) # Grayscale
|
|
|
|
# Simple horizontal projection for line detection
|
|
lines_y = self._detect_line_positions(img_array)
|
|
|
|
if not lines_y:
|
|
# Fallback: process whole image
|
|
text, _ = await self._extract_single(image)
|
|
return [text], []
|
|
|
|
# Extract each line
|
|
results = []
|
|
bboxes = []
|
|
width = image.width
|
|
|
|
for i, (y_start, y_end) in enumerate(lines_y):
|
|
# Crop line
|
|
line_img = image.crop((0, y_start, width, y_end))
|
|
|
|
# Ensure minimum height
|
|
if line_img.height < 20:
|
|
continue
|
|
|
|
# Extract text
|
|
text, conf = await self._extract_single(line_img)
|
|
|
|
if text.strip():
|
|
results.append(text)
|
|
bboxes.append({
|
|
"x": 0,
|
|
"y": y_start,
|
|
"w": width,
|
|
"h": y_end - y_start,
|
|
"text": text,
|
|
"confidence": conf
|
|
})
|
|
|
|
return results, bboxes
|
|
|
|
def _detect_line_positions(self, img_array) -> List[Tuple[int, int]]:
|
|
"""Detect horizontal text line positions using projection profile."""
|
|
import numpy as np
|
|
|
|
# Horizontal projection (sum of pixels per row)
|
|
projection = np.sum(255 - img_array, axis=1)
|
|
|
|
# Threshold to find text rows
|
|
threshold = np.max(projection) * 0.1
|
|
text_rows = projection > threshold
|
|
|
|
# Find line boundaries
|
|
lines = []
|
|
in_line = False
|
|
line_start = 0
|
|
|
|
for i, is_text in enumerate(text_rows):
|
|
if is_text and not in_line:
|
|
in_line = True
|
|
line_start = max(0, i - 5) # Add padding
|
|
elif not is_text and in_line:
|
|
in_line = False
|
|
line_end = min(len(text_rows) - 1, i + 5) # Add padding
|
|
if line_end - line_start > 15: # Minimum line height
|
|
lines.append((line_start, line_end))
|
|
|
|
# Handle last line
|
|
if in_line:
|
|
lines.append((line_start, len(text_rows) - 1))
|
|
|
|
return lines
|
|
|
|
def _estimate_confidence(self, generated_output) -> float:
|
|
"""Estimate confidence from generation scores."""
|
|
try:
|
|
import torch
|
|
|
|
if hasattr(generated_output, 'scores') and generated_output.scores:
|
|
# Average probability of selected tokens
|
|
probs = []
|
|
for score in generated_output.scores:
|
|
prob = torch.softmax(score, dim=-1).max().item()
|
|
probs.append(prob)
|
|
return sum(probs) / len(probs) if probs else 0.5
|
|
return 0.75 # Default confidence
|
|
except Exception:
|
|
return 0.75
|
|
|
|
async def batch_extract(
|
|
self,
|
|
images: List[bytes],
|
|
detect_lines: bool = True
|
|
) -> List[OCRResult]:
|
|
"""
|
|
Extract text from multiple images.
|
|
|
|
Args:
|
|
images: List of image bytes
|
|
detect_lines: If True, detect lines in each image
|
|
|
|
Returns:
|
|
List of OCRResult
|
|
"""
|
|
results = []
|
|
for img_data in images:
|
|
result = await self.extract_text(img_data, detect_lines)
|
|
results.append(result)
|
|
return results
|
|
|
|
# ==========================================
|
|
# FINE-TUNING SUPPORT
|
|
# ==========================================
|
|
|
|
def add_training_example(
|
|
self,
|
|
image_data: bytes,
|
|
ground_truth: str,
|
|
teacher_id: str
|
|
) -> str:
|
|
"""
|
|
Add a training example for fine-tuning.
|
|
|
|
Args:
|
|
image_data: Image bytes
|
|
ground_truth: Correct text (teacher-provided)
|
|
teacher_id: ID of the teacher providing correction
|
|
|
|
Returns:
|
|
Example ID
|
|
"""
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
example_id = str(uuid.uuid4())
|
|
|
|
# Save image
|
|
image_path = TRAINING_DATA_DIR / f"{example_id}.png"
|
|
with open(image_path, "wb") as f:
|
|
f.write(image_data)
|
|
|
|
# Save metadata
|
|
example = TrainingExample(
|
|
image_path=str(image_path),
|
|
ground_truth=ground_truth,
|
|
teacher_id=teacher_id,
|
|
created_at=datetime.utcnow().isoformat()
|
|
)
|
|
|
|
meta_path = TRAINING_DATA_DIR / f"{example_id}.json"
|
|
with open(meta_path, "w") as f:
|
|
json.dump(example.__dict__, f, indent=2)
|
|
|
|
logger.info(f"Training example added: {example_id}")
|
|
return example_id
|
|
|
|
def get_training_examples(self, teacher_id: Optional[str] = None) -> List[TrainingExample]:
|
|
"""Get all training examples, optionally filtered by teacher."""
|
|
examples = []
|
|
|
|
for meta_file in TRAINING_DATA_DIR.glob("*.json"):
|
|
with open(meta_file) as f:
|
|
data = json.load(f)
|
|
example = TrainingExample(**data)
|
|
|
|
if teacher_id is None or example.teacher_id == teacher_id:
|
|
examples.append(example)
|
|
|
|
return examples
|
|
|
|
async def fine_tune(
|
|
self,
|
|
teacher_id: Optional[str] = None,
|
|
epochs: int = 3,
|
|
learning_rate: float = 5e-5
|
|
) -> Dict:
|
|
"""
|
|
Fine-tune the model with collected training examples.
|
|
|
|
Uses LoRA for efficient fine-tuning.
|
|
|
|
Args:
|
|
teacher_id: If provided, only use examples from this teacher
|
|
epochs: Number of training epochs
|
|
learning_rate: Learning rate for fine-tuning
|
|
|
|
Returns:
|
|
Training statistics
|
|
"""
|
|
examples = self.get_training_examples(teacher_id)
|
|
|
|
if len(examples) < 10:
|
|
return {
|
|
"status": "error",
|
|
"message": f"Need at least 10 examples, have {len(examples)}"
|
|
}
|
|
|
|
try:
|
|
from peft import LoraConfig, get_peft_model, TaskType
|
|
from transformers import Trainer, TrainingArguments
|
|
from PIL import Image
|
|
import torch
|
|
|
|
self._load_model()
|
|
|
|
logger.info(f"Starting fine-tuning with {len(examples)} examples")
|
|
|
|
# Configure LoRA
|
|
lora_config = LoraConfig(
|
|
task_type=TaskType.SEQ_2_SEQ_LM,
|
|
r=16, # LoRA rank
|
|
lora_alpha=32,
|
|
lora_dropout=0.1,
|
|
target_modules=["q_proj", "v_proj"] # Attention layers
|
|
)
|
|
|
|
# Apply LoRA
|
|
model = get_peft_model(self._model, lora_config)
|
|
|
|
# Prepare dataset
|
|
class OCRDataset(torch.utils.data.Dataset):
|
|
def __init__(self, examples, processor):
|
|
self.examples = examples
|
|
self.processor = processor
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, idx):
|
|
ex = self.examples[idx]
|
|
image = Image.open(ex.image_path).convert("RGB")
|
|
|
|
pixel_values = self.processor(
|
|
images=image, return_tensors="pt"
|
|
).pixel_values.squeeze()
|
|
|
|
labels = self.processor.tokenizer(
|
|
ex.ground_truth,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
max_length=128
|
|
).input_ids.squeeze()
|
|
|
|
return {
|
|
"pixel_values": pixel_values,
|
|
"labels": labels
|
|
}
|
|
|
|
dataset = OCRDataset(examples, self._processor)
|
|
|
|
# Training arguments
|
|
output_dir = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
|
training_args = TrainingArguments(
|
|
output_dir=str(output_dir),
|
|
num_train_epochs=epochs,
|
|
per_device_train_batch_size=4,
|
|
learning_rate=learning_rate,
|
|
save_strategy="epoch",
|
|
logging_steps=10,
|
|
remove_unused_columns=False,
|
|
)
|
|
|
|
# Train
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
train_result = trainer.train()
|
|
|
|
# Save adapter
|
|
model.save_pretrained(str(output_dir))
|
|
|
|
# Reload model with new adapter
|
|
self._model = None
|
|
self._load_model()
|
|
|
|
return {
|
|
"status": "success",
|
|
"examples_used": len(examples),
|
|
"epochs": epochs,
|
|
"adapter_path": str(output_dir),
|
|
"train_loss": train_result.training_loss
|
|
}
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Missing dependencies for fine-tuning: {e}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"Missing dependencies: {e}. Install with: pip install peft"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Fine-tuning failed: {e}")
|
|
return {
|
|
"status": "error",
|
|
"message": str(e)
|
|
}
|
|
|
|
def get_model_info(self) -> Dict:
|
|
"""Get information about the loaded model."""
|
|
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
|
|
|
return {
|
|
"model_name": self.model_name,
|
|
"model_id": self.model_id,
|
|
"device": self.device,
|
|
"is_loaded": self._model is not None,
|
|
"has_lora_adapter": adapter_path.exists(),
|
|
"lora_adapter_path": str(adapter_path) if adapter_path.exists() else None,
|
|
"training_examples_count": len(list(TRAINING_DATA_DIR.glob("*.json"))),
|
|
}
|
|
|
|
|
|
# Singleton instance
|
|
_trocr_service: Optional[TrOCRService] = None
|
|
|
|
|
|
def get_trocr_service(model_name: str = "trocr-base") -> TrOCRService:
|
|
"""Get or create the TrOCR service singleton."""
|
|
global _trocr_service
|
|
if _trocr_service is None or _trocr_service.model_name != model_name:
|
|
_trocr_service = TrOCRService(model_name=model_name)
|
|
return _trocr_service
|