Files
breakpilot-lehrer/klausur-service/backend/ocr_pipeline_validation.py
Benjamin Admin b6983ab1dc [split-required] Split 500-1000 LOC files across all services
backend-lehrer (5 files):
- alerts_agent/db/repository.py (992 → 5), abitur_docs_api.py (956 → 3)
- teacher_dashboard_api.py (951 → 3), services/pdf_service.py (916 → 3)
- mail/mail_db.py (987 → 6)

klausur-service (5 files):
- legal_templates_ingestion.py (942 → 3), ocr_pipeline_postprocess.py (929 → 4)
- ocr_pipeline_words.py (876 → 3), ocr_pipeline_ocr_merge.py (616 → 2)
- KorrekturPage.tsx (956 → 6)

website (5 pages):
- mail (985 → 9), edu-search (958 → 8), mac-mini (950 → 7)
- ocr-labeling (946 → 7), audit-workspace (871 → 4)

studio-v2 (5 files + 1 deleted):
- page.tsx (946 → 5), MessagesContext.tsx (925 → 4)
- korrektur (914 → 6), worksheet-cleanup (899 → 6)
- useVocabWorksheet.ts (888 → 3)
- Deleted dead page-original.tsx (934 LOC)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-24 23:35:37 +02:00

363 lines
13 KiB
Python

"""
OCR Pipeline Validation — image detection, generation, validation save,
and handwriting removal endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Image detection + generation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM."""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux."""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
# ---------------------------------------------------------------------------
# Validation save/get
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""Remove handwriting from a session image using inpainting."""
import time as _time
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
t0 = _time.monotonic()
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}