fix: Restore all files lost during destructive rebase

A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-09 09:51:32 +01:00
parent f7487ee240
commit bfdaf63ba9
2009 changed files with 749983 additions and 1731 deletions

View File

@@ -0,0 +1,80 @@
# Klausur-Service Services Package
# Grading Services
from .grading_service import (
calculate_grade_points,
calculate_raw_points,
)
# Authentication Services
from .auth_service import (
get_current_user,
)
# EH Audit Services
from .eh_service import (
log_audit,
log_eh_audit,
)
# OCR Services - Lazy imports (require PIL/cv2 which may not be installed)
# These are imported on-demand when actually used
def __getattr__(name):
"""Lazy import for optional image processing modules."""
_handwriting_exports = {
'detect_handwriting', 'detect_handwriting_regions',
'mask_to_png', 'DetectionResult'
}
_inpainting_exports = {
'inpaint_image', 'inpaint_opencv_telea', 'inpaint_opencv_ns',
'remove_handwriting', 'check_lama_available',
'InpaintingMethod', 'InpaintingResult'
}
_layout_exports = {
'reconstruct_layout', 'layout_to_fabric_json', 'reconstruct_and_clean',
'LayoutResult', 'TextElement', 'ElementType'
}
if name in _handwriting_exports:
from . import handwriting_detection
return getattr(handwriting_detection, name)
elif name in _inpainting_exports:
from . import inpainting_service
return getattr(inpainting_service, name)
elif name in _layout_exports:
from . import layout_reconstruction_service
return getattr(layout_reconstruction_service, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
# Grading
'calculate_grade_points',
'calculate_raw_points',
# Authentication
'get_current_user',
# Audit
'log_audit',
'log_eh_audit',
# Handwriting Detection (lazy)
'detect_handwriting',
'detect_handwriting_regions',
'mask_to_png',
'DetectionResult',
# Inpainting (lazy)
'inpaint_image',
'inpaint_opencv_telea',
'inpaint_opencv_ns',
'remove_handwriting',
'check_lama_available',
'InpaintingMethod',
'InpaintingResult',
# Layout Reconstruction (lazy)
'reconstruct_layout',
'layout_to_fabric_json',
'reconstruct_and_clean',
'LayoutResult',
'TextElement',
'ElementType',
]

View File

@@ -0,0 +1,46 @@
"""
Klausur-Service Authentication Service
Functions for JWT authentication and user extraction.
"""
from typing import Dict
import jwt
from fastapi import HTTPException, Request
from config import JWT_SECRET, ENVIRONMENT
def get_current_user(request: Request) -> Dict:
"""
Extract user from JWT token.
Args:
request: FastAPI Request object
Returns:
User payload dict containing user_id, role, email, etc.
Raises:
HTTPException: If token is missing, expired, or invalid
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
if ENVIRONMENT == "development":
return {
"user_id": "demo-teacher",
"role": "admin",
"email": "demo@breakpilot.app"
}
raise HTTPException(status_code=401, detail="Missing authorization header")
token = auth_header.replace("Bearer ", "")
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -0,0 +1,254 @@
"""
Donut OCR Service
Document Understanding Transformer (Donut) fuer strukturierte Dokumentenverarbeitung.
Besonders geeignet fuer:
- Tabellen
- Formulare
- Strukturierte Dokumente
- Rechnungen/Quittungen
Model: naver-clova-ix/donut-base (oder fine-tuned Variante)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import logging
from typing import Tuple, Optional
logger = logging.getLogger(__name__)
# Lazy loading for heavy dependencies
_donut_processor = None
_donut_model = None
_donut_available = None
def _check_donut_available() -> bool:
"""Check if Donut dependencies are available."""
global _donut_available
if _donut_available is not None:
return _donut_available
try:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
_donut_available = True
except ImportError as e:
logger.warning(f"Donut dependencies not available: {e}")
_donut_available = False
return _donut_available
def get_donut_model():
"""
Lazy load Donut model and processor.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _donut_processor, _donut_model
if not _check_donut_available():
return None, None
if _donut_processor is None or _donut_model is None:
try:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
model_name = "naver-clova-ix/donut-base"
logger.info(f"Loading Donut model: {model_name}")
_donut_processor = DonutProcessor.from_pretrained(model_name)
_donut_model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
_donut_model.to(device)
logger.info(f"Donut model loaded on device: {device}")
except Exception as e:
logger.error(f"Failed to load Donut model: {e}")
return None, None
return _donut_processor, _donut_model
async def run_donut_ocr(image_data: bytes, task_prompt: str = "<s_cord-v2>") -> Tuple[Optional[str], float]:
"""
Run Donut OCR on an image.
Donut is an end-to-end document understanding model that can:
- Extract text
- Understand document structure
- Parse forms and tables
Args:
image_data: Raw image bytes
task_prompt: Donut task prompt (default: CORD for receipts/documents)
- "<s_cord-v2>": Receipt/document parsing
- "<s_docvqa>": Document Visual QA
- "<s_synthdog>": Synthetic document generation
Returns:
Tuple of (extracted_text, confidence)
"""
processor, model = get_donut_model()
if processor is None or model is None:
logger.error("Donut model not available")
return None, 0.0
try:
import torch
from PIL import Image
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Prepare input
pixel_values = processor(image, return_tensors="pt").pixel_values
# Move to same device as model
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
# Prepare decoder input
task_prompt_ids = processor.tokenizer(
task_prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids.to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
pixel_values,
decoder_input_ids=task_prompt_ids,
max_length=model.config.decoder.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode output
sequence = processor.batch_decode(outputs.sequences)[0]
# Remove task prompt and special tokens
sequence = sequence.replace(task_prompt, "").replace(
processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# Parse the output (Donut outputs JSON-like structure)
text = _parse_donut_output(sequence)
# Calculate confidence (rough estimate based on output quality)
confidence = 0.8 if text and len(text) > 10 else 0.5
logger.info(f"Donut OCR extracted {len(text)} characters")
return text, confidence
except Exception as e:
logger.error(f"Donut OCR failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
def _parse_donut_output(sequence: str) -> str:
"""
Parse Donut output into plain text.
Donut outputs structured data (JSON-like), we extract readable text.
"""
import re
import json
# Clean up the sequence
sequence = sequence.strip()
# Try to parse as JSON
try:
# Find JSON-like content
json_match = re.search(r'\{.*\}', sequence, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
# Extract text from various possible fields
text_parts = []
_extract_text_recursive(data, text_parts)
return "\n".join(text_parts)
except json.JSONDecodeError:
pass
# Fallback: extract text between tags
text_parts = []
# Match patterns like <text>value</text> or <s_field>value</s_field>
pattern = r'<[^>]+>([^<]+)</[^>]+>'
matches = re.findall(pattern, sequence)
if matches:
text_parts.extend(matches)
return "\n".join(text_parts)
# Last resort: return cleaned sequence
# Remove XML-like tags
clean = re.sub(r'<[^>]+>', ' ', sequence)
clean = re.sub(r'\s+', ' ', clean).strip()
return clean
def _extract_text_recursive(data, text_parts: list, indent: int = 0):
"""Recursively extract text from nested data structure."""
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, str) and value.strip():
# Skip keys that look like metadata
if not key.startswith('_'):
text_parts.append(f"{value.strip()}")
elif isinstance(value, (dict, list)):
_extract_text_recursive(value, text_parts, indent + 1)
elif isinstance(data, list):
for item in data:
_extract_text_recursive(item, text_parts, indent)
elif isinstance(data, str) and data.strip():
text_parts.append(data.strip())
# Alternative: Simple text extraction mode
async def run_donut_ocr_simple(image_data: bytes) -> Tuple[Optional[str], float]:
"""
Simplified Donut OCR that just extracts text without structured parsing.
Uses a more general task prompt for plain text extraction.
"""
return await run_donut_ocr(image_data, task_prompt="<s>")
# Test function
async def test_donut_ocr(image_path: str):
"""Test Donut OCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_donut_ocr(image_data)
print(f"\n=== Donut OCR Test ===")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
asyncio.run(test_donut_ocr(sys.argv[1]))
else:
print("Usage: python donut_ocr_service.py <image_path>")

View File

@@ -0,0 +1,97 @@
"""
Klausur-Service EH Service
Functions for audit logging and EH-related utilities.
"""
import uuid
from datetime import datetime, timezone
from typing import Dict, Optional
from models.grading import AuditLogEntry
from models.eh import EHAuditLogEntry
# Import storage - will be initialized by main.py
# These imports need to reference the actual storage module
import storage
def log_audit(
user_id: str,
action: str,
entity_type: str,
entity_id: str,
field: str = None,
old_value: str = None,
new_value: str = None,
details: Dict = None
) -> AuditLogEntry:
"""
Add an entry to the general audit log.
Args:
user_id: ID of the user performing the action
action: Type of action (score_update, gutachten_update, etc.)
entity_type: Type of entity (klausur, student)
entity_id: ID of the entity
field: Optional field name that was changed
old_value: Optional old value
new_value: Optional new value
details: Optional additional details
Returns:
The created AuditLogEntry
"""
entry = AuditLogEntry(
id=str(uuid.uuid4()),
timestamp=datetime.now(timezone.utc),
user_id=user_id,
action=action,
entity_type=entity_type,
entity_id=entity_id,
field=field,
old_value=old_value,
new_value=new_value,
details=details
)
storage.audit_log_db.append(entry)
return entry
def log_eh_audit(
tenant_id: str,
user_id: str,
action: str,
eh_id: str = None,
details: Dict = None,
ip_address: str = None,
user_agent: str = None
) -> EHAuditLogEntry:
"""
Add an entry to the EH audit log.
Args:
tenant_id: Tenant ID
user_id: ID of the user performing the action
action: Type of action (upload, index, rag_query, etc.)
eh_id: Optional EH ID
details: Optional additional details
ip_address: Optional client IP address
user_agent: Optional client user agent
Returns:
The created EHAuditLogEntry
"""
entry = EHAuditLogEntry(
id=str(uuid.uuid4()),
eh_id=eh_id,
tenant_id=tenant_id,
user_id=user_id,
action=action,
details=details,
ip_address=ip_address,
user_agent=user_agent,
created_at=datetime.now(timezone.utc)
)
storage.eh_audit_db.append(entry)
return entry

View File

@@ -0,0 +1,43 @@
"""
Klausur-Service Grading Service
Functions for grade calculation.
"""
from typing import Dict
from models.grading import GRADE_THRESHOLDS, DEFAULT_CRITERIA
def calculate_grade_points(percentage: float) -> int:
"""
Calculate 15-point grade from percentage.
Args:
percentage: Score as percentage (0-100)
Returns:
Grade points (0-15)
"""
for points, threshold in sorted(GRADE_THRESHOLDS.items(), reverse=True):
if percentage >= threshold:
return points
return 0
def calculate_raw_points(criteria_scores: Dict[str, Dict]) -> int:
"""
Calculate weighted raw points from criteria scores.
Args:
criteria_scores: Dict mapping criterion name to score data
Returns:
Weighted raw points
"""
total = 0.0
for criterion, data in criteria_scores.items():
weight = DEFAULT_CRITERIA.get(criterion, {}).get("weight", 0.2)
score = data.get("score", 0)
total += score * weight
return int(total)

View File

@@ -0,0 +1,359 @@
"""
Handwriting Detection Service for Worksheet Cleanup
Detects handwritten content in scanned worksheets and returns binary masks.
Uses multiple detection methods:
1. Color-based detection (blue/red ink)
2. Stroke analysis (thin irregular strokes)
3. Edge density variance
DATENSCHUTZ: All processing happens locally on Mac Mini.
"""
import numpy as np
from PIL import Image
import io
import logging
from typing import Tuple, Optional
from dataclasses import dataclass
# OpenCV is optional - only required for actual handwriting detection
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
cv2 = None
CV2_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class DetectionResult:
"""Result of handwriting detection."""
mask: np.ndarray # Binary mask (255 = handwriting, 0 = background/printed)
confidence: float # Overall confidence score
handwriting_ratio: float # Ratio of handwriting pixels to total
detection_method: str # Which method was primarily used
def detect_handwriting(image_bytes: bytes) -> DetectionResult:
"""
Detect handwriting in an image.
Args:
image_bytes: Image as bytes (PNG, JPG, etc.)
Returns:
DetectionResult with binary mask where handwriting is white (255)
Raises:
ImportError: If OpenCV is not available
"""
if not CV2_AVAILABLE:
raise ImportError(
"OpenCV (cv2) is required for handwriting detection. "
"Install with: pip install opencv-python-headless"
)
# Load image
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img)
# Convert to BGR if needed (OpenCV format)
if len(img_array.shape) == 2:
# Grayscale to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
elif img_array.shape[2] == 4:
# RGBA to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
elif img_array.shape[2] == 3:
# RGB to BGR
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
img_bgr = img_array
# Run multiple detection methods
color_mask, color_confidence = _detect_by_color(img_bgr)
stroke_mask, stroke_confidence = _detect_by_stroke_analysis(img_bgr)
variance_mask, variance_confidence = _detect_by_variance(img_bgr)
# Combine masks using weighted average
weights = [color_confidence, stroke_confidence, variance_confidence]
total_weight = sum(weights)
if total_weight > 0:
# Weighted combination
combined_mask = (
color_mask.astype(np.float32) * color_confidence +
stroke_mask.astype(np.float32) * stroke_confidence +
variance_mask.astype(np.float32) * variance_confidence
) / total_weight
# Threshold to binary
combined_mask = (combined_mask > 127).astype(np.uint8) * 255
else:
combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
# Post-processing: Remove small noise
combined_mask = _clean_mask(combined_mask)
# Calculate metrics
total_pixels = combined_mask.size
handwriting_pixels = np.sum(combined_mask > 0)
handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0
# Determine primary method
primary_method = "combined"
max_conf = max(color_confidence, stroke_confidence, variance_confidence)
if max_conf == color_confidence:
primary_method = "color"
elif max_conf == stroke_confidence:
primary_method = "stroke"
else:
primary_method = "variance"
overall_confidence = total_weight / 3.0 # Average confidence
logger.info(f"Handwriting detection: {handwriting_ratio:.2%} handwriting, "
f"confidence={overall_confidence:.2f}, method={primary_method}")
return DetectionResult(
mask=combined_mask,
confidence=overall_confidence,
handwriting_ratio=handwriting_ratio,
detection_method=primary_method
)
def _detect_by_color(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Detect handwriting by ink color (blue, red, black pen).
Blue and red ink are common for corrections and handwriting.
Black pen has different characteristics than printed black.
"""
# Convert to HSV for color detection
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
# Blue ink detection (Hue: 100-130, Saturation: 50-255, Value: 30-200)
blue_lower = np.array([100, 50, 30])
blue_upper = np.array([130, 255, 200])
blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
# Red ink detection (Hue: 0-10 and 170-180)
red_lower1 = np.array([0, 50, 50])
red_upper1 = np.array([10, 255, 255])
red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
red_lower2 = np.array([170, 50, 50])
red_upper2 = np.array([180, 255, 255])
red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
# Green ink (less common but sometimes used)
green_lower = np.array([35, 50, 50])
green_upper = np.array([85, 255, 200])
green_mask = cv2.inRange(hsv, green_lower, green_upper)
# Combine colored ink masks
color_mask = cv2.bitwise_or(blue_mask, red_mask)
color_mask = cv2.bitwise_or(color_mask, green_mask)
# Dilate to connect nearby regions
kernel = np.ones((3, 3), np.uint8)
color_mask = cv2.dilate(color_mask, kernel, iterations=1)
# Calculate confidence based on detected pixels
total_pixels = color_mask.size
colored_pixels = np.sum(color_mask > 0)
ratio = colored_pixels / total_pixels if total_pixels > 0 else 0
# High confidence if we found significant colored ink (1-20% of image)
if 0.005 < ratio < 0.3:
confidence = 0.9
elif ratio > 0:
confidence = 0.5
else:
confidence = 0.1
return color_mask, confidence
def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Detect handwriting by analyzing stroke characteristics.
Handwriting typically has:
- Thinner, more variable stroke widths
- More curved lines
- Connected components
"""
# Convert to grayscale
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Adaptive thresholding to extract text
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2
)
# Find edges (handwriting has more irregular edges)
edges = cv2.Canny(gray, 50, 150)
# Morphological gradient for stroke detection
kernel = np.ones((2, 2), np.uint8)
gradient = cv2.morphologyEx(binary, cv2.MORPH_GRADIENT, kernel)
# Skeleton to analyze stroke width
# Thin strokes (handwriting) will have more skeleton pixels relative to mass
skeleton = _skeletonize(binary)
# Detect thin strokes by comparing skeleton to original
# Dilate skeleton and XOR with original to find thick regions (printed)
dilated_skeleton = cv2.dilate(skeleton, np.ones((5, 5), np.uint8), iterations=1)
thick_regions = cv2.bitwise_and(binary, cv2.bitwise_not(dilated_skeleton))
thin_regions = cv2.bitwise_and(binary, dilated_skeleton)
# Handwriting tends to be in thin regions with irregular edges
handwriting_mask = thin_regions
# Calculate confidence
total_ink = np.sum(binary > 0)
thin_ink = np.sum(thin_regions > 0)
if total_ink > 0:
thin_ratio = thin_ink / total_ink
confidence = min(thin_ratio * 1.5, 0.8) # Cap at 0.8
else:
confidence = 0.1
return handwriting_mask, confidence
def _detect_by_variance(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""
Detect handwriting by local variance analysis.
Handwriting has higher local variance in stroke direction and width
compared to uniform printed text.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Calculate local variance using a sliding window
kernel_size = 15
mean = cv2.blur(gray.astype(np.float32), (kernel_size, kernel_size))
sqr_mean = cv2.blur((gray.astype(np.float32))**2, (kernel_size, kernel_size))
variance = sqr_mean - mean**2
# Normalize variance
variance = cv2.normalize(variance, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
# High variance regions might be handwriting
# But also edges of printed text, so we need to filter
# Get text regions first
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2
)
# High variance within text regions
high_variance_mask = cv2.threshold(variance, 100, 255, cv2.THRESH_BINARY)[1]
handwriting_mask = cv2.bitwise_and(high_variance_mask, binary)
# Calculate confidence based on variance distribution
text_pixels = np.sum(binary > 0)
high_var_pixels = np.sum(handwriting_mask > 0)
if text_pixels > 0:
var_ratio = high_var_pixels / text_pixels
# If 5-40% of text has high variance, likely handwriting present
if 0.05 < var_ratio < 0.5:
confidence = 0.7
else:
confidence = 0.3
else:
confidence = 0.1
return handwriting_mask, confidence
def _skeletonize(binary: np.ndarray) -> np.ndarray:
"""
Morphological skeletonization.
"""
skeleton = np.zeros(binary.shape, np.uint8)
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
img = binary.copy()
while True:
eroded = cv2.erode(img, element)
temp = cv2.dilate(eroded, element)
temp = cv2.subtract(img, temp)
skeleton = cv2.bitwise_or(skeleton, temp)
img = eroded.copy()
if cv2.countNonZero(img) == 0:
break
return skeleton
def _clean_mask(mask: np.ndarray, min_area: int = 50) -> np.ndarray:
"""
Clean up the mask by removing small noise regions.
"""
# Find connected components
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
mask, connectivity=8
)
# Create clean mask keeping only components above minimum area
clean = np.zeros_like(mask)
for i in range(1, num_labels): # Skip background (label 0)
area = stats[i, cv2.CC_STAT_AREA]
if area >= min_area:
clean[labels == i] = 255
return clean
def mask_to_png(mask: np.ndarray) -> bytes:
"""
Convert a mask to PNG bytes.
"""
img = Image.fromarray(mask)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return buffer.getvalue()
def detect_handwriting_regions(
image_bytes: bytes,
min_confidence: float = 0.3
) -> dict:
"""
High-level function that returns structured detection results.
Args:
image_bytes: Input image
min_confidence: Minimum confidence to report detection
Returns:
Dictionary with detection results
"""
result = detect_handwriting(image_bytes)
has_handwriting = (
result.confidence >= min_confidence and
result.handwriting_ratio > 0.005 # At least 0.5% handwriting
)
return {
"has_handwriting": has_handwriting,
"confidence": result.confidence,
"handwriting_ratio": result.handwriting_ratio,
"detection_method": result.detection_method,
"mask_shape": result.mask.shape,
}

View File

@@ -0,0 +1,383 @@
"""
Inpainting Service for Worksheet Cleanup
Removes handwriting from scanned worksheets using inpainting techniques.
Supports multiple backends:
1. OpenCV (Telea/NS algorithms) - Fast, CPU-based baseline
2. LaMa (Large Mask Inpainting) - Optional, better quality
DATENSCHUTZ: All processing happens locally on Mac Mini.
"""
import numpy as np
from PIL import Image
import io
import logging
from typing import Tuple, Optional
from dataclasses import dataclass
from enum import Enum
# OpenCV is optional - only required for actual inpainting
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
cv2 = None
CV2_AVAILABLE = False
logger = logging.getLogger(__name__)
class InpaintingMethod(str, Enum):
"""Available inpainting methods."""
OPENCV_TELEA = "opencv_telea" # Fast, good for small regions
OPENCV_NS = "opencv_ns" # Navier-Stokes, slower but smoother
LAMA = "lama" # LaMa deep learning (if available)
AUTO = "auto" # Automatically select best method
@dataclass
class InpaintingResult:
"""Result of inpainting operation."""
image: np.ndarray # Cleaned image (BGR)
method_used: str # Which method was actually used
processing_time_ms: float # Processing time in milliseconds
success: bool
error_message: Optional[str] = None
# Global LaMa model (lazy loaded)
_lama_model = None
_lama_available = None
def check_lama_available() -> bool:
"""Check if LaMa inpainting is available."""
global _lama_available
if _lama_available is not None:
return _lama_available
try:
# Try to import lama-cleaner library
from lama_cleaner.model_manager import ModelManager
_lama_available = True
logger.info("LaMa inpainting is available")
except ImportError:
_lama_available = False
logger.info("LaMa not available, will use OpenCV fallback")
except Exception as e:
_lama_available = False
logger.warning(f"LaMa check failed: {e}")
return _lama_available
def inpaint_image(
image_bytes: bytes,
mask_bytes: bytes,
method: InpaintingMethod = InpaintingMethod.AUTO,
inpaint_radius: int = 3
) -> InpaintingResult:
"""
Inpaint (remove) masked regions from an image.
Args:
image_bytes: Source image as bytes
mask_bytes: Binary mask where white (255) = regions to remove
method: Inpainting method to use
inpaint_radius: Radius for OpenCV inpainting (default 3)
Returns:
InpaintingResult with cleaned image
Raises:
ImportError: If OpenCV is not available
"""
if not CV2_AVAILABLE:
raise ImportError(
"OpenCV (cv2) is required for inpainting. "
"Install with: pip install opencv-python-headless"
)
import time
start_time = time.time()
try:
# Load image
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img)
# Convert to BGR for OpenCV
if len(img_array.shape) == 2:
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
elif img_array.shape[2] == 4:
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
elif img_array.shape[2] == 3:
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
img_bgr = img_array
# Load mask
mask_img = Image.open(io.BytesIO(mask_bytes))
mask_array = np.array(mask_img)
# Ensure mask is single channel
if len(mask_array.shape) == 3:
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
# Ensure mask is binary
_, mask_binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY)
# Resize mask if dimensions don't match
if mask_binary.shape[:2] != img_bgr.shape[:2]:
mask_binary = cv2.resize(
mask_binary,
(img_bgr.shape[1], img_bgr.shape[0]),
interpolation=cv2.INTER_NEAREST
)
# Select method
if method == InpaintingMethod.AUTO:
# Use LaMa if available and mask is large
mask_ratio = np.sum(mask_binary > 0) / mask_binary.size
if check_lama_available() and mask_ratio > 0.05:
method = InpaintingMethod.LAMA
else:
method = InpaintingMethod.OPENCV_TELEA
# Perform inpainting
if method == InpaintingMethod.LAMA:
result_img, actual_method = _inpaint_lama(img_bgr, mask_binary)
elif method == InpaintingMethod.OPENCV_NS:
result_img = cv2.inpaint(
img_bgr, mask_binary, inpaint_radius, cv2.INPAINT_NS
)
actual_method = "opencv_ns"
else: # OPENCV_TELEA (default)
result_img = cv2.inpaint(
img_bgr, mask_binary, inpaint_radius, cv2.INPAINT_TELEA
)
actual_method = "opencv_telea"
processing_time = (time.time() - start_time) * 1000
logger.info(f"Inpainting completed: method={actual_method}, "
f"time={processing_time:.1f}ms")
return InpaintingResult(
image=result_img,
method_used=actual_method,
processing_time_ms=processing_time,
success=True
)
except Exception as e:
logger.error(f"Inpainting failed: {e}")
import traceback
logger.error(traceback.format_exc())
return InpaintingResult(
image=None,
method_used="none",
processing_time_ms=0,
success=False,
error_message=str(e)
)
def _inpaint_lama(
img_bgr: np.ndarray,
mask: np.ndarray
) -> Tuple[np.ndarray, str]:
"""
Inpaint using LaMa (Large Mask Inpainting).
Falls back to OpenCV if LaMa fails.
"""
global _lama_model
try:
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler
# Initialize model if needed
if _lama_model is None:
logger.info("Loading LaMa model...")
_lama_model = ModelManager(
name="lama",
device="cpu", # Use CPU for Mac Mini compatibility
)
logger.info("LaMa model loaded")
# Prepare config
config = Config(
ldm_steps=25,
ldm_sampler=LDMSampler.plms,
hd_strategy=HDStrategy.ORIGINAL,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=800,
hd_strategy_resize_limit=800,
)
# Convert BGR to RGB for LaMa
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# Run inpainting
result_rgb = _lama_model(img_rgb, mask, config)
# Convert back to BGR
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
return result_bgr, "lama"
except Exception as e:
logger.warning(f"LaMa inpainting failed, falling back to OpenCV: {e}")
# Fallback to OpenCV
result = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA)
return result, "opencv_telea_fallback"
def inpaint_opencv_telea(
image_bytes: bytes,
mask_bytes: bytes,
radius: int = 3
) -> bytes:
"""
Simple OpenCV Telea inpainting - fastest option.
Args:
image_bytes: Source image
mask_bytes: Binary mask (white = remove)
radius: Inpainting radius
Returns:
Inpainted image as PNG bytes
"""
result = inpaint_image(
image_bytes,
mask_bytes,
method=InpaintingMethod.OPENCV_TELEA,
inpaint_radius=radius
)
if not result.success:
raise RuntimeError(f"Inpainting failed: {result.error_message}")
return image_to_png(result.image)
def inpaint_opencv_ns(
image_bytes: bytes,
mask_bytes: bytes,
radius: int = 3
) -> bytes:
"""
OpenCV Navier-Stokes inpainting - smoother but slower.
"""
result = inpaint_image(
image_bytes,
mask_bytes,
method=InpaintingMethod.OPENCV_NS,
inpaint_radius=radius
)
if not result.success:
raise RuntimeError(f"Inpainting failed: {result.error_message}")
return image_to_png(result.image)
def remove_handwriting(
image_bytes: bytes,
mask: Optional[np.ndarray] = None,
method: InpaintingMethod = InpaintingMethod.AUTO
) -> Tuple[bytes, dict]:
"""
High-level function to remove handwriting from an image.
If no mask is provided, detects handwriting automatically.
Args:
image_bytes: Source image
mask: Optional pre-computed mask
method: Inpainting method
Returns:
Tuple of (cleaned image bytes, metadata dict)
"""
from services.handwriting_detection import detect_handwriting, mask_to_png
# Detect handwriting if no mask provided
if mask is None:
detection_result = detect_handwriting(image_bytes)
mask = detection_result.mask
detection_info = {
"confidence": detection_result.confidence,
"handwriting_ratio": detection_result.handwriting_ratio,
"detection_method": detection_result.detection_method
}
else:
detection_info = {"provided_mask": True}
# Check if there's anything to inpaint
if np.sum(mask > 0) == 0:
logger.info("No handwriting detected, returning original image")
return image_bytes, {
"inpainting_performed": False,
"reason": "no_handwriting_detected",
**detection_info
}
# Convert mask to bytes for inpainting
mask_bytes = mask_to_png(mask)
# Perform inpainting
result = inpaint_image(image_bytes, mask_bytes, method=method)
if not result.success:
raise RuntimeError(f"Inpainting failed: {result.error_message}")
# Convert result to PNG
result_bytes = image_to_png(result.image)
metadata = {
"inpainting_performed": True,
"method_used": result.method_used,
"processing_time_ms": result.processing_time_ms,
**detection_info
}
return result_bytes, metadata
def image_to_png(img_bgr: np.ndarray) -> bytes:
"""
Convert BGR image array to PNG bytes.
"""
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(img_rgb)
buffer = io.BytesIO()
img_pil.save(buffer, format='PNG', optimize=True)
return buffer.getvalue()
def dilate_mask(mask_bytes: bytes, iterations: int = 2) -> bytes:
"""
Dilate a mask to expand the removal region.
Useful to ensure complete handwriting removal including edges.
"""
mask_img = Image.open(io.BytesIO(mask_bytes))
mask_array = np.array(mask_img)
if len(mask_array.shape) == 3:
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
kernel = np.ones((3, 3), np.uint8)
dilated = cv2.dilate(mask_array, kernel, iterations=iterations)
img = Image.fromarray(dilated)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return buffer.getvalue()

View File

@@ -0,0 +1,375 @@
"""
Layout Reconstruction Service for Worksheet Cleanup
Reconstructs the layout of a worksheet from an image:
1. Uses PaddleOCR to detect text with bounding boxes
2. Groups text into logical elements (headings, paragraphs, tables)
3. Generates Fabric.js compatible JSON for the worksheet editor
DATENSCHUTZ: All processing happens locally on Mac Mini.
"""
import numpy as np
from PIL import Image
import io
import json
import logging
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
# OpenCV is optional - only required for actual layout reconstruction
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
cv2 = None
CV2_AVAILABLE = False
logger = logging.getLogger(__name__)
class ElementType(str, Enum):
"""Types of detected layout elements."""
HEADING = "heading"
PARAGRAPH = "paragraph"
TEXT_LINE = "text_line"
TABLE = "table"
LIST_ITEM = "list_item"
FORM_FIELD = "form_field"
IMAGE = "image"
@dataclass
class TextElement:
"""A detected text element with position."""
text: str
x: float # Left position (pixels)
y: float # Top position (pixels)
width: float
height: float
confidence: float
element_type: ElementType = ElementType.TEXT_LINE
font_size: float = 14.0
is_bold: bool = False
is_centered: bool = False
@dataclass
class LayoutResult:
"""Result of layout reconstruction."""
elements: List[TextElement]
page_width: int
page_height: int
fabric_json: Dict[str, Any]
table_regions: List[Dict[str, Any]] = field(default_factory=list)
def reconstruct_layout(
image_bytes: bytes,
detect_tables: bool = True
) -> LayoutResult:
"""
Reconstruct the layout of a worksheet from an image.
Args:
image_bytes: Image as bytes
detect_tables: Whether to detect table structures
Returns:
LayoutResult with elements and Fabric.js JSON
Raises:
ImportError: If OpenCV is not available
"""
if not CV2_AVAILABLE:
raise ImportError(
"OpenCV (cv2) is required for layout reconstruction. "
"Install with: pip install opencv-python-headless"
)
# Load image
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img)
page_height, page_width = img_array.shape[:2]
# Run PaddleOCR to get text with positions
ocr_results = _run_paddle_ocr(image_bytes)
if not ocr_results:
logger.warning("No text detected by PaddleOCR")
return LayoutResult(
elements=[],
page_width=page_width,
page_height=page_height,
fabric_json={"version": "5.3.0", "objects": []}
)
# Convert OCR results to TextElements
elements = _convert_ocr_to_elements(ocr_results, page_width, page_height)
# Group elements into lines and detect headings
elements = _classify_elements(elements, page_width)
# Detect table regions if enabled
table_regions = []
if detect_tables:
table_regions = _detect_tables(img_array, elements)
# Generate Fabric.js JSON
fabric_json = _generate_fabric_json(elements, page_width, page_height)
logger.info(f"Layout reconstruction: {len(elements)} elements, "
f"{len(table_regions)} tables")
return LayoutResult(
elements=elements,
page_width=page_width,
page_height=page_height,
fabric_json=fabric_json,
table_regions=table_regions
)
def _run_paddle_ocr(image_bytes: bytes) -> List[Dict[str, Any]]:
"""
Run PaddleOCR on an image.
Returns list of {text, confidence, bbox} dicts.
"""
try:
from hybrid_vocab_extractor import run_paddle_ocr as paddle_ocr_func, OCRRegion
regions, _ = paddle_ocr_func(image_bytes)
return [
{
"text": r.text,
"confidence": r.confidence,
"bbox": [r.x1, r.y1, r.x2, r.y2]
}
for r in regions
]
except ImportError:
logger.error("PaddleOCR not available")
return []
except Exception as e:
logger.error(f"PaddleOCR failed: {e}")
return []
def _convert_ocr_to_elements(
ocr_results: List[Dict[str, Any]],
page_width: int,
page_height: int
) -> List[TextElement]:
"""
Convert raw OCR results to TextElements.
"""
elements = []
for result in ocr_results:
bbox = result["bbox"]
x1, y1, x2, y2 = bbox
# Calculate dimensions
width = x2 - x1
height = y2 - y1
# Estimate font size from height
font_size = max(8, min(72, height * 0.8))
element = TextElement(
text=result["text"],
x=x1,
y=y1,
width=width,
height=height,
confidence=result["confidence"],
font_size=font_size
)
elements.append(element)
return elements
def _classify_elements(
elements: List[TextElement],
page_width: int
) -> List[TextElement]:
"""
Classify elements as headings, paragraphs, etc.
"""
if not elements:
return elements
# Calculate average metrics
avg_font_size = sum(e.font_size for e in elements) / len(elements)
avg_y = sum(e.y for e in elements) / len(elements)
for element in elements:
# Detect headings (larger font, near top, possibly centered)
is_larger = element.font_size > avg_font_size * 1.3
is_near_top = element.y < avg_y * 0.3
is_centered = abs((element.x + element.width / 2) - page_width / 2) < page_width * 0.15
if is_larger and (is_near_top or is_centered):
element.element_type = ElementType.HEADING
element.is_bold = True
element.is_centered = is_centered
# Detect list items (start with bullet or number)
elif element.text.strip().startswith(('', '-', '', '*')) or \
(len(element.text) > 2 and element.text[0].isdigit() and element.text[1] in '.):'):
element.element_type = ElementType.LIST_ITEM
# Detect form fields (underscores or dotted lines)
elif '_____' in element.text or '.....' in element.text:
element.element_type = ElementType.FORM_FIELD
else:
element.element_type = ElementType.TEXT_LINE
return elements
def _detect_tables(
img_array: np.ndarray,
elements: List[TextElement]
) -> List[Dict[str, Any]]:
"""
Detect table regions in the image.
"""
tables = []
# Convert to grayscale
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Detect horizontal and vertical lines
edges = cv2.Canny(gray, 50, 150)
# Detect lines using Hough transform
lines = cv2.HoughLinesP(
edges, 1, np.pi/180, threshold=100,
minLineLength=50, maxLineGap=10
)
if lines is None:
return tables
# Separate horizontal and vertical lines
horizontal_lines = []
vertical_lines = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)
if angle < 10: # Horizontal
horizontal_lines.append((x1, y1, x2, y2))
elif angle > 80: # Vertical
vertical_lines.append((x1, y1, x2, y2))
# Find table regions (intersections of horizontal and vertical lines)
if len(horizontal_lines) >= 2 and len(vertical_lines) >= 2:
# Sort lines
horizontal_lines.sort(key=lambda l: l[1])
vertical_lines.sort(key=lambda l: l[0])
# Find bounding box of table
min_x = min(l[0] for l in vertical_lines)
max_x = max(l[2] for l in vertical_lines)
min_y = min(l[1] for l in horizontal_lines)
max_y = max(l[3] for l in horizontal_lines)
tables.append({
"x": min_x,
"y": min_y,
"width": max_x - min_x,
"height": max_y - min_y,
"rows": len(horizontal_lines) - 1,
"cols": len(vertical_lines) - 1
})
return tables
def _generate_fabric_json(
elements: List[TextElement],
page_width: int,
page_height: int
) -> Dict[str, Any]:
"""
Generate Fabric.js compatible JSON from elements.
"""
fabric_objects = []
for i, element in enumerate(elements):
fabric_obj = {
"type": "textbox",
"version": "5.3.0",
"originX": "left",
"originY": "top",
"left": element.x,
"top": element.y,
"width": max(element.width, 100),
"height": element.height,
"fill": "#000000",
"stroke": None,
"strokeWidth": 0,
"text": element.text,
"fontSize": element.font_size,
"fontWeight": "bold" if element.is_bold else "normal",
"fontFamily": "Arial",
"textAlign": "center" if element.is_centered else "left",
"underline": False,
"lineHeight": 1.2,
"charSpacing": 0,
"splitByGrapheme": False,
"editable": True,
"selectable": True,
"data": {
"elementType": element.element_type.value,
"confidence": element.confidence,
"originalIndex": i
}
}
fabric_objects.append(fabric_obj)
return {
"version": "5.3.0",
"objects": fabric_objects,
"background": "#ffffff"
}
def layout_to_fabric_json(layout_result: LayoutResult) -> str:
"""
Convert LayoutResult to JSON string for frontend.
"""
return json.dumps(layout_result.fabric_json, ensure_ascii=False, indent=2)
def reconstruct_and_clean(
image_bytes: bytes,
remove_handwriting: bool = True
) -> Tuple[bytes, LayoutResult]:
"""
Full pipeline: clean handwriting and reconstruct layout.
Args:
image_bytes: Source image
remove_handwriting: Whether to remove handwriting first
Returns:
Tuple of (cleaned image bytes, layout result)
"""
if remove_handwriting:
from services.inpainting_service import remove_handwriting as clean_hw
cleaned_bytes, _ = clean_hw(image_bytes)
else:
cleaned_bytes = image_bytes
layout = reconstruct_layout(cleaned_bytes)
return cleaned_bytes, layout

View File

@@ -0,0 +1,586 @@
"""
TrOCR Service
Microsoft's Transformer-based OCR for text recognition.
Besonders geeignet fuer:
- Gedruckten Text
- Saubere Scans
- Schnelle Verarbeitung
Model: microsoft/trocr-base-printed (oder handwritten Variante)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
Phase 2 Enhancements:
- Batch processing for multiple images
- SHA256-based caching for repeated requests
- Model preloading for faster first request
- Word-level bounding boxes with confidence
"""
import io
import hashlib
import logging
import time
import asyncio
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# Lazy loading for heavy dependencies
_trocr_processor = None
_trocr_model = None
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
if _trocr_processor is None or _trocr_model is None:
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Choose model based on use case
if handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
logger.info(f"Loading TrOCR model: {model_name}")
_trocr_processor = TrOCRProcessor.from_pretrained(model_name)
_trocr_model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
_trocr_model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
except Exception as e:
logger.error(f"Failed to load TrOCR model: {e}")
return None, None
return _trocr_processor, _trocr_model
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
Returns:
Tuple of (extracted_text, confidence)
"""
processor, model = get_trocr_model(handwritten=handwritten)
if processor is None or model is None:
logger.error("TrOCR model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
# Split image into lines and process each
lines = _split_into_lines(image)
if not lines:
lines = [image] # Fallback to full image
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
# Prepare input
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
# Move to same device as model
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
# Generate
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
# TrOCR doesn't provide confidence, estimate based on output
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
# Combine results
text = "\n".join(all_text)
# Average confidence
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
if __name__ == "__main__":
import asyncio
import sys
handwritten = "--handwritten" in sys.argv
args = [a for a in sys.argv[1:] if not a.startswith("--")]
if args:
asyncio.run(test_trocr_ocr(args[0], handwritten=handwritten))
else:
print("Usage: python trocr_service.py <image_path> [--handwritten]")