fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
80
klausur-service/backend/services/__init__.py
Normal file
80
klausur-service/backend/services/__init__.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Klausur-Service Services Package
|
||||
|
||||
# Grading Services
|
||||
from .grading_service import (
|
||||
calculate_grade_points,
|
||||
calculate_raw_points,
|
||||
)
|
||||
|
||||
# Authentication Services
|
||||
from .auth_service import (
|
||||
get_current_user,
|
||||
)
|
||||
|
||||
# EH Audit Services
|
||||
from .eh_service import (
|
||||
log_audit,
|
||||
log_eh_audit,
|
||||
)
|
||||
|
||||
# OCR Services - Lazy imports (require PIL/cv2 which may not be installed)
|
||||
# These are imported on-demand when actually used
|
||||
def __getattr__(name):
|
||||
"""Lazy import for optional image processing modules."""
|
||||
_handwriting_exports = {
|
||||
'detect_handwriting', 'detect_handwriting_regions',
|
||||
'mask_to_png', 'DetectionResult'
|
||||
}
|
||||
_inpainting_exports = {
|
||||
'inpaint_image', 'inpaint_opencv_telea', 'inpaint_opencv_ns',
|
||||
'remove_handwriting', 'check_lama_available',
|
||||
'InpaintingMethod', 'InpaintingResult'
|
||||
}
|
||||
_layout_exports = {
|
||||
'reconstruct_layout', 'layout_to_fabric_json', 'reconstruct_and_clean',
|
||||
'LayoutResult', 'TextElement', 'ElementType'
|
||||
}
|
||||
|
||||
if name in _handwriting_exports:
|
||||
from . import handwriting_detection
|
||||
return getattr(handwriting_detection, name)
|
||||
elif name in _inpainting_exports:
|
||||
from . import inpainting_service
|
||||
return getattr(inpainting_service, name)
|
||||
elif name in _layout_exports:
|
||||
from . import layout_reconstruction_service
|
||||
return getattr(layout_reconstruction_service, name)
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Grading
|
||||
'calculate_grade_points',
|
||||
'calculate_raw_points',
|
||||
# Authentication
|
||||
'get_current_user',
|
||||
# Audit
|
||||
'log_audit',
|
||||
'log_eh_audit',
|
||||
# Handwriting Detection (lazy)
|
||||
'detect_handwriting',
|
||||
'detect_handwriting_regions',
|
||||
'mask_to_png',
|
||||
'DetectionResult',
|
||||
# Inpainting (lazy)
|
||||
'inpaint_image',
|
||||
'inpaint_opencv_telea',
|
||||
'inpaint_opencv_ns',
|
||||
'remove_handwriting',
|
||||
'check_lama_available',
|
||||
'InpaintingMethod',
|
||||
'InpaintingResult',
|
||||
# Layout Reconstruction (lazy)
|
||||
'reconstruct_layout',
|
||||
'layout_to_fabric_json',
|
||||
'reconstruct_and_clean',
|
||||
'LayoutResult',
|
||||
'TextElement',
|
||||
'ElementType',
|
||||
]
|
||||
46
klausur-service/backend/services/auth_service.py
Normal file
46
klausur-service/backend/services/auth_service.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Klausur-Service Authentication Service
|
||||
|
||||
Functions for JWT authentication and user extraction.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
import jwt
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from config import JWT_SECRET, ENVIRONMENT
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Dict:
|
||||
"""
|
||||
Extract user from JWT token.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
User payload dict containing user_id, role, email, etc.
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is missing, expired, or invalid
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
if ENVIRONMENT == "development":
|
||||
return {
|
||||
"user_id": "demo-teacher",
|
||||
"role": "admin",
|
||||
"email": "demo@breakpilot.app"
|
||||
}
|
||||
raise HTTPException(status_code=401, detail="Missing authorization header")
|
||||
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
try:
|
||||
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
254
klausur-service/backend/services/donut_ocr_service.py
Normal file
254
klausur-service/backend/services/donut_ocr_service.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Donut OCR Service
|
||||
|
||||
Document Understanding Transformer (Donut) fuer strukturierte Dokumentenverarbeitung.
|
||||
Besonders geeignet fuer:
|
||||
- Tabellen
|
||||
- Formulare
|
||||
- Strukturierte Dokumente
|
||||
- Rechnungen/Quittungen
|
||||
|
||||
Model: naver-clova-ix/donut-base (oder fine-tuned Variante)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Tuple, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
_donut_processor = None
|
||||
_donut_model = None
|
||||
_donut_available = None
|
||||
|
||||
|
||||
def _check_donut_available() -> bool:
|
||||
"""Check if Donut dependencies are available."""
|
||||
global _donut_available
|
||||
if _donut_available is not None:
|
||||
return _donut_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
_donut_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"Donut dependencies not available: {e}")
|
||||
_donut_available = False
|
||||
|
||||
return _donut_available
|
||||
|
||||
|
||||
def get_donut_model():
|
||||
"""
|
||||
Lazy load Donut model and processor.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _donut_processor, _donut_model
|
||||
|
||||
if not _check_donut_available():
|
||||
return None, None
|
||||
|
||||
if _donut_processor is None or _donut_model is None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
|
||||
model_name = "naver-clova-ix/donut-base"
|
||||
|
||||
logger.info(f"Loading Donut model: {model_name}")
|
||||
_donut_processor = DonutProcessor.from_pretrained(model_name)
|
||||
_donut_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
_donut_model.to(device)
|
||||
logger.info(f"Donut model loaded on device: {device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Donut model: {e}")
|
||||
return None, None
|
||||
|
||||
return _donut_processor, _donut_model
|
||||
|
||||
|
||||
async def run_donut_ocr(image_data: bytes, task_prompt: str = "<s_cord-v2>") -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run Donut OCR on an image.
|
||||
|
||||
Donut is an end-to-end document understanding model that can:
|
||||
- Extract text
|
||||
- Understand document structure
|
||||
- Parse forms and tables
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
task_prompt: Donut task prompt (default: CORD for receipts/documents)
|
||||
- "<s_cord-v2>": Receipt/document parsing
|
||||
- "<s_docvqa>": Document Visual QA
|
||||
- "<s_synthdog>": Synthetic document generation
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
processor, model = get_donut_model()
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("Donut model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
# Prepare input
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# Move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
# Prepare decoder input
|
||||
task_prompt_ids = processor.tokenizer(
|
||||
task_prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt"
|
||||
).input_ids.to(device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
pixel_values,
|
||||
decoder_input_ids=task_prompt_ids,
|
||||
max_length=model.config.decoder.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Decode output
|
||||
sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
|
||||
# Remove task prompt and special tokens
|
||||
sequence = sequence.replace(task_prompt, "").replace(
|
||||
processor.tokenizer.eos_token, "").replace(
|
||||
processor.tokenizer.pad_token, "")
|
||||
|
||||
# Parse the output (Donut outputs JSON-like structure)
|
||||
text = _parse_donut_output(sequence)
|
||||
|
||||
# Calculate confidence (rough estimate based on output quality)
|
||||
confidence = 0.8 if text and len(text) > 10 else 0.5
|
||||
|
||||
logger.info(f"Donut OCR extracted {len(text)} characters")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Donut OCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
def _parse_donut_output(sequence: str) -> str:
|
||||
"""
|
||||
Parse Donut output into plain text.
|
||||
|
||||
Donut outputs structured data (JSON-like), we extract readable text.
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
|
||||
# Clean up the sequence
|
||||
sequence = sequence.strip()
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
# Find JSON-like content
|
||||
json_match = re.search(r'\{.*\}', sequence, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
# Extract text from various possible fields
|
||||
text_parts = []
|
||||
_extract_text_recursive(data, text_parts)
|
||||
return "\n".join(text_parts)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback: extract text between tags
|
||||
text_parts = []
|
||||
# Match patterns like <text>value</text> or <s_field>value</s_field>
|
||||
pattern = r'<[^>]+>([^<]+)</[^>]+>'
|
||||
matches = re.findall(pattern, sequence)
|
||||
if matches:
|
||||
text_parts.extend(matches)
|
||||
return "\n".join(text_parts)
|
||||
|
||||
# Last resort: return cleaned sequence
|
||||
# Remove XML-like tags
|
||||
clean = re.sub(r'<[^>]+>', ' ', sequence)
|
||||
clean = re.sub(r'\s+', ' ', clean).strip()
|
||||
return clean
|
||||
|
||||
|
||||
def _extract_text_recursive(data, text_parts: list, indent: int = 0):
|
||||
"""Recursively extract text from nested data structure."""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str) and value.strip():
|
||||
# Skip keys that look like metadata
|
||||
if not key.startswith('_'):
|
||||
text_parts.append(f"{value.strip()}")
|
||||
elif isinstance(value, (dict, list)):
|
||||
_extract_text_recursive(value, text_parts, indent + 1)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_extract_text_recursive(item, text_parts, indent)
|
||||
elif isinstance(data, str) and data.strip():
|
||||
text_parts.append(data.strip())
|
||||
|
||||
|
||||
# Alternative: Simple text extraction mode
|
||||
async def run_donut_ocr_simple(image_data: bytes) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Simplified Donut OCR that just extracts text without structured parsing.
|
||||
|
||||
Uses a more general task prompt for plain text extraction.
|
||||
"""
|
||||
return await run_donut_ocr(image_data, task_prompt="<s>")
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_donut_ocr(image_path: str):
|
||||
"""Test Donut OCR on a local image file."""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
|
||||
print(f"\n=== Donut OCR Test ===")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
asyncio.run(test_donut_ocr(sys.argv[1]))
|
||||
else:
|
||||
print("Usage: python donut_ocr_service.py <image_path>")
|
||||
97
klausur-service/backend/services/eh_service.py
Normal file
97
klausur-service/backend/services/eh_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Klausur-Service EH Service
|
||||
|
||||
Functions for audit logging and EH-related utilities.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from models.grading import AuditLogEntry
|
||||
from models.eh import EHAuditLogEntry
|
||||
|
||||
# Import storage - will be initialized by main.py
|
||||
# These imports need to reference the actual storage module
|
||||
import storage
|
||||
|
||||
|
||||
def log_audit(
|
||||
user_id: str,
|
||||
action: str,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
field: str = None,
|
||||
old_value: str = None,
|
||||
new_value: str = None,
|
||||
details: Dict = None
|
||||
) -> AuditLogEntry:
|
||||
"""
|
||||
Add an entry to the general audit log.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user performing the action
|
||||
action: Type of action (score_update, gutachten_update, etc.)
|
||||
entity_type: Type of entity (klausur, student)
|
||||
entity_id: ID of the entity
|
||||
field: Optional field name that was changed
|
||||
old_value: Optional old value
|
||||
new_value: Optional new value
|
||||
details: Optional additional details
|
||||
|
||||
Returns:
|
||||
The created AuditLogEntry
|
||||
"""
|
||||
entry = AuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
field=field,
|
||||
old_value=old_value,
|
||||
new_value=new_value,
|
||||
details=details
|
||||
)
|
||||
storage.audit_log_db.append(entry)
|
||||
return entry
|
||||
|
||||
|
||||
def log_eh_audit(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
action: str,
|
||||
eh_id: str = None,
|
||||
details: Dict = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None
|
||||
) -> EHAuditLogEntry:
|
||||
"""
|
||||
Add an entry to the EH audit log.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
user_id: ID of the user performing the action
|
||||
action: Type of action (upload, index, rag_query, etc.)
|
||||
eh_id: Optional EH ID
|
||||
details: Optional additional details
|
||||
ip_address: Optional client IP address
|
||||
user_agent: Optional client user agent
|
||||
|
||||
Returns:
|
||||
The created EHAuditLogEntry
|
||||
"""
|
||||
entry = EHAuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=eh_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
storage.eh_audit_db.append(entry)
|
||||
return entry
|
||||
43
klausur-service/backend/services/grading_service.py
Normal file
43
klausur-service/backend/services/grading_service.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Klausur-Service Grading Service
|
||||
|
||||
Functions for grade calculation.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from models.grading import GRADE_THRESHOLDS, DEFAULT_CRITERIA
|
||||
|
||||
|
||||
def calculate_grade_points(percentage: float) -> int:
|
||||
"""
|
||||
Calculate 15-point grade from percentage.
|
||||
|
||||
Args:
|
||||
percentage: Score as percentage (0-100)
|
||||
|
||||
Returns:
|
||||
Grade points (0-15)
|
||||
"""
|
||||
for points, threshold in sorted(GRADE_THRESHOLDS.items(), reverse=True):
|
||||
if percentage >= threshold:
|
||||
return points
|
||||
return 0
|
||||
|
||||
|
||||
def calculate_raw_points(criteria_scores: Dict[str, Dict]) -> int:
|
||||
"""
|
||||
Calculate weighted raw points from criteria scores.
|
||||
|
||||
Args:
|
||||
criteria_scores: Dict mapping criterion name to score data
|
||||
|
||||
Returns:
|
||||
Weighted raw points
|
||||
"""
|
||||
total = 0.0
|
||||
for criterion, data in criteria_scores.items():
|
||||
weight = DEFAULT_CRITERIA.get(criterion, {}).get("weight", 0.2)
|
||||
score = data.get("score", 0)
|
||||
total += score * weight
|
||||
return int(total)
|
||||
359
klausur-service/backend/services/handwriting_detection.py
Normal file
359
klausur-service/backend/services/handwriting_detection.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Handwriting Detection Service for Worksheet Cleanup
|
||||
|
||||
Detects handwritten content in scanned worksheets and returns binary masks.
|
||||
Uses multiple detection methods:
|
||||
1. Color-based detection (blue/red ink)
|
||||
2. Stroke analysis (thin irregular strokes)
|
||||
3. Edge density variance
|
||||
|
||||
DATENSCHUTZ: All processing happens locally on Mac Mini.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
import logging
|
||||
from typing import Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
# OpenCV is optional - only required for actual handwriting detection
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Result of handwriting detection."""
|
||||
mask: np.ndarray # Binary mask (255 = handwriting, 0 = background/printed)
|
||||
confidence: float # Overall confidence score
|
||||
handwriting_ratio: float # Ratio of handwriting pixels to total
|
||||
detection_method: str # Which method was primarily used
|
||||
|
||||
|
||||
def detect_handwriting(image_bytes: bytes) -> DetectionResult:
|
||||
"""
|
||||
Detect handwriting in an image.
|
||||
|
||||
Args:
|
||||
image_bytes: Image as bytes (PNG, JPG, etc.)
|
||||
|
||||
Returns:
|
||||
DetectionResult with binary mask where handwriting is white (255)
|
||||
|
||||
Raises:
|
||||
ImportError: If OpenCV is not available
|
||||
"""
|
||||
if not CV2_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCV (cv2) is required for handwriting detection. "
|
||||
"Install with: pip install opencv-python-headless"
|
||||
)
|
||||
|
||||
# Load image
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
img_array = np.array(img)
|
||||
|
||||
# Convert to BGR if needed (OpenCV format)
|
||||
if len(img_array.shape) == 2:
|
||||
# Grayscale to BGR
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
|
||||
elif img_array.shape[2] == 4:
|
||||
# RGBA to BGR
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
|
||||
elif img_array.shape[2] == 3:
|
||||
# RGB to BGR
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img_bgr = img_array
|
||||
|
||||
# Run multiple detection methods
|
||||
color_mask, color_confidence = _detect_by_color(img_bgr)
|
||||
stroke_mask, stroke_confidence = _detect_by_stroke_analysis(img_bgr)
|
||||
variance_mask, variance_confidence = _detect_by_variance(img_bgr)
|
||||
|
||||
# Combine masks using weighted average
|
||||
weights = [color_confidence, stroke_confidence, variance_confidence]
|
||||
total_weight = sum(weights)
|
||||
|
||||
if total_weight > 0:
|
||||
# Weighted combination
|
||||
combined_mask = (
|
||||
color_mask.astype(np.float32) * color_confidence +
|
||||
stroke_mask.astype(np.float32) * stroke_confidence +
|
||||
variance_mask.astype(np.float32) * variance_confidence
|
||||
) / total_weight
|
||||
|
||||
# Threshold to binary
|
||||
combined_mask = (combined_mask > 127).astype(np.uint8) * 255
|
||||
else:
|
||||
combined_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
|
||||
|
||||
# Post-processing: Remove small noise
|
||||
combined_mask = _clean_mask(combined_mask)
|
||||
|
||||
# Calculate metrics
|
||||
total_pixels = combined_mask.size
|
||||
handwriting_pixels = np.sum(combined_mask > 0)
|
||||
handwriting_ratio = handwriting_pixels / total_pixels if total_pixels > 0 else 0
|
||||
|
||||
# Determine primary method
|
||||
primary_method = "combined"
|
||||
max_conf = max(color_confidence, stroke_confidence, variance_confidence)
|
||||
if max_conf == color_confidence:
|
||||
primary_method = "color"
|
||||
elif max_conf == stroke_confidence:
|
||||
primary_method = "stroke"
|
||||
else:
|
||||
primary_method = "variance"
|
||||
|
||||
overall_confidence = total_weight / 3.0 # Average confidence
|
||||
|
||||
logger.info(f"Handwriting detection: {handwriting_ratio:.2%} handwriting, "
|
||||
f"confidence={overall_confidence:.2f}, method={primary_method}")
|
||||
|
||||
return DetectionResult(
|
||||
mask=combined_mask,
|
||||
confidence=overall_confidence,
|
||||
handwriting_ratio=handwriting_ratio,
|
||||
detection_method=primary_method
|
||||
)
|
||||
|
||||
|
||||
def _detect_by_color(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""
|
||||
Detect handwriting by ink color (blue, red, black pen).
|
||||
|
||||
Blue and red ink are common for corrections and handwriting.
|
||||
Black pen has different characteristics than printed black.
|
||||
"""
|
||||
# Convert to HSV for color detection
|
||||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Blue ink detection (Hue: 100-130, Saturation: 50-255, Value: 30-200)
|
||||
blue_lower = np.array([100, 50, 30])
|
||||
blue_upper = np.array([130, 255, 200])
|
||||
blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
|
||||
|
||||
# Red ink detection (Hue: 0-10 and 170-180)
|
||||
red_lower1 = np.array([0, 50, 50])
|
||||
red_upper1 = np.array([10, 255, 255])
|
||||
red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
|
||||
|
||||
red_lower2 = np.array([170, 50, 50])
|
||||
red_upper2 = np.array([180, 255, 255])
|
||||
red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
|
||||
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
|
||||
|
||||
# Green ink (less common but sometimes used)
|
||||
green_lower = np.array([35, 50, 50])
|
||||
green_upper = np.array([85, 255, 200])
|
||||
green_mask = cv2.inRange(hsv, green_lower, green_upper)
|
||||
|
||||
# Combine colored ink masks
|
||||
color_mask = cv2.bitwise_or(blue_mask, red_mask)
|
||||
color_mask = cv2.bitwise_or(color_mask, green_mask)
|
||||
|
||||
# Dilate to connect nearby regions
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
color_mask = cv2.dilate(color_mask, kernel, iterations=1)
|
||||
|
||||
# Calculate confidence based on detected pixels
|
||||
total_pixels = color_mask.size
|
||||
colored_pixels = np.sum(color_mask > 0)
|
||||
ratio = colored_pixels / total_pixels if total_pixels > 0 else 0
|
||||
|
||||
# High confidence if we found significant colored ink (1-20% of image)
|
||||
if 0.005 < ratio < 0.3:
|
||||
confidence = 0.9
|
||||
elif ratio > 0:
|
||||
confidence = 0.5
|
||||
else:
|
||||
confidence = 0.1
|
||||
|
||||
return color_mask, confidence
|
||||
|
||||
|
||||
def _detect_by_stroke_analysis(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""
|
||||
Detect handwriting by analyzing stroke characteristics.
|
||||
|
||||
Handwriting typically has:
|
||||
- Thinner, more variable stroke widths
|
||||
- More curved lines
|
||||
- Connected components
|
||||
"""
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Adaptive thresholding to extract text
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, 11, 2
|
||||
)
|
||||
|
||||
# Find edges (handwriting has more irregular edges)
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Morphological gradient for stroke detection
|
||||
kernel = np.ones((2, 2), np.uint8)
|
||||
gradient = cv2.morphologyEx(binary, cv2.MORPH_GRADIENT, kernel)
|
||||
|
||||
# Skeleton to analyze stroke width
|
||||
# Thin strokes (handwriting) will have more skeleton pixels relative to mass
|
||||
skeleton = _skeletonize(binary)
|
||||
|
||||
# Detect thin strokes by comparing skeleton to original
|
||||
# Dilate skeleton and XOR with original to find thick regions (printed)
|
||||
dilated_skeleton = cv2.dilate(skeleton, np.ones((5, 5), np.uint8), iterations=1)
|
||||
thick_regions = cv2.bitwise_and(binary, cv2.bitwise_not(dilated_skeleton))
|
||||
thin_regions = cv2.bitwise_and(binary, dilated_skeleton)
|
||||
|
||||
# Handwriting tends to be in thin regions with irregular edges
|
||||
handwriting_mask = thin_regions
|
||||
|
||||
# Calculate confidence
|
||||
total_ink = np.sum(binary > 0)
|
||||
thin_ink = np.sum(thin_regions > 0)
|
||||
|
||||
if total_ink > 0:
|
||||
thin_ratio = thin_ink / total_ink
|
||||
confidence = min(thin_ratio * 1.5, 0.8) # Cap at 0.8
|
||||
else:
|
||||
confidence = 0.1
|
||||
|
||||
return handwriting_mask, confidence
|
||||
|
||||
|
||||
def _detect_by_variance(img_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""
|
||||
Detect handwriting by local variance analysis.
|
||||
|
||||
Handwriting has higher local variance in stroke direction and width
|
||||
compared to uniform printed text.
|
||||
"""
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Calculate local variance using a sliding window
|
||||
kernel_size = 15
|
||||
mean = cv2.blur(gray.astype(np.float32), (kernel_size, kernel_size))
|
||||
sqr_mean = cv2.blur((gray.astype(np.float32))**2, (kernel_size, kernel_size))
|
||||
variance = sqr_mean - mean**2
|
||||
|
||||
# Normalize variance
|
||||
variance = cv2.normalize(variance, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
||||
|
||||
# High variance regions might be handwriting
|
||||
# But also edges of printed text, so we need to filter
|
||||
|
||||
# Get text regions first
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, 11, 2
|
||||
)
|
||||
|
||||
# High variance within text regions
|
||||
high_variance_mask = cv2.threshold(variance, 100, 255, cv2.THRESH_BINARY)[1]
|
||||
handwriting_mask = cv2.bitwise_and(high_variance_mask, binary)
|
||||
|
||||
# Calculate confidence based on variance distribution
|
||||
text_pixels = np.sum(binary > 0)
|
||||
high_var_pixels = np.sum(handwriting_mask > 0)
|
||||
|
||||
if text_pixels > 0:
|
||||
var_ratio = high_var_pixels / text_pixels
|
||||
# If 5-40% of text has high variance, likely handwriting present
|
||||
if 0.05 < var_ratio < 0.5:
|
||||
confidence = 0.7
|
||||
else:
|
||||
confidence = 0.3
|
||||
else:
|
||||
confidence = 0.1
|
||||
|
||||
return handwriting_mask, confidence
|
||||
|
||||
|
||||
def _skeletonize(binary: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Morphological skeletonization.
|
||||
"""
|
||||
skeleton = np.zeros(binary.shape, np.uint8)
|
||||
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
|
||||
|
||||
img = binary.copy()
|
||||
while True:
|
||||
eroded = cv2.erode(img, element)
|
||||
temp = cv2.dilate(eroded, element)
|
||||
temp = cv2.subtract(img, temp)
|
||||
skeleton = cv2.bitwise_or(skeleton, temp)
|
||||
img = eroded.copy()
|
||||
|
||||
if cv2.countNonZero(img) == 0:
|
||||
break
|
||||
|
||||
return skeleton
|
||||
|
||||
|
||||
def _clean_mask(mask: np.ndarray, min_area: int = 50) -> np.ndarray:
|
||||
"""
|
||||
Clean up the mask by removing small noise regions.
|
||||
"""
|
||||
# Find connected components
|
||||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
mask, connectivity=8
|
||||
)
|
||||
|
||||
# Create clean mask keeping only components above minimum area
|
||||
clean = np.zeros_like(mask)
|
||||
for i in range(1, num_labels): # Skip background (label 0)
|
||||
area = stats[i, cv2.CC_STAT_AREA]
|
||||
if area >= min_area:
|
||||
clean[labels == i] = 255
|
||||
|
||||
return clean
|
||||
|
||||
|
||||
def mask_to_png(mask: np.ndarray) -> bytes:
|
||||
"""
|
||||
Convert a mask to PNG bytes.
|
||||
"""
|
||||
img = Image.fromarray(mask)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def detect_handwriting_regions(
|
||||
image_bytes: bytes,
|
||||
min_confidence: float = 0.3
|
||||
) -> dict:
|
||||
"""
|
||||
High-level function that returns structured detection results.
|
||||
|
||||
Args:
|
||||
image_bytes: Input image
|
||||
min_confidence: Minimum confidence to report detection
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
result = detect_handwriting(image_bytes)
|
||||
|
||||
has_handwriting = (
|
||||
result.confidence >= min_confidence and
|
||||
result.handwriting_ratio > 0.005 # At least 0.5% handwriting
|
||||
)
|
||||
|
||||
return {
|
||||
"has_handwriting": has_handwriting,
|
||||
"confidence": result.confidence,
|
||||
"handwriting_ratio": result.handwriting_ratio,
|
||||
"detection_method": result.detection_method,
|
||||
"mask_shape": result.mask.shape,
|
||||
}
|
||||
383
klausur-service/backend/services/inpainting_service.py
Normal file
383
klausur-service/backend/services/inpainting_service.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Inpainting Service for Worksheet Cleanup
|
||||
|
||||
Removes handwriting from scanned worksheets using inpainting techniques.
|
||||
Supports multiple backends:
|
||||
1. OpenCV (Telea/NS algorithms) - Fast, CPU-based baseline
|
||||
2. LaMa (Large Mask Inpainting) - Optional, better quality
|
||||
|
||||
DATENSCHUTZ: All processing happens locally on Mac Mini.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
import logging
|
||||
from typing import Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
# OpenCV is optional - only required for actual inpainting
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InpaintingMethod(str, Enum):
|
||||
"""Available inpainting methods."""
|
||||
OPENCV_TELEA = "opencv_telea" # Fast, good for small regions
|
||||
OPENCV_NS = "opencv_ns" # Navier-Stokes, slower but smoother
|
||||
LAMA = "lama" # LaMa deep learning (if available)
|
||||
AUTO = "auto" # Automatically select best method
|
||||
|
||||
|
||||
@dataclass
|
||||
class InpaintingResult:
|
||||
"""Result of inpainting operation."""
|
||||
image: np.ndarray # Cleaned image (BGR)
|
||||
method_used: str # Which method was actually used
|
||||
processing_time_ms: float # Processing time in milliseconds
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
# Global LaMa model (lazy loaded)
|
||||
_lama_model = None
|
||||
_lama_available = None
|
||||
|
||||
|
||||
def check_lama_available() -> bool:
|
||||
"""Check if LaMa inpainting is available."""
|
||||
global _lama_available
|
||||
|
||||
if _lama_available is not None:
|
||||
return _lama_available
|
||||
|
||||
try:
|
||||
# Try to import lama-cleaner library
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
_lama_available = True
|
||||
logger.info("LaMa inpainting is available")
|
||||
except ImportError:
|
||||
_lama_available = False
|
||||
logger.info("LaMa not available, will use OpenCV fallback")
|
||||
except Exception as e:
|
||||
_lama_available = False
|
||||
logger.warning(f"LaMa check failed: {e}")
|
||||
|
||||
return _lama_available
|
||||
|
||||
|
||||
def inpaint_image(
|
||||
image_bytes: bytes,
|
||||
mask_bytes: bytes,
|
||||
method: InpaintingMethod = InpaintingMethod.AUTO,
|
||||
inpaint_radius: int = 3
|
||||
) -> InpaintingResult:
|
||||
"""
|
||||
Inpaint (remove) masked regions from an image.
|
||||
|
||||
Args:
|
||||
image_bytes: Source image as bytes
|
||||
mask_bytes: Binary mask where white (255) = regions to remove
|
||||
method: Inpainting method to use
|
||||
inpaint_radius: Radius for OpenCV inpainting (default 3)
|
||||
|
||||
Returns:
|
||||
InpaintingResult with cleaned image
|
||||
|
||||
Raises:
|
||||
ImportError: If OpenCV is not available
|
||||
"""
|
||||
if not CV2_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCV (cv2) is required for inpainting. "
|
||||
"Install with: pip install opencv-python-headless"
|
||||
)
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Load image
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
img_array = np.array(img)
|
||||
|
||||
# Convert to BGR for OpenCV
|
||||
if len(img_array.shape) == 2:
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
|
||||
elif img_array.shape[2] == 4:
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
|
||||
elif img_array.shape[2] == 3:
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img_bgr = img_array
|
||||
|
||||
# Load mask
|
||||
mask_img = Image.open(io.BytesIO(mask_bytes))
|
||||
mask_array = np.array(mask_img)
|
||||
|
||||
# Ensure mask is single channel
|
||||
if len(mask_array.shape) == 3:
|
||||
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Ensure mask is binary
|
||||
_, mask_binary = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Resize mask if dimensions don't match
|
||||
if mask_binary.shape[:2] != img_bgr.shape[:2]:
|
||||
mask_binary = cv2.resize(
|
||||
mask_binary,
|
||||
(img_bgr.shape[1], img_bgr.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
# Select method
|
||||
if method == InpaintingMethod.AUTO:
|
||||
# Use LaMa if available and mask is large
|
||||
mask_ratio = np.sum(mask_binary > 0) / mask_binary.size
|
||||
if check_lama_available() and mask_ratio > 0.05:
|
||||
method = InpaintingMethod.LAMA
|
||||
else:
|
||||
method = InpaintingMethod.OPENCV_TELEA
|
||||
|
||||
# Perform inpainting
|
||||
if method == InpaintingMethod.LAMA:
|
||||
result_img, actual_method = _inpaint_lama(img_bgr, mask_binary)
|
||||
elif method == InpaintingMethod.OPENCV_NS:
|
||||
result_img = cv2.inpaint(
|
||||
img_bgr, mask_binary, inpaint_radius, cv2.INPAINT_NS
|
||||
)
|
||||
actual_method = "opencv_ns"
|
||||
else: # OPENCV_TELEA (default)
|
||||
result_img = cv2.inpaint(
|
||||
img_bgr, mask_binary, inpaint_radius, cv2.INPAINT_TELEA
|
||||
)
|
||||
actual_method = "opencv_telea"
|
||||
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"Inpainting completed: method={actual_method}, "
|
||||
f"time={processing_time:.1f}ms")
|
||||
|
||||
return InpaintingResult(
|
||||
image=result_img,
|
||||
method_used=actual_method,
|
||||
processing_time_ms=processing_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inpainting failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return InpaintingResult(
|
||||
image=None,
|
||||
method_used="none",
|
||||
processing_time_ms=0,
|
||||
success=False,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
|
||||
def _inpaint_lama(
|
||||
img_bgr: np.ndarray,
|
||||
mask: np.ndarray
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Inpaint using LaMa (Large Mask Inpainting).
|
||||
|
||||
Falls back to OpenCV if LaMa fails.
|
||||
"""
|
||||
global _lama_model
|
||||
|
||||
try:
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy, LDMSampler
|
||||
|
||||
# Initialize model if needed
|
||||
if _lama_model is None:
|
||||
logger.info("Loading LaMa model...")
|
||||
_lama_model = ModelManager(
|
||||
name="lama",
|
||||
device="cpu", # Use CPU for Mac Mini compatibility
|
||||
)
|
||||
logger.info("LaMa model loaded")
|
||||
|
||||
# Prepare config
|
||||
config = Config(
|
||||
ldm_steps=25,
|
||||
ldm_sampler=LDMSampler.plms,
|
||||
hd_strategy=HDStrategy.ORIGINAL,
|
||||
hd_strategy_crop_margin=32,
|
||||
hd_strategy_crop_trigger_size=800,
|
||||
hd_strategy_resize_limit=800,
|
||||
)
|
||||
|
||||
# Convert BGR to RGB for LaMa
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Run inpainting
|
||||
result_rgb = _lama_model(img_rgb, mask, config)
|
||||
|
||||
# Convert back to BGR
|
||||
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
return result_bgr, "lama"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LaMa inpainting failed, falling back to OpenCV: {e}")
|
||||
# Fallback to OpenCV
|
||||
result = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA)
|
||||
return result, "opencv_telea_fallback"
|
||||
|
||||
|
||||
def inpaint_opencv_telea(
|
||||
image_bytes: bytes,
|
||||
mask_bytes: bytes,
|
||||
radius: int = 3
|
||||
) -> bytes:
|
||||
"""
|
||||
Simple OpenCV Telea inpainting - fastest option.
|
||||
|
||||
Args:
|
||||
image_bytes: Source image
|
||||
mask_bytes: Binary mask (white = remove)
|
||||
radius: Inpainting radius
|
||||
|
||||
Returns:
|
||||
Inpainted image as PNG bytes
|
||||
"""
|
||||
result = inpaint_image(
|
||||
image_bytes,
|
||||
mask_bytes,
|
||||
method=InpaintingMethod.OPENCV_TELEA,
|
||||
inpaint_radius=radius
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise RuntimeError(f"Inpainting failed: {result.error_message}")
|
||||
|
||||
return image_to_png(result.image)
|
||||
|
||||
|
||||
def inpaint_opencv_ns(
|
||||
image_bytes: bytes,
|
||||
mask_bytes: bytes,
|
||||
radius: int = 3
|
||||
) -> bytes:
|
||||
"""
|
||||
OpenCV Navier-Stokes inpainting - smoother but slower.
|
||||
"""
|
||||
result = inpaint_image(
|
||||
image_bytes,
|
||||
mask_bytes,
|
||||
method=InpaintingMethod.OPENCV_NS,
|
||||
inpaint_radius=radius
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise RuntimeError(f"Inpainting failed: {result.error_message}")
|
||||
|
||||
return image_to_png(result.image)
|
||||
|
||||
|
||||
def remove_handwriting(
|
||||
image_bytes: bytes,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
method: InpaintingMethod = InpaintingMethod.AUTO
|
||||
) -> Tuple[bytes, dict]:
|
||||
"""
|
||||
High-level function to remove handwriting from an image.
|
||||
|
||||
If no mask is provided, detects handwriting automatically.
|
||||
|
||||
Args:
|
||||
image_bytes: Source image
|
||||
mask: Optional pre-computed mask
|
||||
method: Inpainting method
|
||||
|
||||
Returns:
|
||||
Tuple of (cleaned image bytes, metadata dict)
|
||||
"""
|
||||
from services.handwriting_detection import detect_handwriting, mask_to_png
|
||||
|
||||
# Detect handwriting if no mask provided
|
||||
if mask is None:
|
||||
detection_result = detect_handwriting(image_bytes)
|
||||
mask = detection_result.mask
|
||||
detection_info = {
|
||||
"confidence": detection_result.confidence,
|
||||
"handwriting_ratio": detection_result.handwriting_ratio,
|
||||
"detection_method": detection_result.detection_method
|
||||
}
|
||||
else:
|
||||
detection_info = {"provided_mask": True}
|
||||
|
||||
# Check if there's anything to inpaint
|
||||
if np.sum(mask > 0) == 0:
|
||||
logger.info("No handwriting detected, returning original image")
|
||||
return image_bytes, {
|
||||
"inpainting_performed": False,
|
||||
"reason": "no_handwriting_detected",
|
||||
**detection_info
|
||||
}
|
||||
|
||||
# Convert mask to bytes for inpainting
|
||||
mask_bytes = mask_to_png(mask)
|
||||
|
||||
# Perform inpainting
|
||||
result = inpaint_image(image_bytes, mask_bytes, method=method)
|
||||
|
||||
if not result.success:
|
||||
raise RuntimeError(f"Inpainting failed: {result.error_message}")
|
||||
|
||||
# Convert result to PNG
|
||||
result_bytes = image_to_png(result.image)
|
||||
|
||||
metadata = {
|
||||
"inpainting_performed": True,
|
||||
"method_used": result.method_used,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
**detection_info
|
||||
}
|
||||
|
||||
return result_bytes, metadata
|
||||
|
||||
|
||||
def image_to_png(img_bgr: np.ndarray) -> bytes:
|
||||
"""
|
||||
Convert BGR image array to PNG bytes.
|
||||
"""
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
img_pil = Image.fromarray(img_rgb)
|
||||
buffer = io.BytesIO()
|
||||
img_pil.save(buffer, format='PNG', optimize=True)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def dilate_mask(mask_bytes: bytes, iterations: int = 2) -> bytes:
|
||||
"""
|
||||
Dilate a mask to expand the removal region.
|
||||
|
||||
Useful to ensure complete handwriting removal including edges.
|
||||
"""
|
||||
mask_img = Image.open(io.BytesIO(mask_bytes))
|
||||
mask_array = np.array(mask_img)
|
||||
|
||||
if len(mask_array.shape) == 3:
|
||||
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated = cv2.dilate(mask_array, kernel, iterations=iterations)
|
||||
|
||||
img = Image.fromarray(dilated)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
return buffer.getvalue()
|
||||
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Layout Reconstruction Service for Worksheet Cleanup
|
||||
|
||||
Reconstructs the layout of a worksheet from an image:
|
||||
1. Uses PaddleOCR to detect text with bounding boxes
|
||||
2. Groups text into logical elements (headings, paragraphs, tables)
|
||||
3. Generates Fabric.js compatible JSON for the worksheet editor
|
||||
|
||||
DATENSCHUTZ: All processing happens locally on Mac Mini.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
# OpenCV is optional - only required for actual layout reconstruction
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElementType(str, Enum):
|
||||
"""Types of detected layout elements."""
|
||||
HEADING = "heading"
|
||||
PARAGRAPH = "paragraph"
|
||||
TEXT_LINE = "text_line"
|
||||
TABLE = "table"
|
||||
LIST_ITEM = "list_item"
|
||||
FORM_FIELD = "form_field"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextElement:
|
||||
"""A detected text element with position."""
|
||||
text: str
|
||||
x: float # Left position (pixels)
|
||||
y: float # Top position (pixels)
|
||||
width: float
|
||||
height: float
|
||||
confidence: float
|
||||
element_type: ElementType = ElementType.TEXT_LINE
|
||||
font_size: float = 14.0
|
||||
is_bold: bool = False
|
||||
is_centered: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutResult:
|
||||
"""Result of layout reconstruction."""
|
||||
elements: List[TextElement]
|
||||
page_width: int
|
||||
page_height: int
|
||||
fabric_json: Dict[str, Any]
|
||||
table_regions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def reconstruct_layout(
|
||||
image_bytes: bytes,
|
||||
detect_tables: bool = True
|
||||
) -> LayoutResult:
|
||||
"""
|
||||
Reconstruct the layout of a worksheet from an image.
|
||||
|
||||
Args:
|
||||
image_bytes: Image as bytes
|
||||
detect_tables: Whether to detect table structures
|
||||
|
||||
Returns:
|
||||
LayoutResult with elements and Fabric.js JSON
|
||||
|
||||
Raises:
|
||||
ImportError: If OpenCV is not available
|
||||
"""
|
||||
if not CV2_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCV (cv2) is required for layout reconstruction. "
|
||||
"Install with: pip install opencv-python-headless"
|
||||
)
|
||||
|
||||
# Load image
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
img_array = np.array(img)
|
||||
page_height, page_width = img_array.shape[:2]
|
||||
|
||||
# Run PaddleOCR to get text with positions
|
||||
ocr_results = _run_paddle_ocr(image_bytes)
|
||||
|
||||
if not ocr_results:
|
||||
logger.warning("No text detected by PaddleOCR")
|
||||
return LayoutResult(
|
||||
elements=[],
|
||||
page_width=page_width,
|
||||
page_height=page_height,
|
||||
fabric_json={"version": "5.3.0", "objects": []}
|
||||
)
|
||||
|
||||
# Convert OCR results to TextElements
|
||||
elements = _convert_ocr_to_elements(ocr_results, page_width, page_height)
|
||||
|
||||
# Group elements into lines and detect headings
|
||||
elements = _classify_elements(elements, page_width)
|
||||
|
||||
# Detect table regions if enabled
|
||||
table_regions = []
|
||||
if detect_tables:
|
||||
table_regions = _detect_tables(img_array, elements)
|
||||
|
||||
# Generate Fabric.js JSON
|
||||
fabric_json = _generate_fabric_json(elements, page_width, page_height)
|
||||
|
||||
logger.info(f"Layout reconstruction: {len(elements)} elements, "
|
||||
f"{len(table_regions)} tables")
|
||||
|
||||
return LayoutResult(
|
||||
elements=elements,
|
||||
page_width=page_width,
|
||||
page_height=page_height,
|
||||
fabric_json=fabric_json,
|
||||
table_regions=table_regions
|
||||
)
|
||||
|
||||
|
||||
def _run_paddle_ocr(image_bytes: bytes) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run PaddleOCR on an image.
|
||||
|
||||
Returns list of {text, confidence, bbox} dicts.
|
||||
"""
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr as paddle_ocr_func, OCRRegion
|
||||
|
||||
regions, _ = paddle_ocr_func(image_bytes)
|
||||
|
||||
return [
|
||||
{
|
||||
"text": r.text,
|
||||
"confidence": r.confidence,
|
||||
"bbox": [r.x1, r.y1, r.x2, r.y2]
|
||||
}
|
||||
for r in regions
|
||||
]
|
||||
except ImportError:
|
||||
logger.error("PaddleOCR not available")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"PaddleOCR failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _convert_ocr_to_elements(
|
||||
ocr_results: List[Dict[str, Any]],
|
||||
page_width: int,
|
||||
page_height: int
|
||||
) -> List[TextElement]:
|
||||
"""
|
||||
Convert raw OCR results to TextElements.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
for result in ocr_results:
|
||||
bbox = result["bbox"]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# Calculate dimensions
|
||||
width = x2 - x1
|
||||
height = y2 - y1
|
||||
|
||||
# Estimate font size from height
|
||||
font_size = max(8, min(72, height * 0.8))
|
||||
|
||||
element = TextElement(
|
||||
text=result["text"],
|
||||
x=x1,
|
||||
y=y1,
|
||||
width=width,
|
||||
height=height,
|
||||
confidence=result["confidence"],
|
||||
font_size=font_size
|
||||
)
|
||||
elements.append(element)
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
def _classify_elements(
|
||||
elements: List[TextElement],
|
||||
page_width: int
|
||||
) -> List[TextElement]:
|
||||
"""
|
||||
Classify elements as headings, paragraphs, etc.
|
||||
"""
|
||||
if not elements:
|
||||
return elements
|
||||
|
||||
# Calculate average metrics
|
||||
avg_font_size = sum(e.font_size for e in elements) / len(elements)
|
||||
avg_y = sum(e.y for e in elements) / len(elements)
|
||||
|
||||
for element in elements:
|
||||
# Detect headings (larger font, near top, possibly centered)
|
||||
is_larger = element.font_size > avg_font_size * 1.3
|
||||
is_near_top = element.y < avg_y * 0.3
|
||||
is_centered = abs((element.x + element.width / 2) - page_width / 2) < page_width * 0.15
|
||||
|
||||
if is_larger and (is_near_top or is_centered):
|
||||
element.element_type = ElementType.HEADING
|
||||
element.is_bold = True
|
||||
element.is_centered = is_centered
|
||||
# Detect list items (start with bullet or number)
|
||||
elif element.text.strip().startswith(('•', '-', '–', '*')) or \
|
||||
(len(element.text) > 2 and element.text[0].isdigit() and element.text[1] in '.):'):
|
||||
element.element_type = ElementType.LIST_ITEM
|
||||
# Detect form fields (underscores or dotted lines)
|
||||
elif '_____' in element.text or '.....' in element.text:
|
||||
element.element_type = ElementType.FORM_FIELD
|
||||
else:
|
||||
element.element_type = ElementType.TEXT_LINE
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
def _detect_tables(
|
||||
img_array: np.ndarray,
|
||||
elements: List[TextElement]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect table regions in the image.
|
||||
"""
|
||||
tables = []
|
||||
|
||||
# Convert to grayscale
|
||||
if len(img_array.shape) == 3:
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_array
|
||||
|
||||
# Detect horizontal and vertical lines
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Detect lines using Hough transform
|
||||
lines = cv2.HoughLinesP(
|
||||
edges, 1, np.pi/180, threshold=100,
|
||||
minLineLength=50, maxLineGap=10
|
||||
)
|
||||
|
||||
if lines is None:
|
||||
return tables
|
||||
|
||||
# Separate horizontal and vertical lines
|
||||
horizontal_lines = []
|
||||
vertical_lines = []
|
||||
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)
|
||||
|
||||
if angle < 10: # Horizontal
|
||||
horizontal_lines.append((x1, y1, x2, y2))
|
||||
elif angle > 80: # Vertical
|
||||
vertical_lines.append((x1, y1, x2, y2))
|
||||
|
||||
# Find table regions (intersections of horizontal and vertical lines)
|
||||
if len(horizontal_lines) >= 2 and len(vertical_lines) >= 2:
|
||||
# Sort lines
|
||||
horizontal_lines.sort(key=lambda l: l[1])
|
||||
vertical_lines.sort(key=lambda l: l[0])
|
||||
|
||||
# Find bounding box of table
|
||||
min_x = min(l[0] for l in vertical_lines)
|
||||
max_x = max(l[2] for l in vertical_lines)
|
||||
min_y = min(l[1] for l in horizontal_lines)
|
||||
max_y = max(l[3] for l in horizontal_lines)
|
||||
|
||||
tables.append({
|
||||
"x": min_x,
|
||||
"y": min_y,
|
||||
"width": max_x - min_x,
|
||||
"height": max_y - min_y,
|
||||
"rows": len(horizontal_lines) - 1,
|
||||
"cols": len(vertical_lines) - 1
|
||||
})
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def _generate_fabric_json(
|
||||
elements: List[TextElement],
|
||||
page_width: int,
|
||||
page_height: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate Fabric.js compatible JSON from elements.
|
||||
"""
|
||||
fabric_objects = []
|
||||
|
||||
for i, element in enumerate(elements):
|
||||
fabric_obj = {
|
||||
"type": "textbox",
|
||||
"version": "5.3.0",
|
||||
"originX": "left",
|
||||
"originY": "top",
|
||||
"left": element.x,
|
||||
"top": element.y,
|
||||
"width": max(element.width, 100),
|
||||
"height": element.height,
|
||||
"fill": "#000000",
|
||||
"stroke": None,
|
||||
"strokeWidth": 0,
|
||||
"text": element.text,
|
||||
"fontSize": element.font_size,
|
||||
"fontWeight": "bold" if element.is_bold else "normal",
|
||||
"fontFamily": "Arial",
|
||||
"textAlign": "center" if element.is_centered else "left",
|
||||
"underline": False,
|
||||
"lineHeight": 1.2,
|
||||
"charSpacing": 0,
|
||||
"splitByGrapheme": False,
|
||||
"editable": True,
|
||||
"selectable": True,
|
||||
"data": {
|
||||
"elementType": element.element_type.value,
|
||||
"confidence": element.confidence,
|
||||
"originalIndex": i
|
||||
}
|
||||
}
|
||||
fabric_objects.append(fabric_obj)
|
||||
|
||||
return {
|
||||
"version": "5.3.0",
|
||||
"objects": fabric_objects,
|
||||
"background": "#ffffff"
|
||||
}
|
||||
|
||||
|
||||
def layout_to_fabric_json(layout_result: LayoutResult) -> str:
|
||||
"""
|
||||
Convert LayoutResult to JSON string for frontend.
|
||||
"""
|
||||
return json.dumps(layout_result.fabric_json, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def reconstruct_and_clean(
|
||||
image_bytes: bytes,
|
||||
remove_handwriting: bool = True
|
||||
) -> Tuple[bytes, LayoutResult]:
|
||||
"""
|
||||
Full pipeline: clean handwriting and reconstruct layout.
|
||||
|
||||
Args:
|
||||
image_bytes: Source image
|
||||
remove_handwriting: Whether to remove handwriting first
|
||||
|
||||
Returns:
|
||||
Tuple of (cleaned image bytes, layout result)
|
||||
"""
|
||||
if remove_handwriting:
|
||||
from services.inpainting_service import remove_handwriting as clean_hw
|
||||
cleaned_bytes, _ = clean_hw(image_bytes)
|
||||
else:
|
||||
cleaned_bytes = image_bytes
|
||||
|
||||
layout = reconstruct_layout(cleaned_bytes)
|
||||
|
||||
return cleaned_bytes, layout
|
||||
586
klausur-service/backend/services/trocr_service.py
Normal file
586
klausur-service/backend/services/trocr_service.py
Normal file
@@ -0,0 +1,586 @@
|
||||
"""
|
||||
TrOCR Service
|
||||
|
||||
Microsoft's Transformer-based OCR for text recognition.
|
||||
Besonders geeignet fuer:
|
||||
- Gedruckten Text
|
||||
- Saubere Scans
|
||||
- Schnelle Verarbeitung
|
||||
|
||||
Model: microsoft/trocr-base-printed (oder handwritten Variante)
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
|
||||
Phase 2 Enhancements:
|
||||
- Batch processing for multiple images
|
||||
- SHA256-based caching for repeated requests
|
||||
- Model preloading for faster first request
|
||||
- Word-level bounding boxes with confidence
|
||||
"""
|
||||
|
||||
import io
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
_trocr_processor = None
|
||||
_trocr_model = None
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
if _trocr_processor is None or _trocr_model is None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
# Choose model based on use case
|
||||
if handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
_trocr_processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
_trocr_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
_trocr_model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model: {e}")
|
||||
return None, None
|
||||
|
||||
return _trocr_processor, _trocr_model
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
# Split image into lines and process each
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image] # Fallback to full image
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
# Prepare input
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
# Move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
# Decode
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
# TrOCR doesn't provide confidence, estimate based on output
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
# Combine results
|
||||
text = "\n".join(all_text)
|
||||
|
||||
# Average confidence
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
handwritten = "--handwritten" in sys.argv
|
||||
args = [a for a in sys.argv[1:] if not a.startswith("--")]
|
||||
|
||||
if args:
|
||||
asyncio.run(test_trocr_ocr(args[0], handwritten=handwritten))
|
||||
else:
|
||||
print("Usage: python trocr_service.py <image_path> [--handwritten]")
|
||||
Reference in New Issue
Block a user