fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
577
backend/klausur/services/trocr_service.py
Normal file
577
backend/klausur/services/trocr_service.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
TrOCR Service for Handwriting Recognition.
|
||||
|
||||
Uses Microsoft's TrOCR model for extracting handwritten text from exam images.
|
||||
Supports fine-tuning with teacher corrections via LoRA adapters.
|
||||
|
||||
PRIVACY BY DESIGN:
|
||||
- All processing happens locally
|
||||
- No data sent to external services
|
||||
- Fine-tuning data stays on-premise
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model paths
|
||||
MODEL_CACHE_DIR = Path(os.environ.get("TROCR_CACHE_DIR", "/app/models/trocr"))
|
||||
LORA_ADAPTERS_DIR = Path(os.environ.get("TROCR_LORA_DIR", "/app/models/trocr/lora"))
|
||||
TRAINING_DATA_DIR = Path(os.environ.get("TROCR_TRAINING_DIR", "/app/data/trocr_training"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Result from TrOCR extraction."""
|
||||
text: str
|
||||
confidence: float
|
||||
bounding_boxes: List[Dict] # [{"x": 0, "y": 0, "w": 100, "h": 20, "text": "..."}]
|
||||
processing_time_ms: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingExample:
|
||||
"""A single training example for fine-tuning."""
|
||||
image_path: str
|
||||
ground_truth: str
|
||||
teacher_id: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class TrOCRService:
|
||||
"""
|
||||
Handwriting recognition service using TrOCR.
|
||||
|
||||
Features:
|
||||
- Line-by-line handwriting extraction
|
||||
- Confidence scoring
|
||||
- LoRA fine-tuning support
|
||||
- Batch processing
|
||||
"""
|
||||
|
||||
# Available models (from smallest to largest)
|
||||
MODELS = {
|
||||
"trocr-small": "microsoft/trocr-small-handwritten",
|
||||
"trocr-base": "microsoft/trocr-base-handwritten", # Recommended
|
||||
"trocr-large": "microsoft/trocr-large-handwritten",
|
||||
}
|
||||
|
||||
def __init__(self, model_name: str = "trocr-base", device: str = "auto"):
|
||||
"""
|
||||
Initialize TrOCR service.
|
||||
|
||||
Args:
|
||||
model_name: One of "trocr-small", "trocr-base", "trocr-large"
|
||||
device: "cpu", "cuda", "mps" (Apple Silicon), or "auto"
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model_id = self.MODELS.get(model_name, self.MODELS["trocr-base"])
|
||||
self.device = self._get_device(device)
|
||||
|
||||
self._processor = None
|
||||
self._model = None
|
||||
self._lora_adapter = None
|
||||
|
||||
# Create directories
|
||||
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LORA_ADAPTERS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"TrOCR Service initialized: model={model_name}, device={self.device}")
|
||||
|
||||
def _get_device(self, device: str) -> str:
|
||||
"""Determine the best device for inference."""
|
||||
if device != "auto":
|
||||
return device
|
||||
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy-load the TrOCR model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
import torch
|
||||
|
||||
logger.info(f"Loading TrOCR model: {self.model_id}")
|
||||
|
||||
self._processor = TrOCRProcessor.from_pretrained(
|
||||
self.model_id,
|
||||
cache_dir=str(MODEL_CACHE_DIR)
|
||||
)
|
||||
|
||||
self._model = VisionEncoderDecoderModel.from_pretrained(
|
||||
self.model_id,
|
||||
cache_dir=str(MODEL_CACHE_DIR)
|
||||
)
|
||||
|
||||
# Move to device
|
||||
if self.device == "cuda":
|
||||
self._model = self._model.cuda()
|
||||
elif self.device == "mps":
|
||||
self._model = self._model.to("mps")
|
||||
|
||||
# Load LoRA adapter if exists
|
||||
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
||||
if adapter_path.exists():
|
||||
self._load_lora_adapter(adapter_path)
|
||||
|
||||
logger.info(f"TrOCR model loaded successfully on {self.device}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependencies: {e}")
|
||||
logger.error("Install with: pip install transformers torch pillow")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model: {e}")
|
||||
raise
|
||||
|
||||
def _load_lora_adapter(self, adapter_path: Path):
|
||||
"""Load a LoRA adapter for fine-tuned model."""
|
||||
try:
|
||||
from peft import PeftModel
|
||||
|
||||
logger.info(f"Loading LoRA adapter from {adapter_path}")
|
||||
self._model = PeftModel.from_pretrained(self._model, str(adapter_path))
|
||||
self._lora_adapter = str(adapter_path)
|
||||
logger.info("LoRA adapter loaded successfully")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("peft not installed, skipping LoRA adapter")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load LoRA adapter: {e}")
|
||||
|
||||
async def extract_text(
|
||||
self,
|
||||
image_data: bytes,
|
||||
detect_lines: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Extract handwritten text from an image.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes (PNG, JPG, etc.)
|
||||
detect_lines: If True, detect and process individual lines
|
||||
|
||||
Returns:
|
||||
OCRResult with extracted text and confidence
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
self._load_model()
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load image
|
||||
image = Image.open(BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if detect_lines:
|
||||
# Detect text lines and process each
|
||||
lines, bboxes = await self._detect_and_extract_lines(image)
|
||||
text = "\n".join(lines)
|
||||
confidence = 0.85 # Average confidence estimate
|
||||
else:
|
||||
# Process whole image
|
||||
text, confidence = await self._extract_single(image)
|
||||
bboxes = []
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
bounding_boxes=bboxes,
|
||||
processing_time_ms=processing_time_ms
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR extraction failed: {e}")
|
||||
return OCRResult(
|
||||
text="",
|
||||
confidence=0.0,
|
||||
bounding_boxes=[],
|
||||
processing_time_ms=int((time.time() - start_time) * 1000)
|
||||
)
|
||||
|
||||
async def _extract_single(self, image) -> Tuple[str, float]:
|
||||
"""Extract text from a single image (no line detection)."""
|
||||
import torch
|
||||
|
||||
# Preprocess
|
||||
pixel_values = self._processor(
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
).pixel_values
|
||||
|
||||
if self.device == "cuda":
|
||||
pixel_values = pixel_values.cuda()
|
||||
elif self.device == "mps":
|
||||
pixel_values = pixel_values.to("mps")
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = self._model.generate(
|
||||
pixel_values,
|
||||
max_length=128,
|
||||
num_beams=4,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True
|
||||
)
|
||||
|
||||
# Decode
|
||||
text = self._processor.batch_decode(
|
||||
generated_ids.sequences,
|
||||
skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
# Estimate confidence from generation scores
|
||||
confidence = self._estimate_confidence(generated_ids)
|
||||
|
||||
return text.strip(), confidence
|
||||
|
||||
async def _detect_and_extract_lines(self, image) -> Tuple[List[str], List[Dict]]:
|
||||
"""Detect text lines and extract each separately."""
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Convert to numpy for line detection
|
||||
img_array = np.array(image.convert("L")) # Grayscale
|
||||
|
||||
# Simple horizontal projection for line detection
|
||||
lines_y = self._detect_line_positions(img_array)
|
||||
|
||||
if not lines_y:
|
||||
# Fallback: process whole image
|
||||
text, _ = await self._extract_single(image)
|
||||
return [text], []
|
||||
|
||||
# Extract each line
|
||||
results = []
|
||||
bboxes = []
|
||||
width = image.width
|
||||
|
||||
for i, (y_start, y_end) in enumerate(lines_y):
|
||||
# Crop line
|
||||
line_img = image.crop((0, y_start, width, y_end))
|
||||
|
||||
# Ensure minimum height
|
||||
if line_img.height < 20:
|
||||
continue
|
||||
|
||||
# Extract text
|
||||
text, conf = await self._extract_single(line_img)
|
||||
|
||||
if text.strip():
|
||||
results.append(text)
|
||||
bboxes.append({
|
||||
"x": 0,
|
||||
"y": y_start,
|
||||
"w": width,
|
||||
"h": y_end - y_start,
|
||||
"text": text,
|
||||
"confidence": conf
|
||||
})
|
||||
|
||||
return results, bboxes
|
||||
|
||||
def _detect_line_positions(self, img_array) -> List[Tuple[int, int]]:
|
||||
"""Detect horizontal text line positions using projection profile."""
|
||||
import numpy as np
|
||||
|
||||
# Horizontal projection (sum of pixels per row)
|
||||
projection = np.sum(255 - img_array, axis=1)
|
||||
|
||||
# Threshold to find text rows
|
||||
threshold = np.max(projection) * 0.1
|
||||
text_rows = projection > threshold
|
||||
|
||||
# Find line boundaries
|
||||
lines = []
|
||||
in_line = False
|
||||
line_start = 0
|
||||
|
||||
for i, is_text in enumerate(text_rows):
|
||||
if is_text and not in_line:
|
||||
in_line = True
|
||||
line_start = max(0, i - 5) # Add padding
|
||||
elif not is_text and in_line:
|
||||
in_line = False
|
||||
line_end = min(len(text_rows) - 1, i + 5) # Add padding
|
||||
if line_end - line_start > 15: # Minimum line height
|
||||
lines.append((line_start, line_end))
|
||||
|
||||
# Handle last line
|
||||
if in_line:
|
||||
lines.append((line_start, len(text_rows) - 1))
|
||||
|
||||
return lines
|
||||
|
||||
def _estimate_confidence(self, generated_output) -> float:
|
||||
"""Estimate confidence from generation scores."""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if hasattr(generated_output, 'scores') and generated_output.scores:
|
||||
# Average probability of selected tokens
|
||||
probs = []
|
||||
for score in generated_output.scores:
|
||||
prob = torch.softmax(score, dim=-1).max().item()
|
||||
probs.append(prob)
|
||||
return sum(probs) / len(probs) if probs else 0.5
|
||||
return 0.75 # Default confidence
|
||||
except Exception:
|
||||
return 0.75
|
||||
|
||||
async def batch_extract(
|
||||
self,
|
||||
images: List[bytes],
|
||||
detect_lines: bool = True
|
||||
) -> List[OCRResult]:
|
||||
"""
|
||||
Extract text from multiple images.
|
||||
|
||||
Args:
|
||||
images: List of image bytes
|
||||
detect_lines: If True, detect lines in each image
|
||||
|
||||
Returns:
|
||||
List of OCRResult
|
||||
"""
|
||||
results = []
|
||||
for img_data in images:
|
||||
result = await self.extract_text(img_data, detect_lines)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# ==========================================
|
||||
# FINE-TUNING SUPPORT
|
||||
# ==========================================
|
||||
|
||||
def add_training_example(
|
||||
self,
|
||||
image_data: bytes,
|
||||
ground_truth: str,
|
||||
teacher_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Add a training example for fine-tuning.
|
||||
|
||||
Args:
|
||||
image_data: Image bytes
|
||||
ground_truth: Correct text (teacher-provided)
|
||||
teacher_id: ID of the teacher providing correction
|
||||
|
||||
Returns:
|
||||
Example ID
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
example_id = str(uuid.uuid4())
|
||||
|
||||
# Save image
|
||||
image_path = TRAINING_DATA_DIR / f"{example_id}.png"
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
# Save metadata
|
||||
example = TrainingExample(
|
||||
image_path=str(image_path),
|
||||
ground_truth=ground_truth,
|
||||
teacher_id=teacher_id,
|
||||
created_at=datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
meta_path = TRAINING_DATA_DIR / f"{example_id}.json"
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(example.__dict__, f, indent=2)
|
||||
|
||||
logger.info(f"Training example added: {example_id}")
|
||||
return example_id
|
||||
|
||||
def get_training_examples(self, teacher_id: Optional[str] = None) -> List[TrainingExample]:
|
||||
"""Get all training examples, optionally filtered by teacher."""
|
||||
examples = []
|
||||
|
||||
for meta_file in TRAINING_DATA_DIR.glob("*.json"):
|
||||
with open(meta_file) as f:
|
||||
data = json.load(f)
|
||||
example = TrainingExample(**data)
|
||||
|
||||
if teacher_id is None or example.teacher_id == teacher_id:
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
async def fine_tune(
|
||||
self,
|
||||
teacher_id: Optional[str] = None,
|
||||
epochs: int = 3,
|
||||
learning_rate: float = 5e-5
|
||||
) -> Dict:
|
||||
"""
|
||||
Fine-tune the model with collected training examples.
|
||||
|
||||
Uses LoRA for efficient fine-tuning.
|
||||
|
||||
Args:
|
||||
teacher_id: If provided, only use examples from this teacher
|
||||
epochs: Number of training epochs
|
||||
learning_rate: Learning rate for fine-tuning
|
||||
|
||||
Returns:
|
||||
Training statistics
|
||||
"""
|
||||
examples = self.get_training_examples(teacher_id)
|
||||
|
||||
if len(examples) < 10:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Need at least 10 examples, have {len(examples)}"
|
||||
}
|
||||
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model, TaskType
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
self._load_model()
|
||||
|
||||
logger.info(f"Starting fine-tuning with {len(examples)} examples")
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_2_SEQ_LM,
|
||||
r=16, # LoRA rank
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj"] # Attention layers
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(self._model, lora_config)
|
||||
|
||||
# Prepare dataset
|
||||
class OCRDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, examples, processor):
|
||||
self.examples = examples
|
||||
self.processor = processor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ex = self.examples[idx]
|
||||
image = Image.open(ex.image_path).convert("RGB")
|
||||
|
||||
pixel_values = self.processor(
|
||||
images=image, return_tensors="pt"
|
||||
).pixel_values.squeeze()
|
||||
|
||||
labels = self.processor.tokenizer(
|
||||
ex.ground_truth,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=128
|
||||
).input_ids.squeeze()
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"labels": labels
|
||||
}
|
||||
|
||||
dataset = OCRDataset(examples, self._processor)
|
||||
|
||||
# Training arguments
|
||||
output_dir = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=4,
|
||||
learning_rate=learning_rate,
|
||||
save_strategy="epoch",
|
||||
logging_steps=10,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
# Save adapter
|
||||
model.save_pretrained(str(output_dir))
|
||||
|
||||
# Reload model with new adapter
|
||||
self._model = None
|
||||
self._load_model()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"examples_used": len(examples),
|
||||
"epochs": epochs,
|
||||
"adapter_path": str(output_dir),
|
||||
"train_loss": train_result.training_loss
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependencies for fine-tuning: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Missing dependencies: {e}. Install with: pip install peft"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Fine-tuning failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Get information about the loaded model."""
|
||||
adapter_path = LORA_ADAPTERS_DIR / f"{self.model_name}_adapter"
|
||||
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model_id": self.model_id,
|
||||
"device": self.device,
|
||||
"is_loaded": self._model is not None,
|
||||
"has_lora_adapter": adapter_path.exists(),
|
||||
"lora_adapter_path": str(adapter_path) if adapter_path.exists() else None,
|
||||
"training_examples_count": len(list(TRAINING_DATA_DIR.glob("*.json"))),
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_trocr_service: Optional[TrOCRService] = None
|
||||
|
||||
|
||||
def get_trocr_service(model_name: str = "trocr-base") -> TrOCRService:
|
||||
"""Get or create the TrOCR service singleton."""
|
||||
global _trocr_service
|
||||
if _trocr_service is None or _trocr_service.model_name != model_name:
|
||||
_trocr_service = TrOCRService(model_name=model_name)
|
||||
return _trocr_service
|
||||
Reference in New Issue
Block a user