Files
breakpilot-lehrer/klausur-service/backend/ocr/pipeline/words.py
Benjamin Admin 0504d22b8e
Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 29s
CI / test-go-edu-search (push) Successful in 29s
CI / test-python-klausur (push) Failing after 2m25s
CI / test-python-agent-core (push) Successful in 19s
CI / test-nodejs-website (push) Successful in 20s
Restructure: Move ocr_pipeline + labeling + crop into ocr/ package
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-25 21:51:43 +02:00

186 lines
6.1 KiB
Python

"""
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
and ground truth endpoints.
Split into sub-modules:
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
ocr_pipeline_words_stream — SSE streaming generators
This barrel module contains the PaddleOCR direct endpoint and ground truth
endpoints, and assembles all word-related routers.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_words_first import build_grid_from_words
from .session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from .common import (
_cache,
_append_pipeline_log,
)
from .words_detect import router as _detect_router
logger = logging.getLogger(__name__)
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
class WordGroundTruthRequest(BaseModel):
is_correct: bool
corrected_entries: Optional[List[Dict[str, Any]]] = None
notes: Optional[str] = None
# ---------------------------------------------------------------------------
# PaddleOCR Direct Endpoint
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/paddle-direct")
async def paddle_direct(session_id: str):
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode original image")
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
word_dicts = await ocr_region_paddle(img_bgr, region=None)
if not word_dicts:
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
duration = time.time() - t0
for cell in cells:
cell["ocr_engine"] = "paddle_direct"
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": "paddle_direct",
"grid_method": "paddle_direct",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
await update_session_db(
session_id,
word_result=word_result,
cropped_png=img_png,
current_step=8,
)
logger.info(
"paddle_direct session %s: %d cells (%d rows, %d cols) in %.2fs",
session_id, len(cells), n_rows, n_cols, duration,
)
await _append_pipeline_log(session_id, "paddle_direct", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": "paddle_direct",
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
# ---------------------------------------------------------------------------
# Ground Truth Words Endpoints
# ---------------------------------------------------------------------------
@_local_router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
gt = {
"is_correct": req.is_correct,
"corrected_entries": req.corrected_entries,
"notes": req.notes,
"saved_at": datetime.utcnow().isoformat(),
"word_result": session.get("word_result"),
}
ground_truth["words"] = gt
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
return {"session_id": session_id, "ground_truth": gt}
@_local_router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
words_gt = ground_truth.get("words")
if not words_gt:
raise HTTPException(status_code=404, detail="No word ground truth saved")
return {
"session_id": session_id,
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Composite router
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_detect_router)
router.include_router(_local_router)