backend-lehrer (5 files): - alerts_agent/db/repository.py (992 → 5), abitur_docs_api.py (956 → 3) - teacher_dashboard_api.py (951 → 3), services/pdf_service.py (916 → 3) - mail/mail_db.py (987 → 6) klausur-service (5 files): - legal_templates_ingestion.py (942 → 3), ocr_pipeline_postprocess.py (929 → 4) - ocr_pipeline_words.py (876 → 3), ocr_pipeline_ocr_merge.py (616 → 2) - KorrekturPage.tsx (956 → 6) website (5 pages): - mail (985 → 9), edu-search (958 → 8), mac-mini (950 → 7) - ocr-labeling (946 → 7), audit-workspace (871 → 4) studio-v2 (5 files + 1 deleted): - page.tsx (946 → 5), MessagesContext.tsx (925 → 4) - korrektur (914 → 6), worksheet-cleanup (899 → 6) - useVocabWorksheet.ts (888 → 3) - Deleted dead page-original.tsx (934 LOC) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""
|
|
OCR Pipeline Validation — image detection, generation, validation save,
|
|
and handwriting removal endpoints.
|
|
|
|
Extracted from ocr_pipeline_postprocess.py.
|
|
|
|
Lizenz: Apache 2.0
|
|
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from ocr_pipeline_session_store import (
|
|
get_session_db,
|
|
get_session_image,
|
|
update_session_db,
|
|
)
|
|
from ocr_pipeline_common import (
|
|
_cache,
|
|
RemoveHandwritingRequest,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pydantic Models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
STYLE_SUFFIXES = {
|
|
"educational": "educational illustration, textbook style, clear, colorful",
|
|
"cartoon": "cartoon, child-friendly, simple shapes",
|
|
"sketch": "pencil sketch, hand-drawn, black and white",
|
|
"clipart": "clipart, flat vector style, simple",
|
|
"realistic": "photorealistic, high detail",
|
|
}
|
|
|
|
|
|
class ValidationRequest(BaseModel):
|
|
notes: Optional[str] = None
|
|
score: Optional[int] = None
|
|
|
|
|
|
class GenerateImageRequest(BaseModel):
|
|
region_index: int
|
|
prompt: str
|
|
style: str = "educational"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Image detection + generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
|
async def detect_image_regions(session_id: str):
|
|
"""Detect illustration/image regions in the original scan using VLM."""
|
|
import base64
|
|
import httpx
|
|
import re
|
|
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
original_png = await get_session_image(session_id, "original")
|
|
if not original_png:
|
|
raise HTTPException(status_code=400, detail="No original image found")
|
|
|
|
word_result = session.get("word_result") or {}
|
|
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
|
vocab_context = ""
|
|
if entries:
|
|
sample = entries[:10]
|
|
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
|
if words:
|
|
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
|
|
|
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
|
|
|
prompt = (
|
|
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
|
"(NOT text, NOT table cells, NOT blank areas). "
|
|
"For each image region found, return its bounding box as percentage of page dimensions "
|
|
"and a short English description of what the image shows. "
|
|
"Reply with ONLY a JSON array like: "
|
|
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
|
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
|
"If there are NO images on the page, return an empty array: []"
|
|
f"{vocab_context}"
|
|
)
|
|
|
|
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"images": [img_b64],
|
|
"stream": False,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
|
resp.raise_for_status()
|
|
text = resp.json().get("response", "")
|
|
|
|
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
|
if match:
|
|
raw_regions = json.loads(match.group(0))
|
|
else:
|
|
raw_regions = []
|
|
|
|
regions = []
|
|
for r in raw_regions:
|
|
regions.append({
|
|
"bbox_pct": {
|
|
"x": max(0, min(100, float(r.get("x", 0)))),
|
|
"y": max(0, min(100, float(r.get("y", 0)))),
|
|
"w": max(1, min(100, float(r.get("w", 10)))),
|
|
"h": max(1, min(100, float(r.get("h", 10)))),
|
|
},
|
|
"description": r.get("description", ""),
|
|
"prompt": r.get("description", ""),
|
|
"image_b64": None,
|
|
"style": "educational",
|
|
})
|
|
|
|
# Enrich prompts with nearby vocab context
|
|
if entries:
|
|
for region in regions:
|
|
ry = region["bbox_pct"]["y"]
|
|
rh = region["bbox_pct"]["h"]
|
|
nearby = [
|
|
e for e in entries
|
|
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
|
]
|
|
if nearby:
|
|
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
|
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
|
if en_words or de_words:
|
|
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
|
if de_words:
|
|
context += f" / {', '.join(de_words[:5])}"
|
|
context += ")"
|
|
region["prompt"] = region["description"] + context
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
validation = ground_truth.get("validation") or {}
|
|
validation["image_regions"] = regions
|
|
validation["detected_at"] = datetime.utcnow().isoformat()
|
|
ground_truth["validation"] = validation
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
|
|
|
return {"regions": regions, "count": len(regions)}
|
|
|
|
except httpx.ConnectError:
|
|
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
|
return {"regions": [], "count": 0, "error": "VLM not available"}
|
|
except Exception as e:
|
|
logger.error(f"Image detection failed for {session_id}: {e}")
|
|
return {"regions": [], "count": 0, "error": str(e)}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
|
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
|
"""Generate a replacement image for a detected region using mflux."""
|
|
import httpx
|
|
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
validation = ground_truth.get("validation") or {}
|
|
regions = validation.get("image_regions") or []
|
|
|
|
if req.region_index < 0 or req.region_index >= len(regions):
|
|
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
|
|
|
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
|
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
|
full_prompt = f"{req.prompt}, {style_suffix}"
|
|
|
|
region = regions[req.region_index]
|
|
bbox = region["bbox_pct"]
|
|
aspect = bbox["w"] / max(bbox["h"], 1)
|
|
if aspect > 1.3:
|
|
width, height = 768, 512
|
|
elif aspect < 0.7:
|
|
width, height = 512, 768
|
|
else:
|
|
width, height = 512, 512
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
resp = await client.post(f"{mflux_url}/generate", json={
|
|
"prompt": full_prompt,
|
|
"width": width,
|
|
"height": height,
|
|
"steps": 4,
|
|
})
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
image_b64 = data.get("image_b64")
|
|
|
|
if not image_b64:
|
|
return {"image_b64": None, "success": False, "error": "No image returned"}
|
|
|
|
regions[req.region_index]["image_b64"] = image_b64
|
|
regions[req.region_index]["prompt"] = req.prompt
|
|
regions[req.region_index]["style"] = req.style
|
|
validation["image_regions"] = regions
|
|
ground_truth["validation"] = validation
|
|
await update_session_db(session_id, ground_truth=ground_truth)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
|
return {"image_b64": image_b64, "success": True}
|
|
|
|
except httpx.ConnectError:
|
|
logger.warning(f"mflux-service not available at {mflux_url}")
|
|
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
|
except Exception as e:
|
|
logger.error(f"Image generation failed for {session_id}: {e}")
|
|
return {"image_b64": None, "success": False, "error": str(e)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Validation save/get
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/reconstruction/validate")
|
|
async def save_validation(session_id: str, req: ValidationRequest):
|
|
"""Save final validation results for step 8."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
validation = ground_truth.get("validation") or {}
|
|
validation["validated_at"] = datetime.utcnow().isoformat()
|
|
validation["notes"] = req.notes
|
|
validation["score"] = req.score
|
|
ground_truth["validation"] = validation
|
|
|
|
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
|
|
|
if session_id in _cache:
|
|
_cache[session_id]["ground_truth"] = ground_truth
|
|
|
|
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
|
|
|
return {"session_id": session_id, "validation": validation}
|
|
|
|
|
|
@router.get("/sessions/{session_id}/reconstruction/validation")
|
|
async def get_validation(session_id: str):
|
|
"""Retrieve saved validation data for step 8."""
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
ground_truth = session.get("ground_truth") or {}
|
|
validation = ground_truth.get("validation")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"validation": validation,
|
|
"word_result": session.get("word_result"),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Remove handwriting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/sessions/{session_id}/remove-handwriting")
|
|
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
|
"""Remove handwriting from a session image using inpainting."""
|
|
import time as _time
|
|
|
|
from services.handwriting_detection import detect_handwriting
|
|
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
|
|
|
session = await get_session_db(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
|
|
t0 = _time.monotonic()
|
|
|
|
# 1. Determine source image
|
|
source = req.use_source
|
|
if source == "auto":
|
|
deskewed = await get_session_image(session_id, "deskewed")
|
|
source = "deskewed" if deskewed else "original"
|
|
|
|
image_bytes = await get_session_image(session_id, source)
|
|
if not image_bytes:
|
|
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
|
|
|
# 2. Detect handwriting mask
|
|
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
|
|
|
# 3. Convert mask to PNG bytes and dilate
|
|
import io
|
|
from PIL import Image as _PILImage
|
|
mask_img = _PILImage.fromarray(detection.mask)
|
|
mask_buf = io.BytesIO()
|
|
mask_img.save(mask_buf, format="PNG")
|
|
mask_bytes = mask_buf.getvalue()
|
|
|
|
if req.dilation > 0:
|
|
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
|
|
|
# 4. Inpaint
|
|
method_map = {
|
|
"telea": InpaintingMethod.OPENCV_TELEA,
|
|
"ns": InpaintingMethod.OPENCV_NS,
|
|
"auto": InpaintingMethod.AUTO,
|
|
}
|
|
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
|
|
|
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
|
if not result.success:
|
|
raise HTTPException(status_code=500, detail="Inpainting failed")
|
|
|
|
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
|
|
|
meta = {
|
|
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
|
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
|
"detection_confidence": round(detection.confidence, 4),
|
|
"target_ink": req.target_ink,
|
|
"dilation": req.dilation,
|
|
"source_image": source,
|
|
"processing_time_ms": elapsed_ms,
|
|
}
|
|
|
|
# 5. Persist clean image
|
|
clean_png_bytes = image_to_png(result.image)
|
|
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
|
|
|
return {
|
|
**meta,
|
|
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
|
"session_id": session_id,
|
|
}
|