[split-required] Split 500-850 LOC files (batch 2)
backend-lehrer (10 files): - game/database.py (785 → 5), correction_api.py (683 → 4) - classroom_engine/antizipation.py (676 → 5) - llm_gateway schools/edu_search already done in prior batch klausur-service (12 files): - orientation_crop_api.py (694 → 5), pdf_export.py (677 → 4) - zeugnis_crawler.py (676 → 5), grid_editor_api.py (671 → 5) - eh_templates.py (658 → 5), mail/api.py (651 → 5) - qdrant_service.py (638 → 5), training_api.py (625 → 4) website (6 pages): - middleware (696 → 8), mail (733 → 6), consent (628 → 8) - compliance/risks (622 → 5), export (502 → 5), brandbook (629 → 7) studio-v2 (3 components): - B2BMigrationWizard (848 → 3), CleanupPanel (765 → 2) - dashboard-experimental (739 → 2) admin-lehrer (4 files): - uebersetzungen (769 → 4), manager (670 → 2) - ChunkBrowserQA (675 → 6), dsfa/page (674 → 5) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
|
||||
|
||||
Auto-crop, manual crop, and skip-crop for scanner/book borders.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from page_crop import detect_and_crop_page, detect_page_splits
|
||||
from ocr_pipeline_session_store import get_sub_sessions, update_session_db
|
||||
|
||||
from orientation_crop_helpers import ensure_cached, append_pipeline_log
|
||||
from page_sub_sessions import create_page_sub_sessions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/crop")
|
||||
async def auto_crop(session_id: str):
|
||||
"""Auto-detect and crop scanner/book borders.
|
||||
|
||||
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
|
||||
Falls back to oriented -> original if earlier steps were skipped.
|
||||
|
||||
If the image is a multi-page spread (e.g. book on scanner), it will
|
||||
automatically split into separate sub-sessions per page, crop each
|
||||
individually, and return the split info.
|
||||
"""
|
||||
cached = await ensure_cached(session_id)
|
||||
|
||||
# Use dewarped (preferred), fall back to oriented, then original
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# --- Check for existing sub-sessions (from page-split step) ---
|
||||
# If page-split already created sub-sessions, skip multi-page detection
|
||||
# in the crop step. Each sub-session runs its own crop independently.
|
||||
existing_subs = await get_sub_sessions(session_id)
|
||||
if existing_subs:
|
||||
crop_result = cached.get("crop_result") or {}
|
||||
if crop_result.get("multi_page"):
|
||||
# Already split -- just return the existing info
|
||||
duration = time.time() - t0
|
||||
h, w = img_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"sub_sessions": [
|
||||
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing_subs)
|
||||
],
|
||||
"note": "Page split was already performed; each sub-session runs its own crop.",
|
||||
}
|
||||
|
||||
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
|
||||
page_splits = detect_page_splits(img_bgr)
|
||||
|
||||
if page_splits and len(page_splits) >= 2:
|
||||
# Multi-page spread detected -- create sub-sessions
|
||||
sub_sessions = await create_page_sub_sessions(
|
||||
session_id, cached, img_bgr, page_splits,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
crop_info: Dict[str, Any] = {
|
||||
"crop_applied": True,
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
"page_splits": page_splits,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
cached["crop_result"] = crop_info
|
||||
|
||||
# Store the first page as the main cropped image for backward compat
|
||||
first_page = page_splits[0]
|
||||
first_bgr = img_bgr[
|
||||
first_page["y"]:first_page["y"] + first_page["height"],
|
||||
first_page["x"]:first_page["x"] + first_page["width"],
|
||||
].copy()
|
||||
first_cropped, _ = detect_and_crop_page(first_bgr)
|
||||
cached["cropped_bgr"] = first_cropped
|
||||
|
||||
ok, png_buf = cv2.imencode(".png", first_cropped)
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=png_buf.tobytes() if ok else b"",
|
||||
crop_result=crop_info,
|
||||
current_step=5,
|
||||
status='split',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
|
||||
session_id, len(page_splits), duration,
|
||||
)
|
||||
|
||||
await append_pipeline_log(session_id, "crop", {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = first_cropped.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
"sub_sessions": sub_sessions,
|
||||
}
|
||||
|
||||
# --- Single page (normal) ---
|
||||
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
|
||||
|
||||
duration = time.time() - t0
|
||||
crop_info["duration_seconds"] = round(duration, 2)
|
||||
crop_info["multi_page"] = False
|
||||
|
||||
# Encode cropped image
|
||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
# Update cache
|
||||
cached["cropped_bgr"] = cropped_bgr
|
||||
cached["crop_result"] = crop_info
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_info,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
|
||||
session_id, crop_info["crop_applied"],
|
||||
crop_info.get("detected_format", "?"),
|
||||
duration,
|
||||
)
|
||||
|
||||
await append_pipeline_log(session_id, "crop", {
|
||||
"crop_applied": crop_info["crop_applied"],
|
||||
"detected_format": crop_info.get("detected_format"),
|
||||
"format_confidence": crop_info.get("format_confidence"),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = cropped_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
|
||||
|
||||
class ManualCropRequest(BaseModel):
|
||||
x: float # percentage 0-100
|
||||
y: float # percentage 0-100
|
||||
width: float # percentage 0-100
|
||||
height: float # percentage 0-100
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/crop/manual")
|
||||
async def manual_crop(session_id: str, req: ManualCropRequest):
|
||||
"""Manually crop using percentage coordinates."""
|
||||
cached = await ensure_cached(session_id)
|
||||
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Convert percentages to pixels
|
||||
px_x = int(w * req.x / 100.0)
|
||||
px_y = int(h * req.y / 100.0)
|
||||
px_w = int(w * req.width / 100.0)
|
||||
px_h = int(h * req.height / 100.0)
|
||||
|
||||
# Clamp
|
||||
px_x = max(0, min(px_x, w - 1))
|
||||
px_y = max(0, min(px_y, h - 1))
|
||||
px_w = max(1, min(px_w, w - px_x))
|
||||
px_h = max(1, min(px_h, h - px_y))
|
||||
|
||||
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
|
||||
|
||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
crop_result = {
|
||||
"crop_applied": True,
|
||||
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
|
||||
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
|
||||
"width": round(req.width, 2), "height": round(req.height, 2)},
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": px_w, "height": px_h},
|
||||
"method": "manual",
|
||||
}
|
||||
|
||||
cached["cropped_bgr"] = cropped_bgr
|
||||
cached["crop_result"] = crop_result
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_result,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
ch, cw = cropped_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": cw,
|
||||
"image_height": ch,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/crop/skip")
|
||||
async def skip_crop(session_id: str):
|
||||
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
|
||||
cached = await ensure_cached(session_id)
|
||||
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available")
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Store the dewarped image as cropped (identity crop)
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
crop_result = {
|
||||
"crop_applied": False,
|
||||
"skipped": True,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
}
|
||||
|
||||
cached["cropped_bgr"] = img_bgr
|
||||
cached["crop_result"] = crop_result
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_result,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
@@ -1,658 +1,34 @@
|
||||
"""
|
||||
Erwartungshorizont Templates for Vorabitur Mode
|
||||
Erwartungshorizont Templates for Vorabitur Mode — barrel re-export.
|
||||
|
||||
Provides pre-defined templates based on German Abitur text analysis types:
|
||||
- Textanalyse (pragmatische Texte)
|
||||
- Sachtextanalyse
|
||||
- Gedichtanalyse / Lyrikinterpretation
|
||||
- Dramenanalyse
|
||||
- Epische Textanalyse / Prosaanalyse
|
||||
- Eroerterung (textgebunden / frei)
|
||||
- Literarische Eroerterung
|
||||
- Materialgestuetztes Schreiben
|
||||
|
||||
Each template includes:
|
||||
- Structured criteria with weights
|
||||
- Typical expectations per section
|
||||
- NiBiS-aligned evaluation points
|
||||
The actual code lives in:
|
||||
- eh_templates_types.py (AUFGABENTYPEN, EHKriterium, EHTemplate)
|
||||
- eh_templates_analyse.py (Textanalyse, Gedicht, Prosa, Drama)
|
||||
- eh_templates_eroerterung.py (Eroerterung textgebunden)
|
||||
- eh_templates_registry.py (TEMPLATES, get_template, list_templates, etc.)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
# Types
|
||||
from eh_templates_types import ( # noqa: F401
|
||||
AUFGABENTYPEN,
|
||||
EHKriterium,
|
||||
EHTemplate,
|
||||
)
|
||||
|
||||
# Template factories
|
||||
from eh_templates_analyse import ( # noqa: F401
|
||||
get_textanalyse_template,
|
||||
get_gedichtanalyse_template,
|
||||
get_prosaanalyse_template,
|
||||
get_dramenanalyse_template,
|
||||
)
|
||||
from eh_templates_eroerterung import get_eroerterung_template # noqa: F401
|
||||
|
||||
# =============================================
|
||||
# TEMPLATE TYPES
|
||||
# =============================================
|
||||
|
||||
AUFGABENTYPEN = {
|
||||
"textanalyse_pragmatisch": {
|
||||
"name": "Textanalyse (pragmatische Texte)",
|
||||
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
|
||||
"category": "analyse"
|
||||
},
|
||||
"sachtextanalyse": {
|
||||
"name": "Sachtextanalyse",
|
||||
"description": "Analyse von informativen und appellativen Sachtexten",
|
||||
"category": "analyse"
|
||||
},
|
||||
"gedichtanalyse": {
|
||||
"name": "Gedichtanalyse / Lyrikinterpretation",
|
||||
"description": "Analyse und Interpretation lyrischer Texte",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"dramenanalyse": {
|
||||
"name": "Dramenanalyse",
|
||||
"description": "Analyse dramatischer Texte und Szenen",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"prosaanalyse": {
|
||||
"name": "Epische Textanalyse / Prosaanalyse",
|
||||
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"eroerterung_textgebunden": {
|
||||
"name": "Textgebundene Eroerterung",
|
||||
"description": "Eroerterung auf Basis eines Sachtextes",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"eroerterung_frei": {
|
||||
"name": "Freie Eroerterung",
|
||||
"description": "Freie Eroerterung zu einem Thema",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"eroerterung_literarisch": {
|
||||
"name": "Literarische Eroerterung",
|
||||
"description": "Eroerterung zu literarischen Fragestellungen",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"materialgestuetzt": {
|
||||
"name": "Materialgestuetztes Schreiben",
|
||||
"description": "Verfassen eines Textes auf Materialbasis",
|
||||
"category": "produktion"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# TEMPLATE STRUCTURES
|
||||
# =============================================
|
||||
|
||||
@dataclass
|
||||
class EHKriterium:
|
||||
"""Single criterion in an Erwartungshorizont."""
|
||||
id: str
|
||||
name: str
|
||||
beschreibung: str
|
||||
gewichtung: int # Percentage weight (0-100)
|
||||
erwartungen: List[str] # Expected points/elements
|
||||
max_punkte: int = 100
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EHTemplate:
|
||||
"""Complete Erwartungshorizont template."""
|
||||
id: str
|
||||
aufgabentyp: str
|
||||
name: str
|
||||
beschreibung: str
|
||||
kriterien: List[EHKriterium]
|
||||
einleitung_hinweise: List[str]
|
||||
hauptteil_hinweise: List[str]
|
||||
schluss_hinweise: List[str]
|
||||
sprachliche_aspekte: List[str]
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now())
|
||||
|
||||
def to_dict(self):
|
||||
d = {
|
||||
'id': self.id,
|
||||
'aufgabentyp': self.aufgabentyp,
|
||||
'name': self.name,
|
||||
'beschreibung': self.beschreibung,
|
||||
'kriterien': [k.to_dict() for k in self.kriterien],
|
||||
'einleitung_hinweise': self.einleitung_hinweise,
|
||||
'hauptteil_hinweise': self.hauptteil_hinweise,
|
||||
'schluss_hinweise': self.schluss_hinweise,
|
||||
'sprachliche_aspekte': self.sprachliche_aspekte,
|
||||
'created_at': self.created_at.isoformat()
|
||||
}
|
||||
return d
|
||||
|
||||
|
||||
# =============================================
|
||||
# PRE-DEFINED TEMPLATES
|
||||
# =============================================
|
||||
|
||||
def get_textanalyse_template() -> EHTemplate:
|
||||
"""Template for pragmatic text analysis."""
|
||||
return EHTemplate(
|
||||
id="template_textanalyse_pragmatisch",
|
||||
aufgabentyp="textanalyse_pragmatisch",
|
||||
name="Textanalyse pragmatischer Texte",
|
||||
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Wiedergabe des Textinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Textaussage/These",
|
||||
"Vollstaendige Wiedergabe der Argumentationsstruktur",
|
||||
"Erkennen von Intention und Adressatenbezug",
|
||||
"Einordnung in den historischen/gesellschaftlichen Kontext",
|
||||
"Beruecksichtigung aller relevanten Textaspekte"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau und Gliederung der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Sinnvolle Einleitung mit Basisinformationen",
|
||||
"Logische Gliederung des Hauptteils",
|
||||
"Stringente Gedankenfuehrung",
|
||||
"Angemessener Schluss mit Fazit/Wertung",
|
||||
"Absatzgliederung und Ueberlaenge"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="analyse",
|
||||
name="Analytische Qualitaet",
|
||||
beschreibung="Tiefe und Qualitaet der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Erkennen rhetorischer Mittel",
|
||||
"Funktionale Deutung der Stilmittel",
|
||||
"Analyse der Argumentationsweise",
|
||||
"Beruecksichtigung von Wortwahl und Satzbau",
|
||||
"Verknuepfung von Form und Inhalt"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung",
|
||||
"Korrekte Getrennt- und Zusammenschreibung",
|
||||
"Korrekte Fremdwortschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Flexion",
|
||||
"Korrekte Zeichensetzung",
|
||||
"Korrekte Bezuege und Kongruenz"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
|
||||
"Benennung des Themas",
|
||||
"Formulierung der Kernthese/Hauptaussage",
|
||||
"Ggf. Einordnung in den Kontext"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Systematische Analyse der Argumentationsstruktur",
|
||||
"Untersuchung der sprachlichen Gestaltung",
|
||||
"Funktionale Deutung der Stilmittel",
|
||||
"Beruecksichtigung von Adressatenbezug und Intention",
|
||||
"Textbelege durch Zitate"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Bewertung der Ueberzeugungskraft",
|
||||
"Ggf. aktuelle Relevanz",
|
||||
"Persoenliche Stellungnahme (wenn gefordert)"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachsprachliche Begriffe korrekt verwenden",
|
||||
"Konjunktiv fuer indirekte Rede",
|
||||
"Praesens als Tempus der Analyse",
|
||||
"Sachlicher, analytischer Stil"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_gedichtanalyse_template() -> EHTemplate:
|
||||
"""Template for poetry analysis."""
|
||||
return EHTemplate(
|
||||
id="template_gedichtanalyse",
|
||||
aufgabentyp="gedichtanalyse",
|
||||
name="Gedichtanalyse / Lyrikinterpretation",
|
||||
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Gedichtinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
|
||||
"Vollstaendige inhaltliche Erschliessung aller Strophen",
|
||||
"Erkennen der zentralen Motive und Themen",
|
||||
"Epochenzuordnung und literaturgeschichtliche Einordnung",
|
||||
"Deutung der Bildlichkeit und Symbolik"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Interpretation",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Einleitung mit Basisinformationen",
|
||||
"Systematische strophenweise oder aspektorientierte Analyse",
|
||||
"Verknuepfung von Form- und Inhaltsanalyse",
|
||||
"Schluessige Gesamtdeutung im Schluss"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="formanalyse",
|
||||
name="Formale Analyse",
|
||||
beschreibung="Analyse der lyrischen Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Bestimmung von Metrum und Reimschema",
|
||||
"Analyse der Klanggestaltung",
|
||||
"Erkennen von Enjambements und Zaesuren",
|
||||
"Deutung der formalen Mittel",
|
||||
"Verknuepfung von Form und Inhalt"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung",
|
||||
"Korrekte Getrennt- und Zusammenschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Flexion",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Entstehungsjahr/Epoche",
|
||||
"Thema/Motiv des Gedichts",
|
||||
"Erste Deutungshypothese",
|
||||
"Formale Grunddaten (Strophen, Verse)"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
|
||||
"Formale Analyse (Metrum, Reim, Klang)",
|
||||
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
|
||||
"Funktionale Verknuepfung aller Ebenen",
|
||||
"Textbelege durch Zitate mit Versangabe"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Interpretationsergebnisse",
|
||||
"Bestaetigung/Modifikation der Deutungshypothese",
|
||||
"Einordnung in Epoche/Werk des Autors",
|
||||
"Aktualitaetsbezug (wenn sinnvoll)"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Lyrikanalyse verwenden",
|
||||
"Zwischen lyrischem Ich und Autor unterscheiden",
|
||||
"Praesens als Analysetempus",
|
||||
"Deutende statt beschreibende Formulierungen"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_eroerterung_template() -> EHTemplate:
|
||||
"""Template for textgebundene Eroerterung."""
|
||||
return EHTemplate(
|
||||
id="template_eroerterung_textgebunden",
|
||||
aufgabentyp="eroerterung_textgebunden",
|
||||
name="Textgebundene Eroerterung",
|
||||
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Qualitaet der Argumentation",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Wiedergabe der Textposition",
|
||||
"Differenzierte eigene Argumentation",
|
||||
"Vielfaeltige und ueberzeugende Argumente",
|
||||
"Beruecksichtigung von Pro und Contra",
|
||||
"Sinnvolle Beispiele und Belege",
|
||||
"Eigenstaendige Schlussfolgerung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Eroerterung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Problemorientierte Einleitung",
|
||||
"Klare Gliederung der Argumentation",
|
||||
"Logische Argumentationsfolge",
|
||||
"Sinnvolle Ueberlaetze",
|
||||
"Begruendetes Fazit"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="textbezug",
|
||||
name="Textbezug",
|
||||
beschreibung="Verknuepfung mit dem Ausgangstext",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Angemessene Textwiedergabe",
|
||||
"Kritische Auseinandersetzung mit Textposition",
|
||||
"Korrekte Zitierweise",
|
||||
"Verknuepfung eigener Argumente mit Text"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung",
|
||||
"Variationsreicher Ausdruck"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Hinfuehrung zum Thema",
|
||||
"Nennung des Ausgangstextes",
|
||||
"Formulierung der Leitfrage/These",
|
||||
"Ueberleitung zum Hauptteil"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Kurze Wiedergabe der Textposition",
|
||||
"Systematische Argumentation (dialektisch oder linear)",
|
||||
"Jedes Argument: These - Begruendung - Beispiel",
|
||||
"Gewichtung der Argumente",
|
||||
"Verknuepfung mit Textposition"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der wichtigsten Argumente",
|
||||
"Eigene begruendete Stellungnahme",
|
||||
"Ggf. Ausblick oder Appell"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Argumentative Konnektoren verwenden",
|
||||
"Sachlicher, ueberzeugender Stil",
|
||||
"Eigene Meinung kennzeichnen",
|
||||
"Konjunktiv fuer Textpositionen"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_prosaanalyse_template() -> EHTemplate:
|
||||
"""Template for prose/narrative text analysis."""
|
||||
return EHTemplate(
|
||||
id="template_prosaanalyse",
|
||||
aufgabentyp="prosaanalyse",
|
||||
name="Epische Textanalyse / Prosaanalyse",
|
||||
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Textinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Handlung",
|
||||
"Charakterisierung der Figuren",
|
||||
"Erkennen der Erzaehlsituation",
|
||||
"Deutung der Konflikte und Motive",
|
||||
"Einordnung in den Gesamtzusammenhang"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Informative Einleitung",
|
||||
"Systematische Analyse im Hauptteil",
|
||||
"Verknuepfung der Analyseergebnisse",
|
||||
"Schluessige Gesamtdeutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="erzaehltechnik",
|
||||
name="Erzaehltechnische Analyse",
|
||||
beschreibung="Analyse narrativer Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Bestimmung der Erzaehlperspektive",
|
||||
"Analyse von Zeitgestaltung",
|
||||
"Raumgestaltung und Atmosphaere",
|
||||
"Figurenrede und Bewusstseinsdarstellung",
|
||||
"Funktionale Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Textsorte, Erscheinungsjahr",
|
||||
"Einordnung des Auszugs in den Gesamttext",
|
||||
"Thema und Deutungshypothese"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Kurze Inhaltsangabe des Auszugs",
|
||||
"Analyse der Handlungsstruktur",
|
||||
"Figurenanalyse mit Textbelegen",
|
||||
"Erzaehltechnische Analyse",
|
||||
"Sprachliche Analyse",
|
||||
"Verknuepfung aller Ebenen"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Bestaetigung der Deutungshypothese",
|
||||
"Bedeutung fuer Gesamtwerk",
|
||||
"Ggf. Aktualitaetsbezug"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Erzaehltextanalyse",
|
||||
"Zwischen Erzaehler und Autor unterscheiden",
|
||||
"Praesens als Analysetempus",
|
||||
"Deutende Formulierungen"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_dramenanalyse_template() -> EHTemplate:
|
||||
"""Template for drama analysis."""
|
||||
return EHTemplate(
|
||||
id="template_dramenanalyse",
|
||||
aufgabentyp="dramenanalyse",
|
||||
name="Dramenanalyse",
|
||||
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Szeneninhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Handlung",
|
||||
"Analyse der Figurenkonstellation",
|
||||
"Erkennen des dramatischen Konflikts",
|
||||
"Einordnung in den Handlungsverlauf",
|
||||
"Deutung der Szene im Gesamtzusammenhang"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Einleitung mit Kontextualisierung",
|
||||
"Systematische Szenenanalyse",
|
||||
"Verknuepfung der Analyseergebnisse",
|
||||
"Schluessige Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="dramentechnik",
|
||||
name="Dramentechnische Analyse",
|
||||
beschreibung="Analyse dramatischer Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Analyse der Dialoggestaltung",
|
||||
"Regieanweisungen und Buehnenraum",
|
||||
"Dramatische Spannung",
|
||||
"Monolog/Dialog-Formen",
|
||||
"Funktionale Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Uraufführungsjahr, Dramenform",
|
||||
"Einordnung der Szene in den Handlungsverlauf",
|
||||
"Thema und Deutungshypothese"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Situierung der Szene",
|
||||
"Analyse des Dialogverlaufs",
|
||||
"Figurenanalyse im Dialog",
|
||||
"Sprachliche Analyse",
|
||||
"Dramentechnische Mittel",
|
||||
"Bedeutung fuer den Konflikt"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Funktion der Szene im Drama",
|
||||
"Bedeutung fuer die Gesamtdeutung"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Dramenanalyse",
|
||||
"Praesens als Analysetempus",
|
||||
"Korrekte Zitierweise mit Akt/Szene/Zeile"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# =============================================
|
||||
# TEMPLATE REGISTRY
|
||||
# =============================================
|
||||
|
||||
TEMPLATES: Dict[str, EHTemplate] = {}
|
||||
|
||||
|
||||
def initialize_templates():
|
||||
"""Initialize all pre-defined templates."""
|
||||
global TEMPLATES
|
||||
TEMPLATES = {
|
||||
"textanalyse_pragmatisch": get_textanalyse_template(),
|
||||
"gedichtanalyse": get_gedichtanalyse_template(),
|
||||
"eroerterung_textgebunden": get_eroerterung_template(),
|
||||
"prosaanalyse": get_prosaanalyse_template(),
|
||||
"dramenanalyse": get_dramenanalyse_template(),
|
||||
}
|
||||
|
||||
|
||||
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
|
||||
"""Get a template by Aufgabentyp."""
|
||||
if not TEMPLATES:
|
||||
initialize_templates()
|
||||
return TEMPLATES.get(aufgabentyp)
|
||||
|
||||
|
||||
def list_templates() -> List[Dict]:
|
||||
"""List all available templates."""
|
||||
if not TEMPLATES:
|
||||
initialize_templates()
|
||||
return [
|
||||
{
|
||||
"aufgabentyp": typ,
|
||||
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
|
||||
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
|
||||
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
|
||||
}
|
||||
for typ in TEMPLATES.keys()
|
||||
]
|
||||
|
||||
|
||||
def get_aufgabentypen() -> Dict:
|
||||
"""Get all Aufgabentypen definitions."""
|
||||
return AUFGABENTYPEN
|
||||
|
||||
|
||||
# Initialize on import
|
||||
initialize_templates()
|
||||
# Registry
|
||||
from eh_templates_registry import ( # noqa: F401
|
||||
TEMPLATES,
|
||||
initialize_templates,
|
||||
get_template,
|
||||
list_templates,
|
||||
get_aufgabentypen,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Erwartungshorizont Templates — Analyse templates.
|
||||
|
||||
Contains templates for:
|
||||
- Textanalyse (pragmatische Texte)
|
||||
- Gedichtanalyse / Lyrikinterpretation
|
||||
- Prosaanalyse
|
||||
- Dramenanalyse
|
||||
"""
|
||||
|
||||
from eh_templates_types import EHTemplate, EHKriterium
|
||||
|
||||
|
||||
def get_textanalyse_template() -> EHTemplate:
|
||||
"""Template for pragmatic text analysis."""
|
||||
return EHTemplate(
|
||||
id="template_textanalyse_pragmatisch",
|
||||
aufgabentyp="textanalyse_pragmatisch",
|
||||
name="Textanalyse pragmatischer Texte",
|
||||
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Wiedergabe des Textinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Textaussage/These",
|
||||
"Vollstaendige Wiedergabe der Argumentationsstruktur",
|
||||
"Erkennen von Intention und Adressatenbezug",
|
||||
"Einordnung in den historischen/gesellschaftlichen Kontext",
|
||||
"Beruecksichtigung aller relevanten Textaspekte"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau und Gliederung der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Sinnvolle Einleitung mit Basisinformationen",
|
||||
"Logische Gliederung des Hauptteils",
|
||||
"Stringente Gedankenfuehrung",
|
||||
"Angemessener Schluss mit Fazit/Wertung",
|
||||
"Absatzgliederung und Ueberlaenge"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="analyse",
|
||||
name="Analytische Qualitaet",
|
||||
beschreibung="Tiefe und Qualitaet der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Erkennen rhetorischer Mittel",
|
||||
"Funktionale Deutung der Stilmittel",
|
||||
"Analyse der Argumentationsweise",
|
||||
"Beruecksichtigung von Wortwahl und Satzbau",
|
||||
"Verknuepfung von Form und Inhalt"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung",
|
||||
"Korrekte Getrennt- und Zusammenschreibung",
|
||||
"Korrekte Fremdwortschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Flexion",
|
||||
"Korrekte Zeichensetzung",
|
||||
"Korrekte Bezuege und Kongruenz"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
|
||||
"Benennung des Themas",
|
||||
"Formulierung der Kernthese/Hauptaussage",
|
||||
"Ggf. Einordnung in den Kontext"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Systematische Analyse der Argumentationsstruktur",
|
||||
"Untersuchung der sprachlichen Gestaltung",
|
||||
"Funktionale Deutung der Stilmittel",
|
||||
"Beruecksichtigung von Adressatenbezug und Intention",
|
||||
"Textbelege durch Zitate"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Bewertung der Ueberzeugungskraft",
|
||||
"Ggf. aktuelle Relevanz",
|
||||
"Persoenliche Stellungnahme (wenn gefordert)"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachsprachliche Begriffe korrekt verwenden",
|
||||
"Konjunktiv fuer indirekte Rede",
|
||||
"Praesens als Tempus der Analyse",
|
||||
"Sachlicher, analytischer Stil"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_gedichtanalyse_template() -> EHTemplate:
|
||||
"""Template for poetry analysis."""
|
||||
return EHTemplate(
|
||||
id="template_gedichtanalyse",
|
||||
aufgabentyp="gedichtanalyse",
|
||||
name="Gedichtanalyse / Lyrikinterpretation",
|
||||
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Gedichtinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
|
||||
"Vollstaendige inhaltliche Erschliessung aller Strophen",
|
||||
"Erkennen der zentralen Motive und Themen",
|
||||
"Epochenzuordnung und literaturgeschichtliche Einordnung",
|
||||
"Deutung der Bildlichkeit und Symbolik"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Interpretation",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Einleitung mit Basisinformationen",
|
||||
"Systematische strophenweise oder aspektorientierte Analyse",
|
||||
"Verknuepfung von Form- und Inhaltsanalyse",
|
||||
"Schluessige Gesamtdeutung im Schluss"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="formanalyse",
|
||||
name="Formale Analyse",
|
||||
beschreibung="Analyse der lyrischen Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Bestimmung von Metrum und Reimschema",
|
||||
"Analyse der Klanggestaltung",
|
||||
"Erkennen von Enjambements und Zaesuren",
|
||||
"Deutung der formalen Mittel",
|
||||
"Verknuepfung von Form und Inhalt"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung",
|
||||
"Korrekte Getrennt- und Zusammenschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Flexion",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Entstehungsjahr/Epoche",
|
||||
"Thema/Motiv des Gedichts",
|
||||
"Erste Deutungshypothese",
|
||||
"Formale Grunddaten (Strophen, Verse)"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
|
||||
"Formale Analyse (Metrum, Reim, Klang)",
|
||||
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
|
||||
"Funktionale Verknuepfung aller Ebenen",
|
||||
"Textbelege durch Zitate mit Versangabe"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Interpretationsergebnisse",
|
||||
"Bestaetigung/Modifikation der Deutungshypothese",
|
||||
"Einordnung in Epoche/Werk des Autors",
|
||||
"Aktualitaetsbezug (wenn sinnvoll)"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Lyrikanalyse verwenden",
|
||||
"Zwischen lyrischem Ich und Autor unterscheiden",
|
||||
"Praesens als Analysetempus",
|
||||
"Deutende statt beschreibende Formulierungen"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_prosaanalyse_template() -> EHTemplate:
|
||||
"""Template for prose/narrative text analysis."""
|
||||
return EHTemplate(
|
||||
id="template_prosaanalyse",
|
||||
aufgabentyp="prosaanalyse",
|
||||
name="Epische Textanalyse / Prosaanalyse",
|
||||
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Textinhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Handlung",
|
||||
"Charakterisierung der Figuren",
|
||||
"Erkennen der Erzaehlsituation",
|
||||
"Deutung der Konflikte und Motive",
|
||||
"Einordnung in den Gesamtzusammenhang"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Informative Einleitung",
|
||||
"Systematische Analyse im Hauptteil",
|
||||
"Verknuepfung der Analyseergebnisse",
|
||||
"Schluessige Gesamtdeutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="erzaehltechnik",
|
||||
name="Erzaehltechnische Analyse",
|
||||
beschreibung="Analyse narrativer Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Bestimmung der Erzaehlperspektive",
|
||||
"Analyse von Zeitgestaltung",
|
||||
"Raumgestaltung und Atmosphaere",
|
||||
"Figurenrede und Bewusstseinsdarstellung",
|
||||
"Funktionale Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Textsorte, Erscheinungsjahr",
|
||||
"Einordnung des Auszugs in den Gesamttext",
|
||||
"Thema und Deutungshypothese"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Kurze Inhaltsangabe des Auszugs",
|
||||
"Analyse der Handlungsstruktur",
|
||||
"Figurenanalyse mit Textbelegen",
|
||||
"Erzaehltechnische Analyse",
|
||||
"Sprachliche Analyse",
|
||||
"Verknuepfung aller Ebenen"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Bestaetigung der Deutungshypothese",
|
||||
"Bedeutung fuer Gesamtwerk",
|
||||
"Ggf. Aktualitaetsbezug"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Erzaehltextanalyse",
|
||||
"Zwischen Erzaehler und Autor unterscheiden",
|
||||
"Praesens als Analysetempus",
|
||||
"Deutende Formulierungen"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_dramenanalyse_template() -> EHTemplate:
|
||||
"""Template for drama analysis."""
|
||||
return EHTemplate(
|
||||
id="template_dramenanalyse",
|
||||
aufgabentyp="dramenanalyse",
|
||||
name="Dramenanalyse",
|
||||
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Erfassung und Deutung des Szeneninhalts",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Erfassung der Handlung",
|
||||
"Analyse der Figurenkonstellation",
|
||||
"Erkennen des dramatischen Konflikts",
|
||||
"Einordnung in den Handlungsverlauf",
|
||||
"Deutung der Szene im Gesamtzusammenhang"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Analyse",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Einleitung mit Kontextualisierung",
|
||||
"Systematische Szenenanalyse",
|
||||
"Verknuepfung der Analyseergebnisse",
|
||||
"Schluessige Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="dramentechnik",
|
||||
name="Dramentechnische Analyse",
|
||||
beschreibung="Analyse dramatischer Gestaltungsmittel",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Analyse der Dialoggestaltung",
|
||||
"Regieanweisungen und Buehnenraum",
|
||||
"Dramatische Spannung",
|
||||
"Monolog/Dialog-Formen",
|
||||
"Funktionale Deutung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Autor, Titel, Urauffuehrungsjahr, Dramenform",
|
||||
"Einordnung der Szene in den Handlungsverlauf",
|
||||
"Thema und Deutungshypothese"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Situierung der Szene",
|
||||
"Analyse des Dialogverlaufs",
|
||||
"Figurenanalyse im Dialog",
|
||||
"Sprachliche Analyse",
|
||||
"Dramentechnische Mittel",
|
||||
"Bedeutung fuer den Konflikt"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der Analyseergebnisse",
|
||||
"Funktion der Szene im Drama",
|
||||
"Bedeutung fuer die Gesamtdeutung"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Fachbegriffe der Dramenanalyse",
|
||||
"Praesens als Analysetempus",
|
||||
"Korrekte Zitierweise mit Akt/Szene/Zeile"
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Erwartungshorizont Templates — Eroerterung template.
|
||||
"""
|
||||
|
||||
from eh_templates_types import EHTemplate, EHKriterium
|
||||
|
||||
|
||||
def get_eroerterung_template() -> EHTemplate:
|
||||
"""Template for textgebundene Eroerterung."""
|
||||
return EHTemplate(
|
||||
id="template_eroerterung_textgebunden",
|
||||
aufgabentyp="eroerterung_textgebunden",
|
||||
name="Textgebundene Eroerterung",
|
||||
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
|
||||
kriterien=[
|
||||
EHKriterium(
|
||||
id="inhalt",
|
||||
name="Inhaltliche Leistung",
|
||||
beschreibung="Qualitaet der Argumentation",
|
||||
gewichtung=40,
|
||||
erwartungen=[
|
||||
"Korrekte Wiedergabe der Textposition",
|
||||
"Differenzierte eigene Argumentation",
|
||||
"Vielfaeltige und ueberzeugende Argumente",
|
||||
"Beruecksichtigung von Pro und Contra",
|
||||
"Sinnvolle Beispiele und Belege",
|
||||
"Eigenstaendige Schlussfolgerung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="struktur",
|
||||
name="Aufbau und Struktur",
|
||||
beschreibung="Logischer Aufbau der Eroerterung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Problemorientierte Einleitung",
|
||||
"Klare Gliederung der Argumentation",
|
||||
"Logische Argumentationsfolge",
|
||||
"Sinnvolle Ueberlaetze",
|
||||
"Begruendetes Fazit"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="textbezug",
|
||||
name="Textbezug",
|
||||
beschreibung="Verknuepfung mit dem Ausgangstext",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Angemessene Textwiedergabe",
|
||||
"Kritische Auseinandersetzung mit Textposition",
|
||||
"Korrekte Zitierweise",
|
||||
"Verknuepfung eigener Argumente mit Text"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="rechtschreibung",
|
||||
name="Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
beschreibung="Orthografische Korrektheit",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekte Rechtschreibung",
|
||||
"Korrekte Gross- und Kleinschreibung"
|
||||
]
|
||||
),
|
||||
EHKriterium(
|
||||
id="grammatik",
|
||||
name="Sprachliche Richtigkeit (Grammatik)",
|
||||
beschreibung="Grammatische Korrektheit und Zeichensetzung",
|
||||
gewichtung=15,
|
||||
erwartungen=[
|
||||
"Korrekter Satzbau",
|
||||
"Korrekte Zeichensetzung",
|
||||
"Variationsreicher Ausdruck"
|
||||
]
|
||||
)
|
||||
],
|
||||
einleitung_hinweise=[
|
||||
"Hinfuehrung zum Thema",
|
||||
"Nennung des Ausgangstextes",
|
||||
"Formulierung der Leitfrage/These",
|
||||
"Ueberleitung zum Hauptteil"
|
||||
],
|
||||
hauptteil_hinweise=[
|
||||
"Kurze Wiedergabe der Textposition",
|
||||
"Systematische Argumentation (dialektisch oder linear)",
|
||||
"Jedes Argument: These - Begruendung - Beispiel",
|
||||
"Gewichtung der Argumente",
|
||||
"Verknuepfung mit Textposition"
|
||||
],
|
||||
schluss_hinweise=[
|
||||
"Zusammenfassung der wichtigsten Argumente",
|
||||
"Eigene begruendete Stellungnahme",
|
||||
"Ggf. Ausblick oder Appell"
|
||||
],
|
||||
sprachliche_aspekte=[
|
||||
"Argumentative Konnektoren verwenden",
|
||||
"Sachlicher, ueberzeugender Stil",
|
||||
"Eigene Meinung kennzeichnen",
|
||||
"Konjunktiv fuer Textpositionen"
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Erwartungshorizont Templates — registry for template lookup.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from eh_templates_types import EHTemplate, AUFGABENTYPEN
|
||||
from eh_templates_analyse import (
|
||||
get_textanalyse_template,
|
||||
get_gedichtanalyse_template,
|
||||
get_prosaanalyse_template,
|
||||
get_dramenanalyse_template,
|
||||
)
|
||||
from eh_templates_eroerterung import get_eroerterung_template
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, EHTemplate] = {}
|
||||
|
||||
|
||||
def initialize_templates():
|
||||
"""Initialize all pre-defined templates."""
|
||||
global TEMPLATES
|
||||
TEMPLATES = {
|
||||
"textanalyse_pragmatisch": get_textanalyse_template(),
|
||||
"gedichtanalyse": get_gedichtanalyse_template(),
|
||||
"eroerterung_textgebunden": get_eroerterung_template(),
|
||||
"prosaanalyse": get_prosaanalyse_template(),
|
||||
"dramenanalyse": get_dramenanalyse_template(),
|
||||
}
|
||||
|
||||
|
||||
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
|
||||
"""Get a template by Aufgabentyp."""
|
||||
if not TEMPLATES:
|
||||
initialize_templates()
|
||||
return TEMPLATES.get(aufgabentyp)
|
||||
|
||||
|
||||
def list_templates() -> List[Dict]:
|
||||
"""List all available templates."""
|
||||
if not TEMPLATES:
|
||||
initialize_templates()
|
||||
return [
|
||||
{
|
||||
"aufgabentyp": typ,
|
||||
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
|
||||
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
|
||||
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
|
||||
}
|
||||
for typ in TEMPLATES.keys()
|
||||
]
|
||||
|
||||
|
||||
def get_aufgabentypen() -> Dict:
|
||||
"""Get all Aufgabentypen definitions."""
|
||||
return AUFGABENTYPEN
|
||||
|
||||
|
||||
# Initialize on import
|
||||
initialize_templates()
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Erwartungshorizont Templates — types and Aufgabentypen registry.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
AUFGABENTYPEN = {
|
||||
"textanalyse_pragmatisch": {
|
||||
"name": "Textanalyse (pragmatische Texte)",
|
||||
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
|
||||
"category": "analyse"
|
||||
},
|
||||
"sachtextanalyse": {
|
||||
"name": "Sachtextanalyse",
|
||||
"description": "Analyse von informativen und appellativen Sachtexten",
|
||||
"category": "analyse"
|
||||
},
|
||||
"gedichtanalyse": {
|
||||
"name": "Gedichtanalyse / Lyrikinterpretation",
|
||||
"description": "Analyse und Interpretation lyrischer Texte",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"dramenanalyse": {
|
||||
"name": "Dramenanalyse",
|
||||
"description": "Analyse dramatischer Texte und Szenen",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"prosaanalyse": {
|
||||
"name": "Epische Textanalyse / Prosaanalyse",
|
||||
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
|
||||
"category": "interpretation"
|
||||
},
|
||||
"eroerterung_textgebunden": {
|
||||
"name": "Textgebundene Eroerterung",
|
||||
"description": "Eroerterung auf Basis eines Sachtextes",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"eroerterung_frei": {
|
||||
"name": "Freie Eroerterung",
|
||||
"description": "Freie Eroerterung zu einem Thema",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"eroerterung_literarisch": {
|
||||
"name": "Literarische Eroerterung",
|
||||
"description": "Eroerterung zu literarischen Fragestellungen",
|
||||
"category": "argumentation"
|
||||
},
|
||||
"materialgestuetzt": {
|
||||
"name": "Materialgestuetztes Schreiben",
|
||||
"description": "Verfassen eines Textes auf Materialbasis",
|
||||
"category": "produktion"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EHKriterium:
|
||||
"""Single criterion in an Erwartungshorizont."""
|
||||
id: str
|
||||
name: str
|
||||
beschreibung: str
|
||||
gewichtung: int # Percentage weight (0-100)
|
||||
erwartungen: List[str] # Expected points/elements
|
||||
max_punkte: int = 100
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EHTemplate:
|
||||
"""Complete Erwartungshorizont template."""
|
||||
id: str
|
||||
aufgabentyp: str
|
||||
name: str
|
||||
beschreibung: str
|
||||
kriterien: List[EHKriterium]
|
||||
einleitung_hinweise: List[str]
|
||||
hauptteil_hinweise: List[str]
|
||||
schluss_hinweise: List[str]
|
||||
sprachliche_aspekte: List[str]
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now())
|
||||
|
||||
def to_dict(self):
|
||||
d = {
|
||||
'id': self.id,
|
||||
'aufgabentyp': self.aufgabentyp,
|
||||
'name': self.name,
|
||||
'beschreibung': self.beschreibung,
|
||||
'kriterien': [k.to_dict() for k in self.kriterien],
|
||||
'einleitung_hinweise': self.einleitung_hinweise,
|
||||
'hauptteil_hinweise': self.hauptteil_hinweise,
|
||||
'schluss_hinweise': self.schluss_hinweise,
|
||||
'sprachliche_aspekte': self.sprachliche_aspekte,
|
||||
'created_at': self.created_at.isoformat()
|
||||
}
|
||||
return d
|
||||
@@ -1,671 +1,31 @@
|
||||
"""
|
||||
Grid Editor API — endpoints for grid building, editing, and export.
|
||||
Grid Editor API — barrel re-export.
|
||||
|
||||
The core grid building logic is in grid_build_core.py.
|
||||
The actual endpoints live in:
|
||||
- grid_editor_api_grid.py (build-grid, rerun-ocr, save-grid, get-grid)
|
||||
- grid_editor_api_gutter.py (gutter-repair, gutter-repair/apply)
|
||||
- grid_editor_api_box.py (build-box-grids)
|
||||
- grid_editor_api_unified.py (build-unified-grid, unified-grid)
|
||||
|
||||
This module re-exports the combined router and key symbols so that
|
||||
existing `from grid_editor_api import router` / `from grid_editor_api import _build_grid_core`
|
||||
continue to work unchanged.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from grid_build_core import _build_grid_core
|
||||
from grid_editor_helpers import _words_in_zone
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/build-grid")
|
||||
async def build_grid(
|
||||
session_id: str,
|
||||
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
|
||||
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
|
||||
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
|
||||
):
|
||||
"""Build a structured, zone-aware grid from existing Kombi word results.
|
||||
|
||||
Requires that paddle-kombi or rapid-kombi has already been run on the session.
|
||||
Uses the image for box detection and the word positions for grid structuring.
|
||||
|
||||
Query params:
|
||||
ipa_mode: "auto" (only when English IPA detected), "all" (force), "none" (skip)
|
||||
syllable_mode: "auto" (only when original has dividers), "all" (force), "none" (skip)
|
||||
|
||||
Returns a StructuredGrid with zones, each containing their own
|
||||
columns, rows, and cells — ready for the frontend Excel-like editor.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
try:
|
||||
result = await _build_grid_core(
|
||||
session_id, session,
|
||||
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
|
||||
enhance=enhance,
|
||||
max_columns=max_cols if max_cols > 0 else None,
|
||||
min_conf=min_conf if min_conf > 0 else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Save automatic grid snapshot for later comparison with manual corrections
|
||||
# Lazy import to avoid circular dependency with ocr_pipeline_regression
|
||||
from ocr_pipeline_regression import _build_reference_snapshot
|
||||
|
||||
wr = session.get("word_result") or {}
|
||||
engine = wr.get("ocr_engine", "")
|
||||
if engine in ("kombi", "rapid_kombi"):
|
||||
auto_pipeline = "kombi"
|
||||
elif engine == "paddle_direct":
|
||||
auto_pipeline = "paddle-direct"
|
||||
else:
|
||||
auto_pipeline = "pipeline"
|
||||
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["auto_grid_snapshot"] = auto_snapshot
|
||||
|
||||
# Persist to DB and advance current_step to 11 (reconstruction complete)
|
||||
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
|
||||
|
||||
logger.info(
|
||||
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
||||
"%d boxes in %.2fs",
|
||||
session_id,
|
||||
len(result.get("zones", [])),
|
||||
result.get("summary", {}).get("total_columns", 0),
|
||||
result.get("summary", {}).get("total_rows", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
result.get("boxes_detected", 0),
|
||||
result.get("duration_seconds", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/rerun-ocr-and-build-grid")
|
||||
async def rerun_ocr_and_build_grid(
|
||||
session_id: str,
|
||||
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
|
||||
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
|
||||
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
|
||||
vision_fusion: bool = Query(False, description="Step 4: Vision-LLM fusion for degraded scans"),
|
||||
doc_category: str = Query("", description="Document type for Vision-LLM prompt context"),
|
||||
):
|
||||
"""Re-run OCR with quality settings, then rebuild the grid.
|
||||
|
||||
Unlike build-grid (which only rebuilds from existing words),
|
||||
this endpoint re-runs the full OCR pipeline on the cropped image
|
||||
with optional CLAHE enhancement, then builds the grid.
|
||||
|
||||
Steps executed: Image Enhancement → OCR → Grid Build
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
import time as _time
|
||||
t0 = _time.time()
|
||||
|
||||
# 1. Load the cropped/dewarped image from cache or session
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No cropped/dewarped image available. Run preprocessing steps first.")
|
||||
|
||||
import numpy as np
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
ocr_input = dewarped_bgr.copy()
|
||||
|
||||
# 2. Scan quality assessment
|
||||
scan_quality_info = {}
|
||||
try:
|
||||
from scan_quality import score_scan_quality
|
||||
quality_report = score_scan_quality(ocr_input)
|
||||
scan_quality_info = quality_report.to_dict()
|
||||
actual_min_conf = min_conf if min_conf > 0 else quality_report.recommended_min_conf
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: scan quality failed: {e}")
|
||||
actual_min_conf = min_conf if min_conf > 0 else 40
|
||||
|
||||
# 3. Image enhancement (Step 3)
|
||||
is_degraded = scan_quality_info.get("is_degraded", False)
|
||||
if enhance and is_degraded:
|
||||
try:
|
||||
from ocr_image_enhance import enhance_for_ocr
|
||||
ocr_input = enhance_for_ocr(ocr_input, is_degraded=True)
|
||||
logger.info("rerun-ocr: CLAHE enhancement applied")
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: enhancement failed: {e}")
|
||||
|
||||
# 4. Run dual-engine OCR
|
||||
from PIL import Image
|
||||
import pytesseract
|
||||
|
||||
# RapidOCR
|
||||
rapid_words = []
|
||||
try:
|
||||
from cv_ocr_engines import ocr_region_rapid
|
||||
from cv_vocab_types import PageRegion
|
||||
full_region = PageRegion(type="full_page", x=0, y=0, width=img_w, height=img_h)
|
||||
rapid_words = ocr_region_rapid(ocr_input, full_region) or []
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: RapidOCR failed: {e}")
|
||||
|
||||
# Tesseract
|
||||
pil_img = Image.fromarray(ocr_input[:, :, ::-1])
|
||||
data = pytesseract.image_to_data(pil_img, lang='eng+deu', config='--psm 6 --oem 3', output_type=pytesseract.Output.DICT)
|
||||
tess_words = []
|
||||
for i in range(len(data["text"])):
|
||||
text = (data["text"][i] or "").strip()
|
||||
conf_raw = str(data["conf"][i])
|
||||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||
if not text or conf < actual_min_conf:
|
||||
continue
|
||||
tess_words.append({
|
||||
"text": text, "left": data["left"][i], "top": data["top"][i],
|
||||
"width": data["width"][i], "height": data["height"][i], "conf": conf,
|
||||
})
|
||||
|
||||
# 5. Merge OCR results
|
||||
from ocr_pipeline_ocr_merge import _split_paddle_multi_words, _merge_paddle_tesseract, _deduplicate_words
|
||||
rapid_split = _split_paddle_multi_words(rapid_words) if rapid_words else []
|
||||
if rapid_split or tess_words:
|
||||
merged_words = _merge_paddle_tesseract(rapid_split, tess_words)
|
||||
merged_words = _deduplicate_words(merged_words)
|
||||
else:
|
||||
merged_words = tess_words
|
||||
|
||||
# 6. Store updated word_result in session
|
||||
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
|
||||
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
|
||||
for w in merged_words]
|
||||
word_result = {
|
||||
"cells": [{"text": " ".join(w["text"] for w in merged_words),
|
||||
"word_boxes": cells_for_storage}],
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"ocr_engine": "rapid_kombi",
|
||||
"word_count": len(merged_words),
|
||||
"raw_paddle_words": rapid_words,
|
||||
}
|
||||
# 6b. Vision-LLM Fusion (Step 4) — correct OCR using Vision model
|
||||
vision_applied = False
|
||||
if vision_fusion:
|
||||
try:
|
||||
from vision_ocr_fusion import vision_fuse_ocr
|
||||
category = doc_category or session.get("document_category") or "vokabelseite"
|
||||
logger.info(f"rerun-ocr: running Vision-LLM fusion (category={category})")
|
||||
merged_words = await vision_fuse_ocr(ocr_input, merged_words, category)
|
||||
vision_applied = True
|
||||
# Rebuild storage from fused words
|
||||
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
|
||||
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
|
||||
for w in merged_words]
|
||||
word_result["cells"] = [{"text": " ".join(w["text"] for w in merged_words),
|
||||
"word_boxes": cells_for_storage}]
|
||||
word_result["word_count"] = len(merged_words)
|
||||
word_result["ocr_engine"] = "vision_fusion"
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: Vision-LLM fusion failed: {e}")
|
||||
|
||||
await update_session_db(session_id, word_result=word_result)
|
||||
|
||||
# Reload session with updated word_result
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
ocr_duration = _time.time() - t0
|
||||
logger.info(
|
||||
"rerun-ocr session %s: %d words (rapid=%d, tess=%d, merged=%d) in %.1fs "
|
||||
"(enhance=%s, min_conf=%d, quality=%s)",
|
||||
session_id, len(merged_words), len(rapid_words), len(tess_words),
|
||||
len(merged_words), ocr_duration, enhance, actual_min_conf,
|
||||
scan_quality_info.get("quality_pct", "?"),
|
||||
)
|
||||
|
||||
# 7. Build grid from new words
|
||||
try:
|
||||
result = await _build_grid_core(
|
||||
session_id, session,
|
||||
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
|
||||
enhance=enhance,
|
||||
max_columns=max_cols if max_cols > 0 else None,
|
||||
min_conf=min_conf if min_conf > 0 else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Persist grid
|
||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
||||
|
||||
# Add quality info to response
|
||||
result["scan_quality"] = scan_quality_info
|
||||
result["ocr_stats"] = {
|
||||
"rapid_words": len(rapid_words),
|
||||
"tess_words": len(tess_words),
|
||||
"merged_words": len(merged_words),
|
||||
"min_conf_used": actual_min_conf,
|
||||
"enhance_applied": enhance and is_degraded,
|
||||
"vision_fusion_applied": vision_applied,
|
||||
"document_category": doc_category or session.get("document_category", ""),
|
||||
"ocr_duration_seconds": round(ocr_duration, 1),
|
||||
}
|
||||
|
||||
total_duration = _time.time() - t0
|
||||
logger.info(
|
||||
"rerun-ocr+build-grid session %s: %d zones, %d cols, %d cells in %.1fs",
|
||||
session_id,
|
||||
len(result.get("zones", [])),
|
||||
result.get("summary", {}).get("total_columns", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
total_duration,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/save-grid")
|
||||
async def save_grid(session_id: str, request: Request):
|
||||
"""Save edited grid data from the frontend Excel-like editor.
|
||||
|
||||
Receives the full StructuredGrid with user edits (text changes,
|
||||
formatting changes like bold columns, header rows, etc.) and
|
||||
persists it to the session's grid_editor_result.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Validate basic structure
|
||||
if "zones" not in body:
|
||||
raise HTTPException(status_code=400, detail="Missing 'zones' in request body")
|
||||
|
||||
# Preserve metadata from the original build
|
||||
existing = session.get("grid_editor_result") or {}
|
||||
result = {
|
||||
"session_id": session_id,
|
||||
"image_width": body.get("image_width", existing.get("image_width", 0)),
|
||||
"image_height": body.get("image_height", existing.get("image_height", 0)),
|
||||
"zones": body["zones"],
|
||||
"boxes_detected": body.get("boxes_detected", existing.get("boxes_detected", 0)),
|
||||
"summary": body.get("summary", existing.get("summary", {})),
|
||||
"formatting": body.get("formatting", existing.get("formatting", {})),
|
||||
"duration_seconds": existing.get("duration_seconds", 0),
|
||||
"edited": True,
|
||||
}
|
||||
|
||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
||||
|
||||
logger.info("save-grid session %s: %d zones saved", session_id, len(body["zones"]))
|
||||
|
||||
return {"session_id": session_id, "saved": True}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/grid-editor")
|
||||
async def get_grid(session_id: str):
|
||||
"""Retrieve the current grid editor state for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
result = session.get("grid_editor_result")
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No grid editor data. Run build-grid first.",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gutter Repair endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/gutter-repair")
|
||||
async def gutter_repair(session_id: str):
|
||||
"""Analyse grid for gutter-edge OCR errors and return repair suggestions.
|
||||
|
||||
Detects:
|
||||
- Words truncated/blurred at the book binding (spell_fix)
|
||||
- Words split across rows with missing hyphen chars (hyphen_join)
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No grid data. Run build-grid first.",
|
||||
)
|
||||
|
||||
from cv_gutter_repair import analyse_grid_for_gutter_repair
|
||||
|
||||
image_width = grid_data.get("image_width", 0)
|
||||
result = analyse_grid_for_gutter_repair(grid_data, image_width=image_width)
|
||||
|
||||
# Persist suggestions in ground_truth.gutter_repair (avoids DB migration)
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["gutter_repair"] = result
|
||||
await update_session_db(session_id, ground_truth=gt)
|
||||
|
||||
logger.info(
|
||||
"gutter-repair session %s: %d suggestions in %.2fs",
|
||||
session_id,
|
||||
result.get("stats", {}).get("suggestions_found", 0),
|
||||
result.get("duration_seconds", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/gutter-repair/apply")
|
||||
async def gutter_repair_apply(session_id: str, request: Request):
|
||||
"""Apply accepted gutter repair suggestions to the grid.
|
||||
|
||||
Body: { "accepted": ["suggestion_id_1", "suggestion_id_2", ...] }
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data.")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
gutter_result = gt.get("gutter_repair")
|
||||
if not gutter_result:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No gutter repair data. Run gutter-repair first.",
|
||||
)
|
||||
|
||||
body = await request.json()
|
||||
accepted_ids = body.get("accepted", [])
|
||||
if not accepted_ids:
|
||||
return {"applied_count": 0, "changes": []}
|
||||
|
||||
# text_overrides: { suggestion_id: "alternative_text" }
|
||||
# Allows the user to pick a different correction from the alternatives list
|
||||
text_overrides = body.get("text_overrides", {})
|
||||
|
||||
from cv_gutter_repair import apply_gutter_suggestions
|
||||
|
||||
suggestions = gutter_result.get("suggestions", [])
|
||||
|
||||
# Apply user-selected alternatives before passing to apply
|
||||
for s in suggestions:
|
||||
sid = s.get("id", "")
|
||||
if sid in text_overrides and text_overrides[sid]:
|
||||
s["suggested_text"] = text_overrides[sid]
|
||||
|
||||
result = apply_gutter_suggestions(grid_data, accepted_ids, suggestions)
|
||||
|
||||
# Save updated grid back to session
|
||||
await update_session_db(session_id, grid_editor_result=grid_data)
|
||||
|
||||
logger.info(
|
||||
"gutter-repair/apply session %s: %d changes applied",
|
||||
session_id,
|
||||
result.get("applied_count", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Box-Grid-Review endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/build-box-grids")
|
||||
async def build_box_grids(session_id: str, request: Request):
|
||||
"""Rebuild grid structure for all detected boxes with layout-aware detection.
|
||||
|
||||
Uses structure_result.boxes (from Step 7) as the source of box coordinates,
|
||||
and raw_paddle_words as OCR word source. Creates or updates box zones in
|
||||
the grid_editor_result.
|
||||
|
||||
Optional body: { "overrides": { "0": "bullet_list" } }
|
||||
Maps box_index → forced layout_type.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
|
||||
|
||||
# Get raw OCR words (with top/left/width/height keys)
|
||||
word_result = session.get("word_result") or {}
|
||||
all_words = word_result.get("raw_paddle_words") or word_result.get("raw_tesseract_words") or []
|
||||
if not all_words:
|
||||
raise HTTPException(status_code=400, detail="No raw OCR words available.")
|
||||
|
||||
# Get detected boxes from structure_result
|
||||
structure_result = session.get("structure_result") or {}
|
||||
gt = session.get("ground_truth") or {}
|
||||
if not structure_result:
|
||||
structure_result = gt.get("structure_result") or {}
|
||||
detected_boxes = structure_result.get("boxes") or []
|
||||
if not detected_boxes:
|
||||
return {"session_id": session_id, "box_zones_rebuilt": 0, "spell_fixes": 0, "message": "No boxes detected"}
|
||||
|
||||
# Filter out false-positive boxes in header/footer margins.
|
||||
# Textbook pages have ~2.5cm margins at top/bottom. At typical scan
|
||||
# resolutions (150-300 DPI), that's roughly 5-10% of image height.
|
||||
# A box whose vertical CENTER falls within the top or bottom 7% of
|
||||
# the image is likely a page number, unit header, or running footer.
|
||||
img_h_for_filter = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
|
||||
if img_h_for_filter > 0:
|
||||
margin_frac = 0.07 # 7% of image height
|
||||
margin_top = img_h_for_filter * margin_frac
|
||||
margin_bottom = img_h_for_filter * (1 - margin_frac)
|
||||
filtered = []
|
||||
for box in detected_boxes:
|
||||
by = box.get("y", 0)
|
||||
bh = box.get("h", 0)
|
||||
box_center_y = by + bh / 2
|
||||
if box_center_y < margin_top or box_center_y > margin_bottom:
|
||||
logger.info("build-box-grids: skipping header/footer box at y=%d h=%d (center=%.0f, margins=%.0f/%.0f)",
|
||||
by, bh, box_center_y, margin_top, margin_bottom)
|
||||
continue
|
||||
filtered.append(box)
|
||||
detected_boxes = filtered
|
||||
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
pass
|
||||
layout_overrides = body.get("overrides", {})
|
||||
|
||||
from cv_box_layout import build_box_zone_grid
|
||||
from grid_editor_helpers import _words_in_zone
|
||||
|
||||
img_w = grid_data.get("image_width", 0) or word_result.get("image_width", 0)
|
||||
img_h = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
|
||||
# Find highest existing zone_index
|
||||
max_zone_idx = max((z.get("zone_index", 0) for z in zones), default=-1)
|
||||
|
||||
# Remove old box zones (we'll rebuild them)
|
||||
zones = [z for z in zones if z.get("zone_type") != "box"]
|
||||
|
||||
box_count = 0
|
||||
spell_fixes = 0
|
||||
|
||||
for box_idx, box in enumerate(detected_boxes):
|
||||
bx = box.get("x", 0)
|
||||
by = box.get("y", 0)
|
||||
bw = box.get("w", 0)
|
||||
bh = box.get("h", 0)
|
||||
|
||||
if bw <= 0 or bh <= 0:
|
||||
continue
|
||||
|
||||
# Filter raw OCR words inside this box
|
||||
zone_words = _words_in_zone(all_words, by, bh, bx, bw)
|
||||
if not zone_words:
|
||||
logger.info("Box %d: no words found in bbox (%d,%d,%d,%d)", box_idx, bx, by, bw, bh)
|
||||
continue
|
||||
|
||||
zone_idx = max_zone_idx + 1 + box_idx
|
||||
forced_layout = layout_overrides.get(str(box_idx))
|
||||
|
||||
# Build box grid
|
||||
box_grid = build_box_zone_grid(
|
||||
zone_words, bx, by, bw, bh,
|
||||
zone_idx, img_w, img_h,
|
||||
layout_type=forced_layout,
|
||||
)
|
||||
|
||||
# Apply SmartSpellChecker to all box cells
|
||||
try:
|
||||
from smart_spell import SmartSpellChecker
|
||||
ssc = SmartSpellChecker()
|
||||
for cell in box_grid.get("cells", []):
|
||||
text = cell.get("text", "")
|
||||
if not text:
|
||||
continue
|
||||
result = ssc.correct_text(text, lang="auto")
|
||||
if result.changed:
|
||||
cell["text"] = result.corrected
|
||||
spell_fixes += 1
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Build zone entry
|
||||
zone_entry = {
|
||||
"zone_index": zone_idx,
|
||||
"zone_type": "box",
|
||||
"bbox_px": {"x": bx, "y": by, "w": bw, "h": bh},
|
||||
"bbox_pct": {
|
||||
"x": round(bx / img_w * 100, 2) if img_w else 0,
|
||||
"y": round(by / img_h * 100, 2) if img_h else 0,
|
||||
"w": round(bw / img_w * 100, 2) if img_w else 0,
|
||||
"h": round(bh / img_h * 100, 2) if img_h else 0,
|
||||
},
|
||||
"border": None,
|
||||
"word_count": len(zone_words),
|
||||
"columns": box_grid["columns"],
|
||||
"rows": box_grid["rows"],
|
||||
"cells": box_grid["cells"],
|
||||
"header_rows": box_grid.get("header_rows", []),
|
||||
"box_layout_type": box_grid.get("box_layout_type", "flowing"),
|
||||
"box_grid_reviewed": False,
|
||||
"box_bg_color": box.get("bg_color_name", ""),
|
||||
"box_bg_hex": box.get("bg_color_hex", ""),
|
||||
}
|
||||
zones.append(zone_entry)
|
||||
box_count += 1
|
||||
|
||||
# Sort zones by y-position for correct reading order
|
||||
zones.sort(key=lambda z: z.get("bbox_px", {}).get("y", 0))
|
||||
|
||||
grid_data["zones"] = zones
|
||||
await update_session_db(session_id, grid_editor_result=grid_data)
|
||||
|
||||
logger.info(
|
||||
"build-box-grids session %s: %d boxes processed (%d words spell-fixed) from %d detected",
|
||||
session_id, box_count, spell_fixes, len(detected_boxes),
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"box_zones_rebuilt": box_count,
|
||||
"total_detected_boxes": len(detected_boxes),
|
||||
"spell_fixes": spell_fixes,
|
||||
"zones": zones,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified Grid endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/build-unified-grid")
|
||||
async def build_unified_grid_endpoint(session_id: str):
|
||||
"""Build a single-zone unified grid merging content + box zones.
|
||||
|
||||
Takes the existing multi-zone grid_editor_result and produces a
|
||||
unified grid where boxes are integrated into the main row sequence.
|
||||
Persists as unified_grid_result (preserves original multi-zone data).
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
|
||||
|
||||
from unified_grid import build_unified_grid
|
||||
|
||||
result = build_unified_grid(
|
||||
zones=grid_data.get("zones", []),
|
||||
image_width=grid_data.get("image_width", 0),
|
||||
image_height=grid_data.get("image_height", 0),
|
||||
layout_metrics=grid_data.get("layout_metrics", {}),
|
||||
)
|
||||
|
||||
# Persist as separate field (don't overwrite original multi-zone grid)
|
||||
await update_session_db(session_id, unified_grid_result=result)
|
||||
|
||||
logger.info(
|
||||
"build-unified-grid session %s: %d rows, %d cells",
|
||||
session_id,
|
||||
result.get("summary", {}).get("total_rows", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/unified-grid")
|
||||
async def get_unified_grid(session_id: str):
|
||||
"""Retrieve the unified grid for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
result = session.get("unified_grid_result")
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No unified grid. Run build-unified-grid first.",
|
||||
)
|
||||
|
||||
return result
|
||||
from fastapi import APIRouter
|
||||
|
||||
from grid_editor_api_grid import router as _grid_router
|
||||
from grid_editor_api_gutter import router as _gutter_router
|
||||
from grid_editor_api_box import router as _box_router
|
||||
from grid_editor_api_unified import router as _unified_router
|
||||
|
||||
# Re-export _build_grid_core so callers that do
|
||||
# `from grid_editor_api import _build_grid_core` keep working.
|
||||
from grid_build_core import _build_grid_core # noqa: F401
|
||||
|
||||
# Merge all sub-routers into one combined router
|
||||
router = APIRouter()
|
||||
router.include_router(_grid_router)
|
||||
router.include_router(_gutter_router)
|
||||
router.include_router(_box_router)
|
||||
router.include_router(_unified_router)
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Grid Editor API — box-grid-review endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from grid_editor_helpers import _words_in_zone
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/build-box-grids")
|
||||
async def build_box_grids(session_id: str, request: Request):
|
||||
"""Rebuild grid structure for all detected boxes with layout-aware detection.
|
||||
|
||||
Uses structure_result.boxes (from Step 7) as the source of box coordinates,
|
||||
and raw_paddle_words as OCR word source. Creates or updates box zones in
|
||||
the grid_editor_result.
|
||||
|
||||
Optional body: { "overrides": { "0": "bullet_list" } }
|
||||
Maps box_index -> forced layout_type.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
|
||||
|
||||
# Get raw OCR words (with top/left/width/height keys)
|
||||
word_result = session.get("word_result") or {}
|
||||
all_words = word_result.get("raw_paddle_words") or word_result.get("raw_tesseract_words") or []
|
||||
if not all_words:
|
||||
raise HTTPException(status_code=400, detail="No raw OCR words available.")
|
||||
|
||||
# Get detected boxes from structure_result
|
||||
structure_result = session.get("structure_result") or {}
|
||||
gt = session.get("ground_truth") or {}
|
||||
if not structure_result:
|
||||
structure_result = gt.get("structure_result") or {}
|
||||
detected_boxes = structure_result.get("boxes") or []
|
||||
if not detected_boxes:
|
||||
return {"session_id": session_id, "box_zones_rebuilt": 0, "spell_fixes": 0, "message": "No boxes detected"}
|
||||
|
||||
# Filter out false-positive boxes in header/footer margins.
|
||||
img_h_for_filter = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
|
||||
if img_h_for_filter > 0:
|
||||
margin_frac = 0.07 # 7% of image height
|
||||
margin_top = img_h_for_filter * margin_frac
|
||||
margin_bottom = img_h_for_filter * (1 - margin_frac)
|
||||
filtered = []
|
||||
for box in detected_boxes:
|
||||
by = box.get("y", 0)
|
||||
bh = box.get("h", 0)
|
||||
box_center_y = by + bh / 2
|
||||
if box_center_y < margin_top or box_center_y > margin_bottom:
|
||||
logger.info("build-box-grids: skipping header/footer box at y=%d h=%d (center=%.0f, margins=%.0f/%.0f)",
|
||||
by, bh, box_center_y, margin_top, margin_bottom)
|
||||
continue
|
||||
filtered.append(box)
|
||||
detected_boxes = filtered
|
||||
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
pass
|
||||
layout_overrides = body.get("overrides", {})
|
||||
|
||||
from cv_box_layout import build_box_zone_grid
|
||||
|
||||
img_w = grid_data.get("image_width", 0) or word_result.get("image_width", 0)
|
||||
img_h = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
|
||||
|
||||
zones = grid_data.get("zones", [])
|
||||
|
||||
# Find highest existing zone_index
|
||||
max_zone_idx = max((z.get("zone_index", 0) for z in zones), default=-1)
|
||||
|
||||
# Remove old box zones (we'll rebuild them)
|
||||
zones = [z for z in zones if z.get("zone_type") != "box"]
|
||||
|
||||
box_count = 0
|
||||
spell_fixes = 0
|
||||
|
||||
for box_idx, box in enumerate(detected_boxes):
|
||||
bx = box.get("x", 0)
|
||||
by = box.get("y", 0)
|
||||
bw = box.get("w", 0)
|
||||
bh = box.get("h", 0)
|
||||
|
||||
if bw <= 0 or bh <= 0:
|
||||
continue
|
||||
|
||||
# Filter raw OCR words inside this box
|
||||
zone_words = _words_in_zone(all_words, by, bh, bx, bw)
|
||||
if not zone_words:
|
||||
logger.info("Box %d: no words found in bbox (%d,%d,%d,%d)", box_idx, bx, by, bw, bh)
|
||||
continue
|
||||
|
||||
zone_idx = max_zone_idx + 1 + box_idx
|
||||
forced_layout = layout_overrides.get(str(box_idx))
|
||||
|
||||
# Build box grid
|
||||
box_grid = build_box_zone_grid(
|
||||
zone_words, bx, by, bw, bh,
|
||||
zone_idx, img_w, img_h,
|
||||
layout_type=forced_layout,
|
||||
)
|
||||
|
||||
# Apply SmartSpellChecker to all box cells
|
||||
try:
|
||||
from smart_spell import SmartSpellChecker
|
||||
ssc = SmartSpellChecker()
|
||||
for cell in box_grid.get("cells", []):
|
||||
text = cell.get("text", "")
|
||||
if not text:
|
||||
continue
|
||||
result = ssc.correct_text(text, lang="auto")
|
||||
if result.changed:
|
||||
cell["text"] = result.corrected
|
||||
spell_fixes += 1
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Build zone entry
|
||||
zone_entry = {
|
||||
"zone_index": zone_idx,
|
||||
"zone_type": "box",
|
||||
"bbox_px": {"x": bx, "y": by, "w": bw, "h": bh},
|
||||
"bbox_pct": {
|
||||
"x": round(bx / img_w * 100, 2) if img_w else 0,
|
||||
"y": round(by / img_h * 100, 2) if img_h else 0,
|
||||
"w": round(bw / img_w * 100, 2) if img_w else 0,
|
||||
"h": round(bh / img_h * 100, 2) if img_h else 0,
|
||||
},
|
||||
"border": None,
|
||||
"word_count": len(zone_words),
|
||||
"columns": box_grid["columns"],
|
||||
"rows": box_grid["rows"],
|
||||
"cells": box_grid["cells"],
|
||||
"header_rows": box_grid.get("header_rows", []),
|
||||
"box_layout_type": box_grid.get("box_layout_type", "flowing"),
|
||||
"box_grid_reviewed": False,
|
||||
"box_bg_color": box.get("bg_color_name", ""),
|
||||
"box_bg_hex": box.get("bg_color_hex", ""),
|
||||
}
|
||||
zones.append(zone_entry)
|
||||
box_count += 1
|
||||
|
||||
# Sort zones by y-position for correct reading order
|
||||
zones.sort(key=lambda z: z.get("bbox_px", {}).get("y", 0))
|
||||
|
||||
grid_data["zones"] = zones
|
||||
await update_session_db(session_id, grid_editor_result=grid_data)
|
||||
|
||||
logger.info(
|
||||
"build-box-grids session %s: %d boxes processed (%d words spell-fixed) from %d detected",
|
||||
session_id, box_count, spell_fixes, len(detected_boxes),
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"box_zones_rebuilt": box_count,
|
||||
"total_detected_boxes": len(detected_boxes),
|
||||
"spell_fixes": spell_fixes,
|
||||
"zones": zones,
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Grid Editor API — grid build, save, and retrieve endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from grid_build_core import _build_grid_core
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/build-grid")
|
||||
async def build_grid(
|
||||
session_id: str,
|
||||
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
|
||||
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
|
||||
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
|
||||
):
|
||||
"""Build a structured, zone-aware grid from existing Kombi word results.
|
||||
|
||||
Requires that paddle-kombi or rapid-kombi has already been run on the session.
|
||||
Uses the image for box detection and the word positions for grid structuring.
|
||||
|
||||
Query params:
|
||||
ipa_mode: "auto" (only when English IPA detected), "all" (force), "none" (skip)
|
||||
syllable_mode: "auto" (only when original has dividers), "all" (force), "none" (skip)
|
||||
|
||||
Returns a StructuredGrid with zones, each containing their own
|
||||
columns, rows, and cells — ready for the frontend Excel-like editor.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
try:
|
||||
result = await _build_grid_core(
|
||||
session_id, session,
|
||||
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
|
||||
enhance=enhance,
|
||||
max_columns=max_cols if max_cols > 0 else None,
|
||||
min_conf=min_conf if min_conf > 0 else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Save automatic grid snapshot for later comparison with manual corrections
|
||||
# Lazy import to avoid circular dependency with ocr_pipeline_regression
|
||||
from ocr_pipeline_regression import _build_reference_snapshot
|
||||
|
||||
wr = session.get("word_result") or {}
|
||||
engine = wr.get("ocr_engine", "")
|
||||
if engine in ("kombi", "rapid_kombi"):
|
||||
auto_pipeline = "kombi"
|
||||
elif engine == "paddle_direct":
|
||||
auto_pipeline = "paddle-direct"
|
||||
else:
|
||||
auto_pipeline = "pipeline"
|
||||
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["auto_grid_snapshot"] = auto_snapshot
|
||||
|
||||
# Persist to DB and advance current_step to 11 (reconstruction complete)
|
||||
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
|
||||
|
||||
logger.info(
|
||||
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
|
||||
"%d boxes in %.2fs",
|
||||
session_id,
|
||||
len(result.get("zones", [])),
|
||||
result.get("summary", {}).get("total_columns", 0),
|
||||
result.get("summary", {}).get("total_rows", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
result.get("boxes_detected", 0),
|
||||
result.get("duration_seconds", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/rerun-ocr-and-build-grid")
|
||||
async def rerun_ocr_and_build_grid(
|
||||
session_id: str,
|
||||
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
|
||||
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
|
||||
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
|
||||
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
|
||||
vision_fusion: bool = Query(False, description="Step 4: Vision-LLM fusion for degraded scans"),
|
||||
doc_category: str = Query("", description="Document type for Vision-LLM prompt context"),
|
||||
):
|
||||
"""Re-run OCR with quality settings, then rebuild the grid.
|
||||
|
||||
Unlike build-grid (which only rebuilds from existing words),
|
||||
this endpoint re-runs the full OCR pipeline on the cropped image
|
||||
with optional CLAHE enhancement, then builds the grid.
|
||||
|
||||
Steps executed: Image Enhancement -> OCR -> Grid Build
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
import time as _time
|
||||
t0 = _time.time()
|
||||
|
||||
# 1. Load the cropped/dewarped image from cache or session
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No cropped/dewarped image available. Run preprocessing steps first.")
|
||||
|
||||
import numpy as np
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
ocr_input = dewarped_bgr.copy()
|
||||
|
||||
# 2. Scan quality assessment
|
||||
scan_quality_info = {}
|
||||
try:
|
||||
from scan_quality import score_scan_quality
|
||||
quality_report = score_scan_quality(ocr_input)
|
||||
scan_quality_info = quality_report.to_dict()
|
||||
actual_min_conf = min_conf if min_conf > 0 else quality_report.recommended_min_conf
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: scan quality failed: {e}")
|
||||
actual_min_conf = min_conf if min_conf > 0 else 40
|
||||
|
||||
# 3. Image enhancement (Step 3)
|
||||
is_degraded = scan_quality_info.get("is_degraded", False)
|
||||
if enhance and is_degraded:
|
||||
try:
|
||||
from ocr_image_enhance import enhance_for_ocr
|
||||
ocr_input = enhance_for_ocr(ocr_input, is_degraded=True)
|
||||
logger.info("rerun-ocr: CLAHE enhancement applied")
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: enhancement failed: {e}")
|
||||
|
||||
# 4. Run dual-engine OCR
|
||||
from PIL import Image
|
||||
import pytesseract
|
||||
|
||||
# RapidOCR
|
||||
rapid_words = []
|
||||
try:
|
||||
from cv_ocr_engines import ocr_region_rapid
|
||||
from cv_vocab_types import PageRegion
|
||||
full_region = PageRegion(type="full_page", x=0, y=0, width=img_w, height=img_h)
|
||||
rapid_words = ocr_region_rapid(ocr_input, full_region) or []
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: RapidOCR failed: {e}")
|
||||
|
||||
# Tesseract
|
||||
pil_img = Image.fromarray(ocr_input[:, :, ::-1])
|
||||
data = pytesseract.image_to_data(pil_img, lang='eng+deu', config='--psm 6 --oem 3', output_type=pytesseract.Output.DICT)
|
||||
tess_words = []
|
||||
for i in range(len(data["text"])):
|
||||
text = (data["text"][i] or "").strip()
|
||||
conf_raw = str(data["conf"][i])
|
||||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||
if not text or conf < actual_min_conf:
|
||||
continue
|
||||
tess_words.append({
|
||||
"text": text, "left": data["left"][i], "top": data["top"][i],
|
||||
"width": data["width"][i], "height": data["height"][i], "conf": conf,
|
||||
})
|
||||
|
||||
# 5. Merge OCR results
|
||||
from ocr_pipeline_ocr_merge import _split_paddle_multi_words, _merge_paddle_tesseract, _deduplicate_words
|
||||
rapid_split = _split_paddle_multi_words(rapid_words) if rapid_words else []
|
||||
if rapid_split or tess_words:
|
||||
merged_words = _merge_paddle_tesseract(rapid_split, tess_words)
|
||||
merged_words = _deduplicate_words(merged_words)
|
||||
else:
|
||||
merged_words = tess_words
|
||||
|
||||
# 6. Store updated word_result in session
|
||||
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
|
||||
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
|
||||
for w in merged_words]
|
||||
word_result = {
|
||||
"cells": [{"text": " ".join(w["text"] for w in merged_words),
|
||||
"word_boxes": cells_for_storage}],
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"ocr_engine": "rapid_kombi",
|
||||
"word_count": len(merged_words),
|
||||
"raw_paddle_words": rapid_words,
|
||||
}
|
||||
# 6b. Vision-LLM Fusion (Step 4) — correct OCR using Vision model
|
||||
vision_applied = False
|
||||
if vision_fusion:
|
||||
try:
|
||||
from vision_ocr_fusion import vision_fuse_ocr
|
||||
category = doc_category or session.get("document_category") or "vokabelseite"
|
||||
logger.info(f"rerun-ocr: running Vision-LLM fusion (category={category})")
|
||||
merged_words = await vision_fuse_ocr(ocr_input, merged_words, category)
|
||||
vision_applied = True
|
||||
# Rebuild storage from fused words
|
||||
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
|
||||
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
|
||||
for w in merged_words]
|
||||
word_result["cells"] = [{"text": " ".join(w["text"] for w in merged_words),
|
||||
"word_boxes": cells_for_storage}]
|
||||
word_result["word_count"] = len(merged_words)
|
||||
word_result["ocr_engine"] = "vision_fusion"
|
||||
except Exception as e:
|
||||
logger.warning(f"rerun-ocr: Vision-LLM fusion failed: {e}")
|
||||
|
||||
await update_session_db(session_id, word_result=word_result)
|
||||
|
||||
# Reload session with updated word_result
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
ocr_duration = _time.time() - t0
|
||||
logger.info(
|
||||
"rerun-ocr session %s: %d words (rapid=%d, tess=%d, merged=%d) in %.1fs "
|
||||
"(enhance=%s, min_conf=%d, quality=%s)",
|
||||
session_id, len(merged_words), len(rapid_words), len(tess_words),
|
||||
len(merged_words), ocr_duration, enhance, actual_min_conf,
|
||||
scan_quality_info.get("quality_pct", "?"),
|
||||
)
|
||||
|
||||
# 7. Build grid from new words
|
||||
try:
|
||||
result = await _build_grid_core(
|
||||
session_id, session,
|
||||
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
|
||||
enhance=enhance,
|
||||
max_columns=max_cols if max_cols > 0 else None,
|
||||
min_conf=min_conf if min_conf > 0 else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Persist grid
|
||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
||||
|
||||
# Add quality info to response
|
||||
result["scan_quality"] = scan_quality_info
|
||||
result["ocr_stats"] = {
|
||||
"rapid_words": len(rapid_words),
|
||||
"tess_words": len(tess_words),
|
||||
"merged_words": len(merged_words),
|
||||
"min_conf_used": actual_min_conf,
|
||||
"enhance_applied": enhance and is_degraded,
|
||||
"vision_fusion_applied": vision_applied,
|
||||
"document_category": doc_category or session.get("document_category", ""),
|
||||
"ocr_duration_seconds": round(ocr_duration, 1),
|
||||
}
|
||||
|
||||
total_duration = _time.time() - t0
|
||||
logger.info(
|
||||
"rerun-ocr+build-grid session %s: %d zones, %d cols, %d cells in %.1fs",
|
||||
session_id,
|
||||
len(result.get("zones", [])),
|
||||
result.get("summary", {}).get("total_columns", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
total_duration,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/save-grid")
|
||||
async def save_grid(session_id: str, request: Request):
|
||||
"""Save edited grid data from the frontend Excel-like editor.
|
||||
|
||||
Receives the full StructuredGrid with user edits (text changes,
|
||||
formatting changes like bold columns, header rows, etc.) and
|
||||
persists it to the session's grid_editor_result.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Validate basic structure
|
||||
if "zones" not in body:
|
||||
raise HTTPException(status_code=400, detail="Missing 'zones' in request body")
|
||||
|
||||
# Preserve metadata from the original build
|
||||
existing = session.get("grid_editor_result") or {}
|
||||
result = {
|
||||
"session_id": session_id,
|
||||
"image_width": body.get("image_width", existing.get("image_width", 0)),
|
||||
"image_height": body.get("image_height", existing.get("image_height", 0)),
|
||||
"zones": body["zones"],
|
||||
"boxes_detected": body.get("boxes_detected", existing.get("boxes_detected", 0)),
|
||||
"summary": body.get("summary", existing.get("summary", {})),
|
||||
"formatting": body.get("formatting", existing.get("formatting", {})),
|
||||
"duration_seconds": existing.get("duration_seconds", 0),
|
||||
"edited": True,
|
||||
}
|
||||
|
||||
await update_session_db(session_id, grid_editor_result=result, current_step=11)
|
||||
|
||||
logger.info("save-grid session %s: %d zones saved", session_id, len(body["zones"]))
|
||||
|
||||
return {"session_id": session_id, "saved": True}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/grid-editor")
|
||||
async def get_grid(session_id: str):
|
||||
"""Retrieve the current grid editor state for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
result = session.get("grid_editor_result")
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No grid editor data. Run build-grid first.",
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Grid Editor API — gutter repair endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/gutter-repair")
|
||||
async def gutter_repair(session_id: str):
|
||||
"""Analyse grid for gutter-edge OCR errors and return repair suggestions.
|
||||
|
||||
Detects:
|
||||
- Words truncated/blurred at the book binding (spell_fix)
|
||||
- Words split across rows with missing hyphen chars (hyphen_join)
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No grid data. Run build-grid first.",
|
||||
)
|
||||
|
||||
from cv_gutter_repair import analyse_grid_for_gutter_repair
|
||||
|
||||
image_width = grid_data.get("image_width", 0)
|
||||
result = analyse_grid_for_gutter_repair(grid_data, image_width=image_width)
|
||||
|
||||
# Persist suggestions in ground_truth.gutter_repair (avoids DB migration)
|
||||
gt = session.get("ground_truth") or {}
|
||||
gt["gutter_repair"] = result
|
||||
await update_session_db(session_id, ground_truth=gt)
|
||||
|
||||
logger.info(
|
||||
"gutter-repair session %s: %d suggestions in %.2fs",
|
||||
session_id,
|
||||
result.get("stats", {}).get("suggestions_found", 0),
|
||||
result.get("duration_seconds", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/gutter-repair/apply")
|
||||
async def gutter_repair_apply(session_id: str, request: Request):
|
||||
"""Apply accepted gutter repair suggestions to the grid.
|
||||
|
||||
Body: { "accepted": ["suggestion_id_1", "suggestion_id_2", ...] }
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data.")
|
||||
|
||||
gt = session.get("ground_truth") or {}
|
||||
gutter_result = gt.get("gutter_repair")
|
||||
if not gutter_result:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No gutter repair data. Run gutter-repair first.",
|
||||
)
|
||||
|
||||
body = await request.json()
|
||||
accepted_ids = body.get("accepted", [])
|
||||
if not accepted_ids:
|
||||
return {"applied_count": 0, "changes": []}
|
||||
|
||||
# text_overrides: { suggestion_id: "alternative_text" }
|
||||
# Allows the user to pick a different correction from the alternatives list
|
||||
text_overrides = body.get("text_overrides", {})
|
||||
|
||||
from cv_gutter_repair import apply_gutter_suggestions
|
||||
|
||||
suggestions = gutter_result.get("suggestions", [])
|
||||
|
||||
# Apply user-selected alternatives before passing to apply
|
||||
for s in suggestions:
|
||||
sid = s.get("id", "")
|
||||
if sid in text_overrides and text_overrides[sid]:
|
||||
s["suggested_text"] = text_overrides[sid]
|
||||
|
||||
result = apply_gutter_suggestions(grid_data, accepted_ids, suggestions)
|
||||
|
||||
# Save updated grid back to session
|
||||
await update_session_db(session_id, grid_editor_result=grid_data)
|
||||
|
||||
logger.info(
|
||||
"gutter-repair/apply session %s: %d changes applied",
|
||||
session_id,
|
||||
result.get("applied_count", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Grid Editor API — unified grid endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/build-unified-grid")
|
||||
async def build_unified_grid_endpoint(session_id: str):
|
||||
"""Build a single-zone unified grid merging content + box zones.
|
||||
|
||||
Takes the existing multi-zone grid_editor_result and produces a
|
||||
unified grid where boxes are integrated into the main row sequence.
|
||||
Persists as unified_grid_result (preserves original multi-zone data).
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
grid_data = session.get("grid_editor_result")
|
||||
if not grid_data:
|
||||
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
|
||||
|
||||
from unified_grid import build_unified_grid
|
||||
|
||||
result = build_unified_grid(
|
||||
zones=grid_data.get("zones", []),
|
||||
image_width=grid_data.get("image_width", 0),
|
||||
image_height=grid_data.get("image_height", 0),
|
||||
layout_metrics=grid_data.get("layout_metrics", {}),
|
||||
)
|
||||
|
||||
# Persist as separate field (don't overwrite original multi-zone grid)
|
||||
await update_session_db(session_id, unified_grid_result=result)
|
||||
|
||||
logger.info(
|
||||
"build-unified-grid session %s: %d rows, %d cells",
|
||||
session_id,
|
||||
result.get("summary", {}).get("total_rows", 0),
|
||||
result.get("summary", {}).get("total_cells", 0),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/unified-grid")
|
||||
async def get_unified_grid(session_id: str):
|
||||
"""Retrieve the unified grid for a session."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
result = session.get("unified_grid_result")
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No unified grid. Run build-unified-grid first.",
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,68 +1,43 @@
|
||||
"""
|
||||
Unified Inbox Mail API
|
||||
Unified Inbox Mail API — barrel re-export.
|
||||
|
||||
FastAPI router for the mail system.
|
||||
The actual endpoints live in:
|
||||
- api_accounts.py (account CRUD, test, sync)
|
||||
- api_inbox.py (unified inbox, email detail, send)
|
||||
- api_ai.py (AI analysis, response suggestions)
|
||||
- api_tasks.py (task CRUD, dashboard, from-email)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
EmailAccountCreate,
|
||||
EmailAccountUpdate,
|
||||
EmailAccount,
|
||||
AccountTestResult,
|
||||
AggregatedEmail,
|
||||
EmailSearchParams,
|
||||
TaskCreate,
|
||||
TaskUpdate,
|
||||
InboxTask,
|
||||
TaskDashboardStats,
|
||||
EmailComposeRequest,
|
||||
EmailSendResult,
|
||||
MailStats,
|
||||
MailHealthCheck,
|
||||
EmailAnalysisResult,
|
||||
ResponseSuggestion,
|
||||
TaskStatus,
|
||||
TaskPriority,
|
||||
EmailCategory,
|
||||
)
|
||||
from .mail_db import (
|
||||
init_mail_tables,
|
||||
create_email_account,
|
||||
get_email_accounts,
|
||||
get_email_account,
|
||||
delete_email_account,
|
||||
get_unified_inbox,
|
||||
get_email,
|
||||
mark_email_read,
|
||||
mark_email_starred,
|
||||
get_mail_stats,
|
||||
log_mail_audit,
|
||||
)
|
||||
from .credentials import get_credentials_service
|
||||
from .aggregator import get_mail_aggregator
|
||||
from .ai_service import get_ai_email_service
|
||||
from .task_service import get_task_service
|
||||
from .models import MailHealthCheck, MailStats
|
||||
from .mail_db import init_mail_tables, get_mail_stats
|
||||
|
||||
from .api_accounts import router as _accounts_router
|
||||
from .api_inbox import router as _inbox_router
|
||||
from .api_ai import router as _ai_router
|
||||
from .api_tasks import router as _tasks_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
|
||||
router = APIRouter()
|
||||
|
||||
# Merge sub-routers
|
||||
router.include_router(_accounts_router)
|
||||
router.include_router(_inbox_router)
|
||||
router.include_router(_ai_router)
|
||||
router.include_router(_tasks_router)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health & Init
|
||||
# Health & Init (kept here as they are small)
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/health", response_model=MailHealthCheck)
|
||||
@router.get("/api/v1/mail/health", response_model=MailHealthCheck)
|
||||
async def health_check():
|
||||
"""Health check for the mail system."""
|
||||
# TODO: Implement full health check
|
||||
return MailHealthCheck(
|
||||
status="healthy",
|
||||
database_connected=True,
|
||||
@@ -70,7 +45,7 @@ async def health_check():
|
||||
)
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
@router.post("/api/v1/mail/init")
|
||||
async def initialize_mail_system():
|
||||
"""Initialize mail database tables."""
|
||||
success = await init_mail_tables()
|
||||
@@ -79,573 +54,14 @@ async def initialize_mail_system():
|
||||
return {"status": "initialized"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Account Management
|
||||
# =============================================================================
|
||||
|
||||
class AccountCreateRequest(BaseModel):
|
||||
"""Request to create an email account."""
|
||||
email: str
|
||||
display_name: str
|
||||
account_type: str = "personal"
|
||||
imap_host: str
|
||||
imap_port: int = 993
|
||||
imap_ssl: bool = True
|
||||
smtp_host: str
|
||||
smtp_port: int = 465
|
||||
smtp_ssl: bool = True
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/accounts", response_model=dict)
|
||||
async def create_account(
|
||||
request: AccountCreateRequest,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a new email account."""
|
||||
credentials_service = get_credentials_service()
|
||||
|
||||
# Store credentials securely
|
||||
vault_path = await credentials_service.store_credentials(
|
||||
account_id=f"{user_id}_{request.email}",
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
)
|
||||
|
||||
# Create account in database
|
||||
account_id = await create_email_account(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
account_type=request.account_type,
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
imap_ssl=request.imap_ssl,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
smtp_ssl=request.smtp_ssl,
|
||||
vault_path=vault_path,
|
||||
)
|
||||
|
||||
if not account_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create account")
|
||||
|
||||
# Log audit
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="account_created",
|
||||
entity_type="account",
|
||||
entity_id=account_id,
|
||||
details={"email": request.email},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return {"id": account_id, "status": "created"}
|
||||
|
||||
|
||||
@router.get("/accounts", response_model=List[dict])
|
||||
async def list_accounts(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: Optional[str] = Query(None, description="Tenant ID"),
|
||||
):
|
||||
"""List all email accounts for a user."""
|
||||
accounts = await get_email_accounts(user_id, tenant_id)
|
||||
# Remove sensitive fields
|
||||
for account in accounts:
|
||||
account.pop("vault_path", None)
|
||||
return accounts
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}", response_model=dict)
|
||||
async def get_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
account.pop("vault_path", None)
|
||||
return account
|
||||
|
||||
|
||||
@router.delete("/accounts/{account_id}")
|
||||
async def remove_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Delete an email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
|
||||
# Delete credentials
|
||||
credentials_service = get_credentials_service()
|
||||
vault_path = account.get("vault_path", "")
|
||||
if vault_path:
|
||||
await credentials_service.delete_credentials(account_id, vault_path)
|
||||
|
||||
# Delete from database (cascades to emails)
|
||||
success = await delete_email_account(account_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete account")
|
||||
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="account_deleted",
|
||||
entity_type="account",
|
||||
entity_id=account_id,
|
||||
)
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/accounts/{account_id}/test", response_model=AccountTestResult)
|
||||
async def test_account_connection(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Test connection for an email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
|
||||
# Get credentials
|
||||
credentials_service = get_credentials_service()
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await credentials_service.get_credentials(account_id, vault_path)
|
||||
|
||||
if not creds:
|
||||
return AccountTestResult(
|
||||
success=False,
|
||||
error_message="Credentials not found"
|
||||
)
|
||||
|
||||
# Test connection
|
||||
aggregator = get_mail_aggregator()
|
||||
result = await aggregator.test_account_connection(
|
||||
imap_host=account["imap_host"],
|
||||
imap_port=account["imap_port"],
|
||||
imap_ssl=account["imap_ssl"],
|
||||
smtp_host=account["smtp_host"],
|
||||
smtp_port=account["smtp_port"],
|
||||
smtp_ssl=account["smtp_ssl"],
|
||||
email_address=creds.email,
|
||||
password=creds.password,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConnectionTestRequest(BaseModel):
|
||||
"""Request to test connection before saving account."""
|
||||
email: str
|
||||
imap_host: str
|
||||
imap_port: int = 993
|
||||
imap_ssl: bool = True
|
||||
smtp_host: str
|
||||
smtp_port: int = 465
|
||||
smtp_ssl: bool = True
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/accounts/test-connection", response_model=AccountTestResult)
|
||||
async def test_connection_before_save(request: ConnectionTestRequest):
|
||||
"""
|
||||
Test IMAP/SMTP connection before saving an account.
|
||||
|
||||
This allows the wizard to verify credentials are correct
|
||||
before creating the account in the database.
|
||||
"""
|
||||
aggregator = get_mail_aggregator()
|
||||
|
||||
result = await aggregator.test_account_connection(
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
imap_ssl=request.imap_ssl,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
smtp_ssl=request.smtp_ssl,
|
||||
email_address=request.email,
|
||||
password=request.password,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/accounts/{account_id}/sync")
|
||||
async def sync_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
max_emails: int = Query(100, ge=1, le=500),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Sync emails from an account."""
|
||||
aggregator = get_mail_aggregator()
|
||||
|
||||
try:
|
||||
new_count, total_count = await aggregator.sync_account(
|
||||
account_id=account_id,
|
||||
user_id=user_id,
|
||||
max_emails=max_emails,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "synced",
|
||||
"new_emails": new_count,
|
||||
"total_emails": total_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unified Inbox
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/inbox", response_model=List[dict])
|
||||
async def get_inbox(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
account_ids: Optional[str] = Query(None, description="Comma-separated account IDs"),
|
||||
categories: Optional[str] = Query(None, description="Comma-separated categories"),
|
||||
is_read: Optional[bool] = Query(None),
|
||||
is_starred: Optional[bool] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Get unified inbox with all accounts aggregated."""
|
||||
# Parse comma-separated values
|
||||
account_id_list = account_ids.split(",") if account_ids else None
|
||||
category_list = categories.split(",") if categories else None
|
||||
|
||||
emails = await get_unified_inbox(
|
||||
user_id=user_id,
|
||||
account_ids=account_id_list,
|
||||
categories=category_list,
|
||||
is_read=is_read,
|
||||
is_starred=is_starred,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
@router.get("/inbox/{email_id}", response_model=dict)
|
||||
async def get_email_detail(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single email with full details."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
# Mark as read
|
||||
await mark_email_read(email_id, user_id, is_read=True)
|
||||
|
||||
return email_data
|
||||
|
||||
|
||||
@router.post("/inbox/{email_id}/read")
|
||||
async def mark_read(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
is_read: bool = Query(True),
|
||||
):
|
||||
"""Mark email as read/unread."""
|
||||
success = await mark_email_read(email_id, user_id, is_read)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update email")
|
||||
return {"status": "updated", "is_read": is_read}
|
||||
|
||||
|
||||
@router.post("/inbox/{email_id}/star")
|
||||
async def mark_starred(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
is_starred: bool = Query(True),
|
||||
):
|
||||
"""Mark email as starred/unstarred."""
|
||||
success = await mark_email_starred(email_id, user_id, is_starred)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update email")
|
||||
return {"status": "updated", "is_starred": is_starred}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Send Email
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/send", response_model=EmailSendResult)
|
||||
async def send_email(
|
||||
request: EmailComposeRequest,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Send an email."""
|
||||
aggregator = get_mail_aggregator()
|
||||
result = await aggregator.send_email(
|
||||
account_id=request.account_id,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="email_sent",
|
||||
entity_type="email",
|
||||
details={
|
||||
"account_id": request.account_id,
|
||||
"to": request.to,
|
||||
"subject": request.subject,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AI Analysis
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/analyze/{email_id}", response_model=EmailAnalysisResult)
|
||||
async def analyze_email(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Run AI analysis on an email."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
ai_service = get_ai_email_service()
|
||||
result = await ai_service.analyze_email(
|
||||
email_id=email_id,
|
||||
sender_email=email_data.get("sender_email", ""),
|
||||
sender_name=email_data.get("sender_name"),
|
||||
subject=email_data.get("subject", ""),
|
||||
body_text=email_data.get("body_text"),
|
||||
body_preview=email_data.get("body_preview"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/suggestions/{email_id}", response_model=List[ResponseSuggestion])
|
||||
async def get_response_suggestions(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get AI-generated response suggestions for an email."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
ai_service = get_ai_email_service()
|
||||
|
||||
# Use stored analysis if available
|
||||
from .models import SenderType, EmailCategory as EC
|
||||
sender_type = SenderType(email_data.get("sender_type", "unbekannt"))
|
||||
category = EC(email_data.get("category", "sonstiges"))
|
||||
|
||||
suggestions = await ai_service.suggest_response(
|
||||
subject=email_data.get("subject", ""),
|
||||
body_text=email_data.get("body_text", ""),
|
||||
sender_type=sender_type,
|
||||
category=category,
|
||||
)
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tasks (Arbeitsvorrat)
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/tasks", response_model=List[dict])
|
||||
async def list_tasks(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
priority: Optional[str] = Query(None, description="Filter by priority"),
|
||||
include_completed: bool = Query(False),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Get all tasks for a user."""
|
||||
task_service = get_task_service()
|
||||
|
||||
status_enum = TaskStatus(status) if status else None
|
||||
priority_enum = TaskPriority(priority) if priority else None
|
||||
|
||||
tasks = await task_service.get_user_tasks(
|
||||
user_id=user_id,
|
||||
status=status_enum,
|
||||
priority=priority_enum,
|
||||
include_completed=include_completed,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=dict)
|
||||
async def create_task(
|
||||
request: TaskCreate,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a new task manually."""
|
||||
task_service = get_task_service()
|
||||
|
||||
task_id = await task_service.create_manual_task(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
task_data=request,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create task")
|
||||
|
||||
return {"id": task_id, "status": "created"}
|
||||
|
||||
|
||||
@router.get("/tasks/dashboard", response_model=TaskDashboardStats)
|
||||
async def get_task_dashboard(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get dashboard statistics for tasks."""
|
||||
task_service = get_task_service()
|
||||
return await task_service.get_dashboard_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=dict)
|
||||
async def get_task(
|
||||
task_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single task."""
|
||||
task_service = get_task_service()
|
||||
task = await task_service.get_task(task_id, user_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.put("/tasks/{task_id}")
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
request: TaskUpdate,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Update a task."""
|
||||
task_service = get_task_service()
|
||||
|
||||
success = await task_service.update_task(task_id, user_id, request)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update task")
|
||||
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/complete")
|
||||
async def complete_task(
|
||||
task_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Mark a task as completed."""
|
||||
task_service = get_task_service()
|
||||
|
||||
success = await task_service.mark_completed(task_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to complete task")
|
||||
|
||||
return {"status": "completed"}
|
||||
|
||||
|
||||
@router.post("/tasks/from-email/{email_id}")
|
||||
async def create_task_from_email(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a task from an email (after analysis)."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
# Get deadlines from stored analysis
|
||||
deadlines_raw = email_data.get("detected_deadlines", [])
|
||||
from .models import DeadlineExtraction, SenderType
|
||||
|
||||
deadlines = []
|
||||
for d in deadlines_raw:
|
||||
try:
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=datetime.fromisoformat(d["date"]),
|
||||
description=d.get("description", "Frist"),
|
||||
confidence=0.8,
|
||||
source_text="",
|
||||
is_firm=d.get("is_firm", True),
|
||||
))
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
|
||||
sender_type = None
|
||||
if email_data.get("sender_type"):
|
||||
try:
|
||||
sender_type = SenderType(email_data["sender_type"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
task_service = get_task_service()
|
||||
task_id = await task_service.create_task_from_email(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
email_id=email_id,
|
||||
deadlines=deadlines,
|
||||
sender_type=sender_type,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create task")
|
||||
|
||||
return {"id": task_id, "status": "created"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/stats", response_model=MailStats)
|
||||
@router.get("/api/v1/mail/stats", response_model=MailStats)
|
||||
async def get_statistics(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get overall mail statistics for a user."""
|
||||
stats = await get_mail_stats(user_id)
|
||||
return MailStats(**stats)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sync All
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sync-all")
|
||||
async def sync_all_accounts(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: Optional[str] = Query(None),
|
||||
):
|
||||
"""Sync all email accounts for a user."""
|
||||
aggregator = get_mail_aggregator()
|
||||
results = await aggregator.sync_all_accounts(user_id, tenant_id)
|
||||
return {"status": "synced", "results": results}
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Mail API — account management and sync endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .models import AccountTestResult
|
||||
from .mail_db import (
|
||||
create_email_account,
|
||||
get_email_accounts,
|
||||
get_email_account,
|
||||
delete_email_account,
|
||||
log_mail_audit,
|
||||
)
|
||||
from .credentials import get_credentials_service
|
||||
from .aggregator import get_mail_aggregator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
|
||||
|
||||
|
||||
class AccountCreateRequest(BaseModel):
|
||||
"""Request to create an email account."""
|
||||
email: str
|
||||
display_name: str
|
||||
account_type: str = "personal"
|
||||
imap_host: str
|
||||
imap_port: int = 993
|
||||
imap_ssl: bool = True
|
||||
smtp_host: str
|
||||
smtp_port: int = 465
|
||||
smtp_ssl: bool = True
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/accounts", response_model=dict)
|
||||
async def create_account(
|
||||
request: AccountCreateRequest,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a new email account."""
|
||||
credentials_service = get_credentials_service()
|
||||
|
||||
# Store credentials securely
|
||||
vault_path = await credentials_service.store_credentials(
|
||||
account_id=f"{user_id}_{request.email}",
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
)
|
||||
|
||||
# Create account in database
|
||||
account_id = await create_email_account(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
account_type=request.account_type,
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
imap_ssl=request.imap_ssl,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
smtp_ssl=request.smtp_ssl,
|
||||
vault_path=vault_path,
|
||||
)
|
||||
|
||||
if not account_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create account")
|
||||
|
||||
# Log audit
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="account_created",
|
||||
entity_type="account",
|
||||
entity_id=account_id,
|
||||
details={"email": request.email},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return {"id": account_id, "status": "created"}
|
||||
|
||||
|
||||
@router.get("/accounts", response_model=List[dict])
|
||||
async def list_accounts(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: Optional[str] = Query(None, description="Tenant ID"),
|
||||
):
|
||||
"""List all email accounts for a user."""
|
||||
accounts = await get_email_accounts(user_id, tenant_id)
|
||||
# Remove sensitive fields
|
||||
for account in accounts:
|
||||
account.pop("vault_path", None)
|
||||
return accounts
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}", response_model=dict)
|
||||
async def get_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
account.pop("vault_path", None)
|
||||
return account
|
||||
|
||||
|
||||
@router.delete("/accounts/{account_id}")
|
||||
async def remove_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Delete an email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
|
||||
# Delete credentials
|
||||
credentials_service = get_credentials_service()
|
||||
vault_path = account.get("vault_path", "")
|
||||
if vault_path:
|
||||
await credentials_service.delete_credentials(account_id, vault_path)
|
||||
|
||||
# Delete from database (cascades to emails)
|
||||
success = await delete_email_account(account_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete account")
|
||||
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="account_deleted",
|
||||
entity_type="account",
|
||||
entity_id=account_id,
|
||||
)
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/accounts/{account_id}/test", response_model=AccountTestResult)
|
||||
async def test_account_connection(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Test connection for an email account."""
|
||||
account = await get_email_account(account_id, user_id)
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
|
||||
# Get credentials
|
||||
credentials_service = get_credentials_service()
|
||||
vault_path = account.get("vault_path", "")
|
||||
creds = await credentials_service.get_credentials(account_id, vault_path)
|
||||
|
||||
if not creds:
|
||||
return AccountTestResult(
|
||||
success=False,
|
||||
error_message="Credentials not found"
|
||||
)
|
||||
|
||||
# Test connection
|
||||
aggregator = get_mail_aggregator()
|
||||
result = await aggregator.test_account_connection(
|
||||
imap_host=account["imap_host"],
|
||||
imap_port=account["imap_port"],
|
||||
imap_ssl=account["imap_ssl"],
|
||||
smtp_host=account["smtp_host"],
|
||||
smtp_port=account["smtp_port"],
|
||||
smtp_ssl=account["smtp_ssl"],
|
||||
email_address=creds.email,
|
||||
password=creds.password,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConnectionTestRequest(BaseModel):
|
||||
"""Request to test connection before saving account."""
|
||||
email: str
|
||||
imap_host: str
|
||||
imap_port: int = 993
|
||||
imap_ssl: bool = True
|
||||
smtp_host: str
|
||||
smtp_port: int = 465
|
||||
smtp_ssl: bool = True
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/accounts/test-connection", response_model=AccountTestResult)
|
||||
async def test_connection_before_save(request: ConnectionTestRequest):
|
||||
"""
|
||||
Test IMAP/SMTP connection before saving an account.
|
||||
|
||||
This allows the wizard to verify credentials are correct
|
||||
before creating the account in the database.
|
||||
"""
|
||||
aggregator = get_mail_aggregator()
|
||||
|
||||
result = await aggregator.test_account_connection(
|
||||
imap_host=request.imap_host,
|
||||
imap_port=request.imap_port,
|
||||
imap_ssl=request.imap_ssl,
|
||||
smtp_host=request.smtp_host,
|
||||
smtp_port=request.smtp_port,
|
||||
smtp_ssl=request.smtp_ssl,
|
||||
email_address=request.email,
|
||||
password=request.password,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/accounts/{account_id}/sync")
|
||||
async def sync_account(
|
||||
account_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
max_emails: int = Query(100, ge=1, le=500),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Sync emails from an account."""
|
||||
aggregator = get_mail_aggregator()
|
||||
|
||||
try:
|
||||
new_count, total_count = await aggregator.sync_account(
|
||||
account_id=account_id,
|
||||
user_id=user_id,
|
||||
max_emails=max_emails,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "synced",
|
||||
"new_emails": new_count,
|
||||
"total_emails": total_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/sync-all")
|
||||
async def sync_all_accounts(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: Optional[str] = Query(None),
|
||||
):
|
||||
"""Sync all email accounts for a user."""
|
||||
aggregator = get_mail_aggregator()
|
||||
results = await aggregator.sync_all_accounts(user_id, tenant_id)
|
||||
return {"status": "synced", "results": results}
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Mail API — AI analysis and response suggestion endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
EmailAnalysisResult,
|
||||
ResponseSuggestion,
|
||||
)
|
||||
from .mail_db import get_email
|
||||
from .ai_service import get_ai_email_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
|
||||
|
||||
|
||||
@router.post("/analyze/{email_id}", response_model=EmailAnalysisResult)
|
||||
async def analyze_email(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Run AI analysis on an email."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
ai_service = get_ai_email_service()
|
||||
result = await ai_service.analyze_email(
|
||||
email_id=email_id,
|
||||
sender_email=email_data.get("sender_email", ""),
|
||||
sender_name=email_data.get("sender_name"),
|
||||
subject=email_data.get("subject", ""),
|
||||
body_text=email_data.get("body_text"),
|
||||
body_preview=email_data.get("body_preview"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/suggestions/{email_id}", response_model=List[ResponseSuggestion])
|
||||
async def get_response_suggestions(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get AI-generated response suggestions for an email."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
ai_service = get_ai_email_service()
|
||||
|
||||
# Use stored analysis if available
|
||||
from .models import SenderType, EmailCategory as EC
|
||||
sender_type = SenderType(email_data.get("sender_type", "unbekannt"))
|
||||
category = EC(email_data.get("category", "sonstiges"))
|
||||
|
||||
suggestions = await ai_service.suggest_response(
|
||||
subject=email_data.get("subject", ""),
|
||||
body_text=email_data.get("body_text", ""),
|
||||
sender_type=sender_type,
|
||||
category=category,
|
||||
)
|
||||
|
||||
return suggestions
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Mail API — unified inbox, send, and email detail endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
EmailComposeRequest,
|
||||
EmailSendResult,
|
||||
)
|
||||
from .mail_db import (
|
||||
get_unified_inbox,
|
||||
get_email,
|
||||
mark_email_read,
|
||||
mark_email_starred,
|
||||
log_mail_audit,
|
||||
)
|
||||
from .aggregator import get_mail_aggregator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
|
||||
|
||||
|
||||
@router.get("/inbox", response_model=List[dict])
|
||||
async def get_inbox(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
account_ids: Optional[str] = Query(None, description="Comma-separated account IDs"),
|
||||
categories: Optional[str] = Query(None, description="Comma-separated categories"),
|
||||
is_read: Optional[bool] = Query(None),
|
||||
is_starred: Optional[bool] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Get unified inbox with all accounts aggregated."""
|
||||
# Parse comma-separated values
|
||||
account_id_list = account_ids.split(",") if account_ids else None
|
||||
category_list = categories.split(",") if categories else None
|
||||
|
||||
emails = await get_unified_inbox(
|
||||
user_id=user_id,
|
||||
account_ids=account_id_list,
|
||||
categories=category_list,
|
||||
is_read=is_read,
|
||||
is_starred=is_starred,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
@router.get("/inbox/{email_id}", response_model=dict)
|
||||
async def get_email_detail(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single email with full details."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
# Mark as read
|
||||
await mark_email_read(email_id, user_id, is_read=True)
|
||||
|
||||
return email_data
|
||||
|
||||
|
||||
@router.post("/inbox/{email_id}/read")
|
||||
async def mark_read(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
is_read: bool = Query(True),
|
||||
):
|
||||
"""Mark email as read/unread."""
|
||||
success = await mark_email_read(email_id, user_id, is_read)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update email")
|
||||
return {"status": "updated", "is_read": is_read}
|
||||
|
||||
|
||||
@router.post("/inbox/{email_id}/star")
|
||||
async def mark_starred(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
is_starred: bool = Query(True),
|
||||
):
|
||||
"""Mark email as starred/unstarred."""
|
||||
success = await mark_email_starred(email_id, user_id, is_starred)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update email")
|
||||
return {"status": "updated", "is_starred": is_starred}
|
||||
|
||||
|
||||
@router.post("/send", response_model=EmailSendResult)
|
||||
async def send_email(
|
||||
request: EmailComposeRequest,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Send an email."""
|
||||
aggregator = get_mail_aggregator()
|
||||
result = await aggregator.send_email(
|
||||
account_id=request.account_id,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
await log_mail_audit(
|
||||
user_id=user_id,
|
||||
action="email_sent",
|
||||
entity_type="email",
|
||||
details={
|
||||
"account_id": request.account_id,
|
||||
"to": request.to,
|
||||
"subject": request.subject,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Mail API — task (Arbeitsvorrat) endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
TaskCreate,
|
||||
TaskUpdate,
|
||||
TaskDashboardStats,
|
||||
TaskStatus,
|
||||
TaskPriority,
|
||||
)
|
||||
from .mail_db import get_email
|
||||
from .task_service import get_task_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=List[dict])
|
||||
async def list_tasks(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
priority: Optional[str] = Query(None, description="Filter by priority"),
|
||||
include_completed: bool = Query(False),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Get all tasks for a user."""
|
||||
task_service = get_task_service()
|
||||
|
||||
status_enum = TaskStatus(status) if status else None
|
||||
priority_enum = TaskPriority(priority) if priority else None
|
||||
|
||||
tasks = await task_service.get_user_tasks(
|
||||
user_id=user_id,
|
||||
status=status_enum,
|
||||
priority=priority_enum,
|
||||
include_completed=include_completed,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=dict)
|
||||
async def create_task(
|
||||
request: TaskCreate,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a new task manually."""
|
||||
task_service = get_task_service()
|
||||
|
||||
task_id = await task_service.create_manual_task(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
task_data=request,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create task")
|
||||
|
||||
return {"id": task_id, "status": "created"}
|
||||
|
||||
|
||||
@router.get("/tasks/dashboard", response_model=TaskDashboardStats)
|
||||
async def get_task_dashboard(
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get dashboard statistics for tasks."""
|
||||
task_service = get_task_service()
|
||||
return await task_service.get_dashboard_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=dict)
|
||||
async def get_task(
|
||||
task_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Get a single task."""
|
||||
task_service = get_task_service()
|
||||
task = await task_service.get_task(task_id, user_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.put("/tasks/{task_id}")
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
request: TaskUpdate,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Update a task."""
|
||||
task_service = get_task_service()
|
||||
|
||||
success = await task_service.update_task(task_id, user_id, request)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update task")
|
||||
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/complete")
|
||||
async def complete_task(
|
||||
task_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
):
|
||||
"""Mark a task as completed."""
|
||||
task_service = get_task_service()
|
||||
|
||||
success = await task_service.mark_completed(task_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to complete task")
|
||||
|
||||
return {"status": "completed"}
|
||||
|
||||
|
||||
@router.post("/tasks/from-email/{email_id}")
|
||||
async def create_task_from_email(
|
||||
email_id: str,
|
||||
user_id: str = Query(..., description="User ID"),
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
):
|
||||
"""Create a task from an email (after analysis)."""
|
||||
email_data = await get_email(email_id, user_id)
|
||||
if not email_data:
|
||||
raise HTTPException(status_code=404, detail="Email not found")
|
||||
|
||||
# Get deadlines from stored analysis
|
||||
deadlines_raw = email_data.get("detected_deadlines", [])
|
||||
from .models import DeadlineExtraction, SenderType
|
||||
|
||||
deadlines = []
|
||||
for d in deadlines_raw:
|
||||
try:
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=datetime.fromisoformat(d["date"]),
|
||||
description=d.get("description", "Frist"),
|
||||
confidence=0.8,
|
||||
source_text="",
|
||||
is_firm=d.get("is_firm", True),
|
||||
))
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
|
||||
sender_type = None
|
||||
if email_data.get("sender_type"):
|
||||
try:
|
||||
sender_type = SenderType(email_data["sender_type"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
task_service = get_task_service()
|
||||
task_id = await task_service.create_task_from_email(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
email_id=email_id,
|
||||
deadlines=deadlines,
|
||||
sender_type=sender_type,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to create task")
|
||||
|
||||
return {"id": task_id, "status": "created"}
|
||||
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from cv_vocab_pipeline import detect_and_fix_orientation
|
||||
from page_crop import detect_page_splits
|
||||
from ocr_pipeline_session_store import update_session_db
|
||||
|
||||
from orientation_crop_helpers import ensure_cached, append_pipeline_log
|
||||
from page_sub_sessions import create_page_sub_sessions_full
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1: Orientation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/orientation")
|
||||
async def detect_orientation(session_id: str):
|
||||
"""Detect and fix 90/180/270 degree rotations from scanners.
|
||||
|
||||
Reads the original image, applies orientation correction,
|
||||
stores the result as oriented_png.
|
||||
"""
|
||||
cached = await ensure_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("original_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Original image not available")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Detect and fix orientation
|
||||
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
|
||||
|
||||
duration = time.time() - t0
|
||||
|
||||
orientation_result = {
|
||||
"orientation_degrees": orientation_deg,
|
||||
"corrected": orientation_deg != 0,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Encode oriented image
|
||||
success, png_buf = cv2.imencode(".png", oriented_bgr)
|
||||
oriented_png = png_buf.tobytes() if success else b""
|
||||
|
||||
# Update cache
|
||||
cached["oriented_bgr"] = oriented_bgr
|
||||
cached["orientation_result"] = orientation_result
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
oriented_png=oriented_png,
|
||||
orientation_result=orientation_result,
|
||||
current_step=2,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
|
||||
session_id, orientation_deg,
|
||||
"corrected" if orientation_deg else "no change",
|
||||
duration,
|
||||
)
|
||||
|
||||
await append_pipeline_log(session_id, "orientation", {
|
||||
"orientation_degrees": orientation_deg,
|
||||
"corrected": orientation_deg != 0,
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = oriented_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**orientation_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/page-split")
|
||||
async def detect_page_split(session_id: str):
|
||||
"""Detect if the image is a double-page book spread and split into sub-sessions.
|
||||
|
||||
Must be called **after orientation** (step 1) and **before deskew** (step 2).
|
||||
Each sub-session receives the raw page region and goes through the full
|
||||
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
|
||||
independently, so each page gets its own deskew correction.
|
||||
|
||||
Returns ``{"multi_page": false}`` if only one page is detected.
|
||||
"""
|
||||
cached = await ensure_cached(session_id)
|
||||
|
||||
# Use oriented (preferred), fall back to original
|
||||
img_bgr = next(
|
||||
(v for k in ("oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for page-split detection")
|
||||
|
||||
t0 = time.time()
|
||||
page_splits = detect_page_splits(img_bgr)
|
||||
used_original = False
|
||||
|
||||
if not page_splits or len(page_splits) < 2:
|
||||
# Orientation may have rotated a landscape double-page spread to
|
||||
# portrait. Try the original (pre-orientation) image as fallback.
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is not None and orig_bgr is not img_bgr:
|
||||
page_splits_orig = detect_page_splits(orig_bgr)
|
||||
if page_splits_orig and len(page_splits_orig) >= 2:
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: spread detected on "
|
||||
"ORIGINAL (orientation rotated it away)",
|
||||
session_id,
|
||||
)
|
||||
img_bgr = orig_bgr
|
||||
page_splits = page_splits_orig
|
||||
used_original = True
|
||||
|
||||
if not page_splits or len(page_splits) < 2:
|
||||
duration = time.time() - t0
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: single page (%.2fs)",
|
||||
session_id, duration,
|
||||
)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"multi_page": False,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Multi-page spread detected — create sub-sessions for full pipeline.
|
||||
# start_step=2 means "ready for deskew" (orientation already applied).
|
||||
# start_step=1 means "needs orientation too" (split from original image).
|
||||
start_step = 1 if used_original else 2
|
||||
sub_sessions = await create_page_sub_sessions_full(
|
||||
session_id, cached, img_bgr, page_splits, start_step=start_step,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
split_info: Dict[str, Any] = {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
"page_splits": page_splits,
|
||||
"used_original": used_original,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Mark parent session as split and hidden from session list
|
||||
await update_session_db(session_id, crop_result=split_info, status='split')
|
||||
cached["crop_result"] = split_info
|
||||
|
||||
await append_pipeline_log(session_id, "page_split", {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
|
||||
session_id, len(page_splits), duration,
|
||||
)
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**split_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"sub_sessions": sub_sessions,
|
||||
}
|
||||
@@ -1,694 +1,16 @@
|
||||
"""
|
||||
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
|
||||
|
||||
Step 1: Orientation detection (fix 90/180/270 degree rotations)
|
||||
Step 4 (UI index 3): Page cropping (after deskew + dewarp, so the image is straight)
|
||||
|
||||
These endpoints were extracted from the main pipeline to keep files manageable.
|
||||
Barrel re-export: merges routers from orientation_api and crop_api,
|
||||
and re-exports set_cache_ref for main.py.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_mod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from orientation_crop_helpers import set_cache_ref # noqa: F401
|
||||
from orientation_api import router as _orientation_router
|
||||
from crop_api import router as _crop_router
|
||||
|
||||
from cv_vocab_pipeline import detect_and_fix_orientation
|
||||
from page_crop import detect_and_crop_page, detect_page_splits
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
get_sub_sessions,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
|
||||
_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
|
||||
"""Set reference to the shared cache from ocr_pipeline_api."""
|
||||
global _cache
|
||||
_cache = cache
|
||||
|
||||
|
||||
async def _ensure_cached(session_id: str) -> Dict[str, Any]:
|
||||
"""Ensure session is in cache, loading from DB if needed."""
|
||||
if session_id in _cache:
|
||||
return _cache[session_id]
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
cache_entry: Dict[str, Any] = {
|
||||
"id": session_id,
|
||||
**session,
|
||||
"original_bgr": None,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
}
|
||||
|
||||
for img_type, bgr_key in [
|
||||
("original", "original_bgr"),
|
||||
("oriented", "oriented_bgr"),
|
||||
("cropped", "cropped_bgr"),
|
||||
("deskewed", "deskewed_bgr"),
|
||||
("dewarped", "dewarped_bgr"),
|
||||
]:
|
||||
png_data = await get_session_image(session_id, img_type)
|
||||
if png_data:
|
||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
cache_entry[bgr_key] = bgr
|
||||
|
||||
_cache[session_id] = cache_entry
|
||||
return cache_entry
|
||||
|
||||
|
||||
async def _append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
|
||||
"""Append a step entry to the pipeline log."""
|
||||
from datetime import datetime
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
return
|
||||
pipeline_log = session.get("pipeline_log") or {"steps": []}
|
||||
pipeline_log["steps"].append({
|
||||
"step": step,
|
||||
"completed_at": datetime.utcnow().isoformat(),
|
||||
"success": True,
|
||||
"duration_ms": duration_ms,
|
||||
"metrics": metrics,
|
||||
})
|
||||
await update_session_db(session_id, pipeline_log=pipeline_log)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1: Orientation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/orientation")
|
||||
async def detect_orientation(session_id: str):
|
||||
"""Detect and fix 90/180/270 degree rotations from scanners.
|
||||
|
||||
Reads the original image, applies orientation correction,
|
||||
stores the result as oriented_png.
|
||||
"""
|
||||
cached = await _ensure_cached(session_id)
|
||||
|
||||
img_bgr = cached.get("original_bgr")
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Original image not available")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Detect and fix orientation
|
||||
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
|
||||
|
||||
duration = time.time() - t0
|
||||
|
||||
orientation_result = {
|
||||
"orientation_degrees": orientation_deg,
|
||||
"corrected": orientation_deg != 0,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Encode oriented image
|
||||
success, png_buf = cv2.imencode(".png", oriented_bgr)
|
||||
oriented_png = png_buf.tobytes() if success else b""
|
||||
|
||||
# Update cache
|
||||
cached["oriented_bgr"] = oriented_bgr
|
||||
cached["orientation_result"] = orientation_result
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
oriented_png=oriented_png,
|
||||
orientation_result=orientation_result,
|
||||
current_step=2,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
|
||||
session_id, orientation_deg,
|
||||
"corrected" if orientation_deg else "no change",
|
||||
duration,
|
||||
)
|
||||
|
||||
await _append_pipeline_log(session_id, "orientation", {
|
||||
"orientation_degrees": orientation_deg,
|
||||
"corrected": orientation_deg != 0,
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = oriented_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**orientation_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/page-split")
|
||||
async def detect_page_split(session_id: str):
|
||||
"""Detect if the image is a double-page book spread and split into sub-sessions.
|
||||
|
||||
Must be called **after orientation** (step 1) and **before deskew** (step 2).
|
||||
Each sub-session receives the raw page region and goes through the full
|
||||
pipeline (deskew → dewarp → crop → columns → rows → words → grid)
|
||||
independently, so each page gets its own deskew correction.
|
||||
|
||||
Returns ``{"multi_page": false}`` if only one page is detected.
|
||||
"""
|
||||
cached = await _ensure_cached(session_id)
|
||||
|
||||
# Use oriented (preferred), fall back to original
|
||||
img_bgr = next(
|
||||
(v for k in ("oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for page-split detection")
|
||||
|
||||
t0 = time.time()
|
||||
page_splits = detect_page_splits(img_bgr)
|
||||
used_original = False
|
||||
|
||||
if not page_splits or len(page_splits) < 2:
|
||||
# Orientation may have rotated a landscape double-page spread to
|
||||
# portrait. Try the original (pre-orientation) image as fallback.
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is not None and orig_bgr is not img_bgr:
|
||||
page_splits_orig = detect_page_splits(orig_bgr)
|
||||
if page_splits_orig and len(page_splits_orig) >= 2:
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: spread detected on "
|
||||
"ORIGINAL (orientation rotated it away)",
|
||||
session_id,
|
||||
)
|
||||
img_bgr = orig_bgr
|
||||
page_splits = page_splits_orig
|
||||
used_original = True
|
||||
|
||||
if not page_splits or len(page_splits) < 2:
|
||||
duration = time.time() - t0
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: single page (%.2fs)",
|
||||
session_id, duration,
|
||||
)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"multi_page": False,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Multi-page spread detected — create sub-sessions for full pipeline.
|
||||
# start_step=2 means "ready for deskew" (orientation already applied).
|
||||
# start_step=1 means "needs orientation too" (split from original image).
|
||||
start_step = 1 if used_original else 2
|
||||
sub_sessions = await _create_page_sub_sessions_full(
|
||||
session_id, cached, img_bgr, page_splits, start_step=start_step,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
split_info: Dict[str, Any] = {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
"page_splits": page_splits,
|
||||
"used_original": used_original,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
|
||||
# Mark parent session as split and hidden from session list
|
||||
await update_session_db(session_id, crop_result=split_info, status='split')
|
||||
cached["crop_result"] = split_info
|
||||
|
||||
await _append_pipeline_log(session_id, "page_split", {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
|
||||
session_id, len(page_splits), duration,
|
||||
)
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**split_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"sub_sessions": sub_sessions,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/crop")
|
||||
async def auto_crop(session_id: str):
|
||||
"""Auto-detect and crop scanner/book borders.
|
||||
|
||||
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
|
||||
Falls back to oriented → original if earlier steps were skipped.
|
||||
|
||||
If the image is a multi-page spread (e.g. book on scanner), it will
|
||||
automatically split into separate sub-sessions per page, crop each
|
||||
individually, and return the split info.
|
||||
"""
|
||||
cached = await _ensure_cached(session_id)
|
||||
|
||||
# Use dewarped (preferred), fall back to oriented, then original
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# --- Check for existing sub-sessions (from page-split step) ---
|
||||
# If page-split already created sub-sessions, skip multi-page detection
|
||||
# in the crop step. Each sub-session runs its own crop independently.
|
||||
existing_subs = await get_sub_sessions(session_id)
|
||||
if existing_subs:
|
||||
crop_result = cached.get("crop_result") or {}
|
||||
if crop_result.get("multi_page"):
|
||||
# Already split — just return the existing info
|
||||
duration = time.time() - t0
|
||||
h, w = img_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"sub_sessions": [
|
||||
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing_subs)
|
||||
],
|
||||
"note": "Page split was already performed; each sub-session runs its own crop.",
|
||||
}
|
||||
|
||||
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
|
||||
page_splits = detect_page_splits(img_bgr)
|
||||
|
||||
if page_splits and len(page_splits) >= 2:
|
||||
# Multi-page spread detected — create sub-sessions
|
||||
sub_sessions = await _create_page_sub_sessions(
|
||||
session_id, cached, img_bgr, page_splits,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
crop_info: Dict[str, Any] = {
|
||||
"crop_applied": True,
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
"page_splits": page_splits,
|
||||
"duration_seconds": round(duration, 2),
|
||||
}
|
||||
cached["crop_result"] = crop_info
|
||||
|
||||
# Store the first page as the main cropped image for backward compat
|
||||
first_page = page_splits[0]
|
||||
first_bgr = img_bgr[
|
||||
first_page["y"]:first_page["y"] + first_page["height"],
|
||||
first_page["x"]:first_page["x"] + first_page["width"],
|
||||
].copy()
|
||||
first_cropped, _ = detect_and_crop_page(first_bgr)
|
||||
cached["cropped_bgr"] = first_cropped
|
||||
|
||||
ok, png_buf = cv2.imencode(".png", first_cropped)
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=png_buf.tobytes() if ok else b"",
|
||||
crop_result=crop_info,
|
||||
current_step=5,
|
||||
status='split',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
|
||||
session_id, len(page_splits), duration,
|
||||
)
|
||||
|
||||
await _append_pipeline_log(session_id, "crop", {
|
||||
"multi_page": True,
|
||||
"page_count": len(page_splits),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = first_cropped.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
"sub_sessions": sub_sessions,
|
||||
}
|
||||
|
||||
# --- Single page (normal) ---
|
||||
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
|
||||
|
||||
duration = time.time() - t0
|
||||
crop_info["duration_seconds"] = round(duration, 2)
|
||||
crop_info["multi_page"] = False
|
||||
|
||||
# Encode cropped image
|
||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
# Update cache
|
||||
cached["cropped_bgr"] = cropped_bgr
|
||||
cached["crop_result"] = crop_info
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_info,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
|
||||
session_id, crop_info["crop_applied"],
|
||||
crop_info.get("detected_format", "?"),
|
||||
duration,
|
||||
)
|
||||
|
||||
await _append_pipeline_log(session_id, "crop", {
|
||||
"crop_applied": crop_info["crop_applied"],
|
||||
"detected_format": crop_info.get("detected_format"),
|
||||
"format_confidence": crop_info.get("format_confidence"),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
h, w = cropped_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_info,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
|
||||
|
||||
async def _create_page_sub_sessions(
|
||||
parent_session_id: str,
|
||||
parent_cached: dict,
|
||||
full_img_bgr: np.ndarray,
|
||||
page_splits: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create sub-sessions for each detected page in a multi-page spread.
|
||||
|
||||
Each page region is individually cropped, then stored as a sub-session
|
||||
with its own cropped image ready for the rest of the pipeline.
|
||||
"""
|
||||
# Check for existing sub-sessions (idempotent)
|
||||
existing = await get_sub_sessions(parent_session_id)
|
||||
if existing:
|
||||
return [
|
||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing)
|
||||
]
|
||||
|
||||
parent_name = parent_cached.get("name", "Scan")
|
||||
parent_filename = parent_cached.get("filename", "scan.png")
|
||||
|
||||
sub_sessions: List[Dict[str, Any]] = []
|
||||
|
||||
for page in page_splits:
|
||||
pi = page["page_index"]
|
||||
px, py = page["x"], page["y"]
|
||||
pw, ph = page["width"], page["height"]
|
||||
|
||||
# Extract page region
|
||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||
|
||||
# Crop each page individually (remove its own borders)
|
||||
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
|
||||
|
||||
# Encode as PNG
|
||||
ok, png_buf = cv2.imencode(".png", cropped_page)
|
||||
page_png = png_buf.tobytes() if ok else b""
|
||||
|
||||
sub_id = str(uuid_mod.uuid4())
|
||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=parent_filename,
|
||||
original_png=page_png,
|
||||
)
|
||||
|
||||
# Pre-populate: set cropped = original (already cropped)
|
||||
await update_session_db(
|
||||
sub_id,
|
||||
cropped_png=page_png,
|
||||
crop_result=page_crop_info,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
ch, cw = cropped_page.shape[:2]
|
||||
sub_sessions.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"page_index": pi,
|
||||
"source_rect": page,
|
||||
"cropped_size": {"width": cw, "height": ch},
|
||||
"detected_format": page_crop_info.get("detected_format"),
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
|
||||
sub_id, pi + 1, px, pw, cw, ch,
|
||||
)
|
||||
|
||||
return sub_sessions
|
||||
|
||||
|
||||
async def _create_page_sub_sessions_full(
|
||||
parent_session_id: str,
|
||||
parent_cached: dict,
|
||||
full_img_bgr: np.ndarray,
|
||||
page_splits: List[Dict[str, Any]],
|
||||
start_step: int = 2,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
|
||||
|
||||
Unlike ``_create_page_sub_sessions`` (used by the crop step), these
|
||||
sub-sessions store the *uncropped* page region and start at
|
||||
``start_step`` (default 2 = ready for deskew; 1 if orientation still
|
||||
needed). Each page goes through its own pipeline independently,
|
||||
which is essential for book spreads where each page has a different tilt.
|
||||
"""
|
||||
# Idempotent: reuse existing sub-sessions
|
||||
existing = await get_sub_sessions(parent_session_id)
|
||||
if existing:
|
||||
return [
|
||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing)
|
||||
]
|
||||
|
||||
parent_name = parent_cached.get("name", "Scan")
|
||||
parent_filename = parent_cached.get("filename", "scan.png")
|
||||
|
||||
sub_sessions: List[Dict[str, Any]] = []
|
||||
|
||||
for page in page_splits:
|
||||
pi = page["page_index"]
|
||||
px, py = page["x"], page["y"]
|
||||
pw, ph = page["width"], page["height"]
|
||||
|
||||
# Extract RAW page region — NO individual cropping here; each
|
||||
# sub-session will run its own crop step after deskew + dewarp.
|
||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||
|
||||
# Encode as PNG
|
||||
ok, png_buf = cv2.imencode(".png", page_bgr)
|
||||
page_png = png_buf.tobytes() if ok else b""
|
||||
|
||||
sub_id = str(uuid_mod.uuid4())
|
||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=parent_filename,
|
||||
original_png=page_png,
|
||||
)
|
||||
|
||||
# start_step=2 → ready for deskew (orientation already done on spread)
|
||||
# start_step=1 → needs its own orientation (split from original image)
|
||||
await update_session_db(sub_id, current_step=start_step)
|
||||
|
||||
# Cache the BGR so the pipeline can start immediately
|
||||
_cache[sub_id] = {
|
||||
"id": sub_id,
|
||||
"filename": parent_filename,
|
||||
"name": sub_name,
|
||||
"original_bgr": page_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": start_step,
|
||||
}
|
||||
|
||||
rh, rw = page_bgr.shape[:2]
|
||||
sub_sessions.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"page_index": pi,
|
||||
"source_rect": page,
|
||||
"image_size": {"width": rw, "height": rh},
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d → %dx%d",
|
||||
sub_id, pi + 1, px, pw, rw, rh,
|
||||
)
|
||||
|
||||
return sub_sessions
|
||||
|
||||
|
||||
class ManualCropRequest(BaseModel):
|
||||
x: float # percentage 0-100
|
||||
y: float # percentage 0-100
|
||||
width: float # percentage 0-100
|
||||
height: float # percentage 0-100
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/crop/manual")
|
||||
async def manual_crop(session_id: str, req: ManualCropRequest):
|
||||
"""Manually crop using percentage coordinates."""
|
||||
cached = await _ensure_cached(session_id)
|
||||
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available for cropping")
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Convert percentages to pixels
|
||||
px_x = int(w * req.x / 100.0)
|
||||
px_y = int(h * req.y / 100.0)
|
||||
px_w = int(w * req.width / 100.0)
|
||||
px_h = int(h * req.height / 100.0)
|
||||
|
||||
# Clamp
|
||||
px_x = max(0, min(px_x, w - 1))
|
||||
px_y = max(0, min(px_y, h - 1))
|
||||
px_w = max(1, min(px_w, w - px_x))
|
||||
px_h = max(1, min(px_h, h - px_y))
|
||||
|
||||
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
|
||||
|
||||
success, png_buf = cv2.imencode(".png", cropped_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
crop_result = {
|
||||
"crop_applied": True,
|
||||
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
|
||||
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
|
||||
"width": round(req.width, 2), "height": round(req.height, 2)},
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": px_w, "height": px_h},
|
||||
"method": "manual",
|
||||
}
|
||||
|
||||
cached["cropped_bgr"] = cropped_bgr
|
||||
cached["crop_result"] = crop_result
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_result,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
ch, cw = cropped_bgr.shape[:2]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": cw,
|
||||
"image_height": ch,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/crop/skip")
|
||||
async def skip_crop(session_id: str):
|
||||
"""Skip cropping — use dewarped (or oriented/original) image as-is."""
|
||||
cached = await _ensure_cached(session_id)
|
||||
|
||||
img_bgr = next(
|
||||
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
|
||||
if (v := cached.get(k)) is not None),
|
||||
None,
|
||||
)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="No image available")
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Store the dewarped image as cropped (identity crop)
|
||||
success, png_buf = cv2.imencode(".png", img_bgr)
|
||||
cropped_png = png_buf.tobytes() if success else b""
|
||||
|
||||
crop_result = {
|
||||
"crop_applied": False,
|
||||
"skipped": True,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
}
|
||||
|
||||
cached["cropped_bgr"] = img_bgr
|
||||
cached["crop_result"] = crop_result
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
cropped_png=cropped_png,
|
||||
crop_result=crop_result,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**crop_result,
|
||||
"image_width": w,
|
||||
"image_height": h,
|
||||
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
|
||||
}
|
||||
router = APIRouter()
|
||||
router.include_router(_orientation_router)
|
||||
router.include_router(_crop_router)
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Orientation & Crop shared helpers - cache management and pipeline logging.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
|
||||
_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
|
||||
"""Set reference to the shared cache from ocr_pipeline_api."""
|
||||
global _cache
|
||||
_cache = cache
|
||||
|
||||
|
||||
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get reference to the shared cache."""
|
||||
return _cache
|
||||
|
||||
|
||||
async def ensure_cached(session_id: str) -> Dict[str, Any]:
|
||||
"""Ensure session is in cache, loading from DB if needed."""
|
||||
if session_id in _cache:
|
||||
return _cache[session_id]
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
cache_entry: Dict[str, Any] = {
|
||||
"id": session_id,
|
||||
**session,
|
||||
"original_bgr": None,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
}
|
||||
|
||||
for img_type, bgr_key in [
|
||||
("original", "original_bgr"),
|
||||
("oriented", "oriented_bgr"),
|
||||
("cropped", "cropped_bgr"),
|
||||
("deskewed", "deskewed_bgr"),
|
||||
("dewarped", "dewarped_bgr"),
|
||||
]:
|
||||
png_data = await get_session_image(session_id, img_type)
|
||||
if png_data:
|
||||
arr = np.frombuffer(png_data, dtype=np.uint8)
|
||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
cache_entry[bgr_key] = bgr
|
||||
|
||||
_cache[session_id] = cache_entry
|
||||
return cache_entry
|
||||
|
||||
|
||||
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
|
||||
"""Append a step entry to the pipeline log."""
|
||||
from datetime import datetime
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
return
|
||||
pipeline_log = session.get("pipeline_log") or {"steps": []}
|
||||
pipeline_log["steps"].append({
|
||||
"step": step,
|
||||
"completed_at": datetime.utcnow().isoformat(),
|
||||
"success": True,
|
||||
"duration_ms": duration_ms,
|
||||
"metrics": metrics,
|
||||
})
|
||||
await update_session_db(session_id, pipeline_log=pipeline_log)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Sub-session creation for multi-page spreads.
|
||||
|
||||
Used by both the page-split and crop steps when a double-page scan is detected.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_mod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from page_crop import detect_and_crop_page
|
||||
from ocr_pipeline_session_store import (
|
||||
create_session_db,
|
||||
get_sub_sessions,
|
||||
update_session_db,
|
||||
)
|
||||
from orientation_crop_helpers import get_cache_ref
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_page_sub_sessions(
|
||||
parent_session_id: str,
|
||||
parent_cached: dict,
|
||||
full_img_bgr: np.ndarray,
|
||||
page_splits: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create sub-sessions for each detected page in a multi-page spread.
|
||||
|
||||
Each page region is individually cropped, then stored as a sub-session
|
||||
with its own cropped image ready for the rest of the pipeline.
|
||||
"""
|
||||
# Check for existing sub-sessions (idempotent)
|
||||
existing = await get_sub_sessions(parent_session_id)
|
||||
if existing:
|
||||
return [
|
||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing)
|
||||
]
|
||||
|
||||
parent_name = parent_cached.get("name", "Scan")
|
||||
parent_filename = parent_cached.get("filename", "scan.png")
|
||||
|
||||
sub_sessions: List[Dict[str, Any]] = []
|
||||
|
||||
for page in page_splits:
|
||||
pi = page["page_index"]
|
||||
px, py = page["x"], page["y"]
|
||||
pw, ph = page["width"], page["height"]
|
||||
|
||||
# Extract page region
|
||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||
|
||||
# Crop each page individually (remove its own borders)
|
||||
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
|
||||
|
||||
# Encode as PNG
|
||||
ok, png_buf = cv2.imencode(".png", cropped_page)
|
||||
page_png = png_buf.tobytes() if ok else b""
|
||||
|
||||
sub_id = str(uuid_mod.uuid4())
|
||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=parent_filename,
|
||||
original_png=page_png,
|
||||
)
|
||||
|
||||
# Pre-populate: set cropped = original (already cropped)
|
||||
await update_session_db(
|
||||
sub_id,
|
||||
cropped_png=page_png,
|
||||
crop_result=page_crop_info,
|
||||
current_step=5,
|
||||
)
|
||||
|
||||
ch, cw = cropped_page.shape[:2]
|
||||
sub_sessions.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"page_index": pi,
|
||||
"source_rect": page,
|
||||
"cropped_size": {"width": cw, "height": ch},
|
||||
"detected_format": page_crop_info.get("detected_format"),
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
|
||||
sub_id, pi + 1, px, pw, cw, ch,
|
||||
)
|
||||
|
||||
return sub_sessions
|
||||
|
||||
|
||||
async def create_page_sub_sessions_full(
|
||||
parent_session_id: str,
|
||||
parent_cached: dict,
|
||||
full_img_bgr: np.ndarray,
|
||||
page_splits: List[Dict[str, Any]],
|
||||
start_step: int = 2,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
|
||||
|
||||
Unlike ``create_page_sub_sessions`` (used by the crop step), these
|
||||
sub-sessions store the *uncropped* page region and start at
|
||||
``start_step`` (default 2 = ready for deskew; 1 if orientation still
|
||||
needed). Each page goes through its own pipeline independently,
|
||||
which is essential for book spreads where each page has a different tilt.
|
||||
"""
|
||||
_cache = get_cache_ref()
|
||||
|
||||
# Idempotent: reuse existing sub-sessions
|
||||
existing = await get_sub_sessions(parent_session_id)
|
||||
if existing:
|
||||
return [
|
||||
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
|
||||
for i, s in enumerate(existing)
|
||||
]
|
||||
|
||||
parent_name = parent_cached.get("name", "Scan")
|
||||
parent_filename = parent_cached.get("filename", "scan.png")
|
||||
|
||||
sub_sessions: List[Dict[str, Any]] = []
|
||||
|
||||
for page in page_splits:
|
||||
pi = page["page_index"]
|
||||
px, py = page["x"], page["y"]
|
||||
pw, ph = page["width"], page["height"]
|
||||
|
||||
# Extract RAW page region — NO individual cropping here; each
|
||||
# sub-session will run its own crop step after deskew + dewarp.
|
||||
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
|
||||
|
||||
# Encode as PNG
|
||||
ok, png_buf = cv2.imencode(".png", page_bgr)
|
||||
page_png = png_buf.tobytes() if ok else b""
|
||||
|
||||
sub_id = str(uuid_mod.uuid4())
|
||||
sub_name = f"{parent_name} — Seite {pi + 1}"
|
||||
|
||||
await create_session_db(
|
||||
session_id=sub_id,
|
||||
name=sub_name,
|
||||
filename=parent_filename,
|
||||
original_png=page_png,
|
||||
)
|
||||
|
||||
# start_step=2 -> ready for deskew (orientation already done on spread)
|
||||
# start_step=1 -> needs its own orientation (split from original image)
|
||||
await update_session_db(sub_id, current_step=start_step)
|
||||
|
||||
# Cache the BGR so the pipeline can start immediately
|
||||
_cache[sub_id] = {
|
||||
"id": sub_id,
|
||||
"filename": parent_filename,
|
||||
"name": sub_name,
|
||||
"original_bgr": page_bgr,
|
||||
"oriented_bgr": None,
|
||||
"cropped_bgr": None,
|
||||
"deskewed_bgr": None,
|
||||
"dewarped_bgr": None,
|
||||
"orientation_result": None,
|
||||
"crop_result": None,
|
||||
"deskew_result": None,
|
||||
"dewarp_result": None,
|
||||
"ground_truth": {},
|
||||
"current_step": start_step,
|
||||
}
|
||||
|
||||
rh, rw = page_bgr.shape[:2]
|
||||
sub_sessions.append({
|
||||
"id": sub_id,
|
||||
"name": sub_name,
|
||||
"page_index": pi,
|
||||
"source_rect": page,
|
||||
"image_size": {"width": rw, "height": rh},
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
|
||||
sub_id, pi + 1, px, pw, rw, rh,
|
||||
)
|
||||
|
||||
return sub_sessions
|
||||
@@ -1,677 +1,17 @@
|
||||
"""
|
||||
PDF Export Module for Abiturkorrektur System
|
||||
|
||||
Generates:
|
||||
- Individual Gutachten PDFs for each student
|
||||
- Klausur overview PDFs with grade distribution
|
||||
- Niedersachsen-compliant formatting
|
||||
Barrel re-export: all PDF generation functions and constants.
|
||||
"""
|
||||
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.units import cm, mm
|
||||
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY, TA_RIGHT
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
|
||||
PageBreak, HRFlowable, Image, KeepTogether
|
||||
from pdf_export_styles import ( # noqa: F401
|
||||
GRADE_POINTS_TO_NOTE,
|
||||
CRITERIA_DISPLAY_NAMES,
|
||||
CRITERIA_WEIGHTS,
|
||||
get_custom_styles,
|
||||
)
|
||||
from pdf_export_gutachten import generate_gutachten_pdf # noqa: F401
|
||||
from pdf_export_overview import ( # noqa: F401
|
||||
generate_klausur_overview_pdf,
|
||||
generate_annotations_pdf,
|
||||
)
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
|
||||
|
||||
# =============================================
|
||||
# CONSTANTS
|
||||
# =============================================
|
||||
|
||||
GRADE_POINTS_TO_NOTE = {
|
||||
15: "1+", 14: "1", 13: "1-",
|
||||
12: "2+", 11: "2", 10: "2-",
|
||||
9: "3+", 8: "3", 7: "3-",
|
||||
6: "4+", 5: "4", 4: "4-",
|
||||
3: "5+", 2: "5", 1: "5-",
|
||||
0: "6"
|
||||
}
|
||||
|
||||
CRITERIA_DISPLAY_NAMES = {
|
||||
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
|
||||
"inhalt": "Inhaltliche Leistung",
|
||||
"struktur": "Aufbau und Struktur",
|
||||
"stil": "Ausdruck und Stil"
|
||||
}
|
||||
|
||||
CRITERIA_WEIGHTS = {
|
||||
"rechtschreibung": 15,
|
||||
"grammatik": 15,
|
||||
"inhalt": 40,
|
||||
"struktur": 15,
|
||||
"stil": 15
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# STYLES
|
||||
# =============================================
|
||||
|
||||
def get_custom_styles():
|
||||
"""Create custom paragraph styles for Gutachten."""
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Title style
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenTitle',
|
||||
parent=styles['Heading1'],
|
||||
fontSize=16,
|
||||
spaceAfter=12,
|
||||
alignment=TA_CENTER,
|
||||
textColor=colors.HexColor('#1e3a5f')
|
||||
))
|
||||
|
||||
# Subtitle style
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenSubtitle',
|
||||
parent=styles['Heading2'],
|
||||
fontSize=12,
|
||||
spaceAfter=8,
|
||||
spaceBefore=16,
|
||||
textColor=colors.HexColor('#2c5282')
|
||||
))
|
||||
|
||||
# Section header
|
||||
styles.add(ParagraphStyle(
|
||||
name='SectionHeader',
|
||||
parent=styles['Heading3'],
|
||||
fontSize=11,
|
||||
spaceAfter=6,
|
||||
spaceBefore=12,
|
||||
textColor=colors.HexColor('#2d3748'),
|
||||
borderColor=colors.HexColor('#e2e8f0'),
|
||||
borderWidth=0,
|
||||
borderPadding=0
|
||||
))
|
||||
|
||||
# Body text
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenBody',
|
||||
parent=styles['Normal'],
|
||||
fontSize=10,
|
||||
leading=14,
|
||||
alignment=TA_JUSTIFY,
|
||||
spaceAfter=6
|
||||
))
|
||||
|
||||
# Small text for footer/meta
|
||||
styles.add(ParagraphStyle(
|
||||
name='MetaText',
|
||||
parent=styles['Normal'],
|
||||
fontSize=8,
|
||||
textColor=colors.grey,
|
||||
alignment=TA_LEFT
|
||||
))
|
||||
|
||||
# List item
|
||||
styles.add(ParagraphStyle(
|
||||
name='ListItem',
|
||||
parent=styles['Normal'],
|
||||
fontSize=10,
|
||||
leftIndent=20,
|
||||
bulletIndent=10,
|
||||
spaceAfter=4
|
||||
))
|
||||
|
||||
return styles
|
||||
|
||||
|
||||
# =============================================
|
||||
# PDF GENERATION FUNCTIONS
|
||||
# =============================================
|
||||
|
||||
def generate_gutachten_pdf(
|
||||
student_data: Dict[str, Any],
|
||||
klausur_data: Dict[str, Any],
|
||||
annotations: List[Dict[str, Any]] = None,
|
||||
workflow_data: Dict[str, Any] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a PDF Gutachten for a single student.
|
||||
|
||||
Args:
|
||||
student_data: Student work data including criteria_scores, gutachten, grade_points
|
||||
klausur_data: Klausur metadata (title, subject, year, etc.)
|
||||
annotations: List of annotations for annotation summary
|
||||
workflow_data: Examiner workflow data (EK, ZK, DK info)
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=2*cm,
|
||||
leftMargin=2*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Meta information table
|
||||
meta_data = [
|
||||
["Pruefling:", student_data.get('student_name', 'Anonym')],
|
||||
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
|
||||
["Kurs:", klausur_data.get('semester', 'Abitur')],
|
||||
["Datum:", datetime.now().strftime("%d.%m.%Y")]
|
||||
]
|
||||
|
||||
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
|
||||
meta_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 10),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(meta_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Gutachten content
|
||||
gutachten = student_data.get('gutachten', {})
|
||||
|
||||
if gutachten:
|
||||
# Einleitung
|
||||
if gutachten.get('einleitung'):
|
||||
story.append(Paragraph("Einleitung", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
# Hauptteil
|
||||
if gutachten.get('hauptteil'):
|
||||
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
# Fazit
|
||||
if gutachten.get('fazit'):
|
||||
story.append(Paragraph("Fazit", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
# Staerken und Schwaechen
|
||||
if gutachten.get('staerken') or gutachten.get('schwaechen'):
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
if gutachten.get('staerken'):
|
||||
story.append(Paragraph("Staerken:", styles['SectionHeader']))
|
||||
for s in gutachten['staerken']:
|
||||
story.append(Paragraph(f"• {s}", styles['ListItem']))
|
||||
|
||||
if gutachten.get('schwaechen'):
|
||||
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
|
||||
for s in gutachten['schwaechen']:
|
||||
story.append(Paragraph(f"• {s}", styles['ListItem']))
|
||||
else:
|
||||
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
|
||||
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Bewertungstabelle
|
||||
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
criteria_scores = student_data.get('criteria_scores', {})
|
||||
|
||||
# Build criteria table data
|
||||
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
|
||||
total_weighted = 0
|
||||
total_weight = 0
|
||||
|
||||
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
|
||||
weight = CRITERIA_WEIGHTS.get(key, 0)
|
||||
score_data = criteria_scores.get(key, {})
|
||||
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
|
||||
|
||||
# Calculate weighted contribution
|
||||
weighted_score = (score / 100) * weight if score else 0
|
||||
total_weighted += weighted_score
|
||||
total_weight += weight
|
||||
|
||||
table_data.append([
|
||||
display_name,
|
||||
f"{weight}%",
|
||||
f"{score}%",
|
||||
f"{weighted_score:.1f}"
|
||||
])
|
||||
|
||||
# Add total row
|
||||
table_data.append([
|
||||
"Gesamt",
|
||||
f"{total_weight}%",
|
||||
"",
|
||||
f"{total_weighted:.1f}"
|
||||
])
|
||||
|
||||
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
|
||||
criteria_table.setStyle(TableStyle([
|
||||
# Header row
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 10),
|
||||
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
|
||||
# Body rows
|
||||
('FONTSIZE', (0, 1), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
# Grid
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
# Total row
|
||||
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
|
||||
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
|
||||
# Alternating row colors
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
|
||||
]))
|
||||
story.append(criteria_table)
|
||||
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Final grade box
|
||||
grade_points = student_data.get('grade_points', 0)
|
||||
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
|
||||
raw_points = student_data.get('raw_points', 0)
|
||||
|
||||
grade_data = [
|
||||
["Rohpunkte:", f"{raw_points} / 100"],
|
||||
["Notenpunkte:", f"{grade_points} Punkte"],
|
||||
["Note:", grade_note]
|
||||
]
|
||||
|
||||
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
|
||||
grade_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 11),
|
||||
('FONTSIZE', (1, -1), (1, -1), 14),
|
||||
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 8),
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 12),
|
||||
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
|
||||
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
|
||||
]))
|
||||
|
||||
story.append(KeepTogether([
|
||||
Paragraph("Endergebnis", styles['SectionHeader']),
|
||||
Spacer(1, 0.2*cm),
|
||||
grade_table
|
||||
]))
|
||||
|
||||
# Examiner workflow information
|
||||
if workflow_data:
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
|
||||
|
||||
workflow_rows = []
|
||||
|
||||
if workflow_data.get('erst_korrektor'):
|
||||
ek = workflow_data['erst_korrektor']
|
||||
workflow_rows.append([
|
||||
"Erstkorrektor:",
|
||||
ek.get('name', 'Unbekannt'),
|
||||
f"{ek.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('zweit_korrektor'):
|
||||
zk = workflow_data['zweit_korrektor']
|
||||
workflow_rows.append([
|
||||
"Zweitkorrektor:",
|
||||
zk.get('name', 'Unbekannt'),
|
||||
f"{zk.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('dritt_korrektor'):
|
||||
dk = workflow_data['dritt_korrektor']
|
||||
workflow_rows.append([
|
||||
"Drittkorrektor:",
|
||||
dk.get('name', 'Unbekannt'),
|
||||
f"{dk.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('final_grade_source'):
|
||||
workflow_rows.append([
|
||||
"Endnote durch:",
|
||||
workflow_data['final_grade_source'],
|
||||
""
|
||||
])
|
||||
|
||||
if workflow_rows:
|
||||
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
|
||||
workflow_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(workflow_table)
|
||||
|
||||
# Annotation summary (if any)
|
||||
if annotations:
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
|
||||
|
||||
# Group annotations by type
|
||||
by_type = {}
|
||||
for ann in annotations:
|
||||
ann_type = ann.get('type', 'comment')
|
||||
if ann_type not in by_type:
|
||||
by_type[ann_type] = []
|
||||
by_type[ann_type].append(ann)
|
||||
|
||||
for ann_type, anns in by_type.items():
|
||||
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
|
||||
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
|
||||
|
||||
# Footer with generation info
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def generate_klausur_overview_pdf(
|
||||
klausur_data: Dict[str, Any],
|
||||
students: List[Dict[str, Any]],
|
||||
fairness_data: Optional[Dict[str, Any]] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate an overview PDF for an entire Klausur with all student grades.
|
||||
|
||||
Args:
|
||||
klausur_data: Klausur metadata
|
||||
students: List of all student work data
|
||||
fairness_data: Optional fairness analysis data
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=1.5*cm,
|
||||
leftMargin=1.5*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Meta information
|
||||
meta_data = [
|
||||
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
|
||||
["Kurs:", klausur_data.get('semester', 'Abitur')],
|
||||
["Anzahl Arbeiten:", str(len(students))],
|
||||
["Stand:", datetime.now().strftime("%d.%m.%Y")]
|
||||
]
|
||||
|
||||
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
|
||||
meta_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 10),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(meta_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Statistics (if fairness data available)
|
||||
if fairness_data and fairness_data.get('statistics'):
|
||||
stats = fairness_data['statistics']
|
||||
story.append(Paragraph("Statistik", styles['SectionHeader']))
|
||||
|
||||
stats_data = [
|
||||
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
|
||||
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
|
||||
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
|
||||
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
|
||||
]
|
||||
|
||||
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
|
||||
stats_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
|
||||
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
]))
|
||||
story.append(stats_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Student grades table
|
||||
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
# Sort students by grade (descending)
|
||||
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
|
||||
|
||||
# Build table header
|
||||
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
|
||||
|
||||
for idx, student in enumerate(sorted_students, 1):
|
||||
grade_points = student.get('grade_points', 0)
|
||||
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
|
||||
raw_points = student.get('raw_points', 0)
|
||||
status = student.get('status', 'unknown')
|
||||
|
||||
# Format status
|
||||
status_display = {
|
||||
'completed': 'Abgeschlossen',
|
||||
'first_examiner': 'In Korrektur',
|
||||
'second_examiner': 'Zweitkorrektur',
|
||||
'uploaded': 'Hochgeladen',
|
||||
'ocr_complete': 'OCR fertig',
|
||||
'analyzing': 'Wird analysiert'
|
||||
}.get(status, status)
|
||||
|
||||
table_data.append([
|
||||
str(idx),
|
||||
student.get('student_name', 'Anonym'),
|
||||
f"{raw_points}/100",
|
||||
str(grade_points),
|
||||
grade_note,
|
||||
status_display
|
||||
])
|
||||
|
||||
# Create table
|
||||
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
|
||||
student_table.setStyle(TableStyle([
|
||||
# Header
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 9),
|
||||
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
|
||||
# Body
|
||||
('FONTSIZE', (0, 1), (-1, -1), 9),
|
||||
('ALIGN', (0, 1), (0, -1), 'CENTER'),
|
||||
('ALIGN', (2, 1), (4, -1), 'CENTER'),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
# Grid
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
# Alternating rows
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
|
||||
]))
|
||||
story.append(student_table)
|
||||
|
||||
# Grade distribution
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
# Count grades
|
||||
grade_counts = {}
|
||||
for student in sorted_students:
|
||||
gp = student.get('grade_points', 0)
|
||||
grade_counts[gp] = grade_counts.get(gp, 0) + 1
|
||||
|
||||
# Build grade distribution table
|
||||
dist_data = [["Punkte", "Note", "Anzahl"]]
|
||||
for points in range(15, -1, -1):
|
||||
if points in grade_counts:
|
||||
note = GRADE_POINTS_TO_NOTE.get(points, "-")
|
||||
count = grade_counts[points]
|
||||
dist_data.append([str(points), note, str(count)])
|
||||
|
||||
if len(dist_data) > 1:
|
||||
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
|
||||
dist_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
]))
|
||||
story.append(dist_table)
|
||||
|
||||
# Footer
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def generate_annotations_pdf(
|
||||
student_data: Dict[str, Any],
|
||||
klausur_data: Dict[str, Any],
|
||||
annotations: List[Dict[str, Any]]
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a PDF with all annotations for a student work.
|
||||
|
||||
Args:
|
||||
student_data: Student work data
|
||||
klausur_data: Klausur metadata
|
||||
annotations: List of all annotations
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=2*cm,
|
||||
leftMargin=2*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
if not annotations:
|
||||
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
|
||||
else:
|
||||
# Group by type
|
||||
by_type = {}
|
||||
for ann in annotations:
|
||||
ann_type = ann.get('type', 'comment')
|
||||
if ann_type not in by_type:
|
||||
by_type[ann_type] = []
|
||||
by_type[ann_type].append(ann)
|
||||
|
||||
for ann_type, anns in by_type.items():
|
||||
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
|
||||
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
# Sort by page then position
|
||||
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
|
||||
|
||||
for idx, ann in enumerate(sorted_anns, 1):
|
||||
page = ann.get('page', 1)
|
||||
text = ann.get('text', '')
|
||||
suggestion = ann.get('suggestion', '')
|
||||
severity = ann.get('severity', 'minor')
|
||||
|
||||
# Build annotation text
|
||||
ann_text = f"<b>[S.{page}]</b> {text}"
|
||||
if suggestion:
|
||||
ann_text += f" → <i>{suggestion}</i>"
|
||||
|
||||
# Color code by severity
|
||||
if severity == 'critical':
|
||||
ann_text = f"<font color='red'>{ann_text}</font>"
|
||||
elif severity == 'major':
|
||||
ann_text = f"<font color='orange'>{ann_text}</font>"
|
||||
|
||||
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
|
||||
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
# Footer
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
PDF Export - Individual Gutachten PDF generation.
|
||||
|
||||
Generates a single student's Gutachten with criteria table,
|
||||
workflow info, and annotation summary.
|
||||
"""
|
||||
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.units import cm
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
|
||||
HRFlowable, KeepTogether
|
||||
)
|
||||
|
||||
from pdf_export_styles import (
|
||||
GRADE_POINTS_TO_NOTE,
|
||||
CRITERIA_DISPLAY_NAMES,
|
||||
CRITERIA_WEIGHTS,
|
||||
get_custom_styles,
|
||||
)
|
||||
|
||||
|
||||
def generate_gutachten_pdf(
|
||||
student_data: Dict[str, Any],
|
||||
klausur_data: Dict[str, Any],
|
||||
annotations: List[Dict[str, Any]] = None,
|
||||
workflow_data: Dict[str, Any] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a PDF Gutachten for a single student.
|
||||
|
||||
Args:
|
||||
student_data: Student work data including criteria_scores, gutachten, grade_points
|
||||
klausur_data: Klausur metadata (title, subject, year, etc.)
|
||||
annotations: List of annotations for annotation summary
|
||||
workflow_data: Examiner workflow data (EK, ZK, DK info)
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=2*cm,
|
||||
leftMargin=2*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Meta information table
|
||||
meta_data = [
|
||||
["Pruefling:", student_data.get('student_name', 'Anonym')],
|
||||
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
|
||||
["Kurs:", klausur_data.get('semester', 'Abitur')],
|
||||
["Datum:", datetime.now().strftime("%d.%m.%Y")]
|
||||
]
|
||||
|
||||
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
|
||||
meta_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 10),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(meta_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Gutachten content
|
||||
_add_gutachten_content(story, styles, student_data)
|
||||
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Bewertungstabelle
|
||||
_add_criteria_table(story, styles, student_data)
|
||||
|
||||
# Final grade box
|
||||
_add_grade_box(story, styles, student_data)
|
||||
|
||||
# Examiner workflow information
|
||||
if workflow_data:
|
||||
_add_workflow_info(story, styles, workflow_data)
|
||||
|
||||
# Annotation summary
|
||||
if annotations:
|
||||
_add_annotation_summary(story, styles, annotations)
|
||||
|
||||
# Footer
|
||||
_add_footer(story, styles)
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def _add_gutachten_content(story, styles, student_data):
|
||||
"""Add gutachten text sections to the story."""
|
||||
gutachten = student_data.get('gutachten', {})
|
||||
|
||||
if gutachten:
|
||||
if gutachten.get('einleitung'):
|
||||
story.append(Paragraph("Einleitung", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
if gutachten.get('hauptteil'):
|
||||
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
if gutachten.get('fazit'):
|
||||
story.append(Paragraph("Fazit", styles['SectionHeader']))
|
||||
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
if gutachten.get('staerken') or gutachten.get('schwaechen'):
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
if gutachten.get('staerken'):
|
||||
story.append(Paragraph("Staerken:", styles['SectionHeader']))
|
||||
for s in gutachten['staerken']:
|
||||
story.append(Paragraph(f"• {s}", styles['ListItem']))
|
||||
|
||||
if gutachten.get('schwaechen'):
|
||||
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
|
||||
for s in gutachten['schwaechen']:
|
||||
story.append(Paragraph(f"• {s}", styles['ListItem']))
|
||||
else:
|
||||
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
|
||||
|
||||
|
||||
def _add_criteria_table(story, styles, student_data):
|
||||
"""Add criteria scoring table to the story."""
|
||||
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
criteria_scores = student_data.get('criteria_scores', {})
|
||||
|
||||
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
|
||||
total_weighted = 0
|
||||
total_weight = 0
|
||||
|
||||
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
|
||||
weight = CRITERIA_WEIGHTS.get(key, 0)
|
||||
score_data = criteria_scores.get(key, {})
|
||||
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
|
||||
|
||||
weighted_score = (score / 100) * weight if score else 0
|
||||
total_weighted += weighted_score
|
||||
total_weight += weight
|
||||
|
||||
table_data.append([
|
||||
display_name,
|
||||
f"{weight}%",
|
||||
f"{score}%",
|
||||
f"{weighted_score:.1f}"
|
||||
])
|
||||
|
||||
table_data.append([
|
||||
"Gesamt",
|
||||
f"{total_weight}%",
|
||||
"",
|
||||
f"{total_weighted:.1f}"
|
||||
])
|
||||
|
||||
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
|
||||
criteria_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 10),
|
||||
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
|
||||
('FONTSIZE', (0, 1), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
|
||||
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
|
||||
]))
|
||||
story.append(criteria_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
|
||||
def _add_grade_box(story, styles, student_data):
|
||||
"""Add final grade box to the story."""
|
||||
grade_points = student_data.get('grade_points', 0)
|
||||
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
|
||||
raw_points = student_data.get('raw_points', 0)
|
||||
|
||||
grade_data = [
|
||||
["Rohpunkte:", f"{raw_points} / 100"],
|
||||
["Notenpunkte:", f"{grade_points} Punkte"],
|
||||
["Note:", grade_note]
|
||||
]
|
||||
|
||||
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
|
||||
grade_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 11),
|
||||
('FONTSIZE', (1, -1), (1, -1), 14),
|
||||
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 8),
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 12),
|
||||
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
|
||||
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
|
||||
]))
|
||||
|
||||
story.append(KeepTogether([
|
||||
Paragraph("Endergebnis", styles['SectionHeader']),
|
||||
Spacer(1, 0.2*cm),
|
||||
grade_table
|
||||
]))
|
||||
|
||||
|
||||
def _add_workflow_info(story, styles, workflow_data):
|
||||
"""Add examiner workflow information to the story."""
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
|
||||
|
||||
workflow_rows = []
|
||||
|
||||
if workflow_data.get('erst_korrektor'):
|
||||
ek = workflow_data['erst_korrektor']
|
||||
workflow_rows.append([
|
||||
"Erstkorrektor:",
|
||||
ek.get('name', 'Unbekannt'),
|
||||
f"{ek.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('zweit_korrektor'):
|
||||
zk = workflow_data['zweit_korrektor']
|
||||
workflow_rows.append([
|
||||
"Zweitkorrektor:",
|
||||
zk.get('name', 'Unbekannt'),
|
||||
f"{zk.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('dritt_korrektor'):
|
||||
dk = workflow_data['dritt_korrektor']
|
||||
workflow_rows.append([
|
||||
"Drittkorrektor:",
|
||||
dk.get('name', 'Unbekannt'),
|
||||
f"{dk.get('grade_points', '-')} Punkte"
|
||||
])
|
||||
|
||||
if workflow_data.get('final_grade_source'):
|
||||
workflow_rows.append([
|
||||
"Endnote durch:",
|
||||
workflow_data['final_grade_source'],
|
||||
""
|
||||
])
|
||||
|
||||
if workflow_rows:
|
||||
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
|
||||
workflow_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(workflow_table)
|
||||
|
||||
|
||||
def _add_annotation_summary(story, styles, annotations):
|
||||
"""Add annotation summary to the story."""
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
|
||||
|
||||
by_type = {}
|
||||
for ann in annotations:
|
||||
ann_type = ann.get('type', 'comment')
|
||||
if ann_type not in by_type:
|
||||
by_type[ann_type] = []
|
||||
by_type[ann_type].append(ann)
|
||||
|
||||
for ann_type, anns in by_type.items():
|
||||
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
|
||||
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
|
||||
|
||||
|
||||
def _add_footer(story, styles):
|
||||
"""Add generation footer to the story."""
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
PDF Export - Klausur overview and annotations PDF generation.
|
||||
|
||||
Generates:
|
||||
- Klausur overview with grade distribution for all students
|
||||
- Annotations PDF for a single student
|
||||
"""
|
||||
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.units import cm
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
|
||||
HRFlowable
|
||||
)
|
||||
|
||||
from pdf_export_styles import (
|
||||
GRADE_POINTS_TO_NOTE,
|
||||
CRITERIA_DISPLAY_NAMES,
|
||||
get_custom_styles,
|
||||
)
|
||||
|
||||
|
||||
def generate_klausur_overview_pdf(
|
||||
klausur_data: Dict[str, Any],
|
||||
students: List[Dict[str, Any]],
|
||||
fairness_data: Optional[Dict[str, Any]] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate an overview PDF for an entire Klausur with all student grades.
|
||||
|
||||
Args:
|
||||
klausur_data: Klausur metadata
|
||||
students: List of all student work data
|
||||
fairness_data: Optional fairness analysis data
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=1.5*cm,
|
||||
leftMargin=1.5*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Meta information
|
||||
meta_data = [
|
||||
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
|
||||
["Kurs:", klausur_data.get('semester', 'Abitur')],
|
||||
["Anzahl Arbeiten:", str(len(students))],
|
||||
["Stand:", datetime.now().strftime("%d.%m.%Y")]
|
||||
]
|
||||
|
||||
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
|
||||
meta_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 10),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
]))
|
||||
story.append(meta_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Statistics (if fairness data available)
|
||||
if fairness_data and fairness_data.get('statistics'):
|
||||
_add_statistics(story, styles, fairness_data['statistics'])
|
||||
|
||||
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
# Student grades table
|
||||
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
|
||||
_add_student_table(story, styles, sorted_students)
|
||||
|
||||
# Grade distribution
|
||||
_add_grade_distribution(story, styles, sorted_students)
|
||||
|
||||
# Footer
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def _add_statistics(story, styles, stats):
|
||||
"""Add statistics section."""
|
||||
story.append(Paragraph("Statistik", styles['SectionHeader']))
|
||||
|
||||
stats_data = [
|
||||
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
|
||||
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
|
||||
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
|
||||
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
|
||||
]
|
||||
|
||||
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
|
||||
stats_table.setStyle(TableStyle([
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
|
||||
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
]))
|
||||
story.append(stats_table)
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
|
||||
def _add_student_table(story, styles, sorted_students):
|
||||
"""Add student grades table."""
|
||||
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
|
||||
|
||||
for idx, student in enumerate(sorted_students, 1):
|
||||
grade_points = student.get('grade_points', 0)
|
||||
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
|
||||
raw_points = student.get('raw_points', 0)
|
||||
status = student.get('status', 'unknown')
|
||||
|
||||
status_display = {
|
||||
'completed': 'Abgeschlossen',
|
||||
'first_examiner': 'In Korrektur',
|
||||
'second_examiner': 'Zweitkorrektur',
|
||||
'uploaded': 'Hochgeladen',
|
||||
'ocr_complete': 'OCR fertig',
|
||||
'analyzing': 'Wird analysiert'
|
||||
}.get(status, status)
|
||||
|
||||
table_data.append([
|
||||
str(idx),
|
||||
student.get('student_name', 'Anonym'),
|
||||
f"{raw_points}/100",
|
||||
str(grade_points),
|
||||
grade_note,
|
||||
status_display
|
||||
])
|
||||
|
||||
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
|
||||
student_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 9),
|
||||
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
|
||||
('FONTSIZE', (0, 1), (-1, -1), 9),
|
||||
('ALIGN', (0, 1), (0, -1), 'CENTER'),
|
||||
('ALIGN', (2, 1), (4, -1), 'CENTER'),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
|
||||
]))
|
||||
story.append(student_table)
|
||||
|
||||
|
||||
def _add_grade_distribution(story, styles, sorted_students):
|
||||
"""Add grade distribution table."""
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
grade_counts = {}
|
||||
for student in sorted_students:
|
||||
gp = student.get('grade_points', 0)
|
||||
grade_counts[gp] = grade_counts.get(gp, 0) + 1
|
||||
|
||||
dist_data = [["Punkte", "Note", "Anzahl"]]
|
||||
for points in range(15, -1, -1):
|
||||
if points in grade_counts:
|
||||
note = GRADE_POINTS_TO_NOTE.get(points, "-")
|
||||
count = grade_counts[points]
|
||||
dist_data.append([str(points), note, str(count)])
|
||||
|
||||
if len(dist_data) > 1:
|
||||
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
|
||||
dist_table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 4),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
|
||||
]))
|
||||
story.append(dist_table)
|
||||
|
||||
|
||||
def generate_annotations_pdf(
|
||||
student_data: Dict[str, Any],
|
||||
klausur_data: Dict[str, Any],
|
||||
annotations: List[Dict[str, Any]]
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a PDF with all annotations for a student work.
|
||||
|
||||
Args:
|
||||
student_data: Student work data
|
||||
klausur_data: Klausur metadata
|
||||
annotations: List of all annotations
|
||||
|
||||
Returns:
|
||||
PDF as bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
doc = SimpleDocTemplate(
|
||||
buffer,
|
||||
pagesize=A4,
|
||||
rightMargin=2*cm,
|
||||
leftMargin=2*cm,
|
||||
topMargin=2*cm,
|
||||
bottomMargin=2*cm
|
||||
)
|
||||
|
||||
styles = get_custom_styles()
|
||||
story = []
|
||||
|
||||
# Header
|
||||
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
|
||||
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
|
||||
story.append(Spacer(1, 0.5*cm))
|
||||
|
||||
if not annotations:
|
||||
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
|
||||
else:
|
||||
# Group by type
|
||||
by_type = {}
|
||||
for ann in annotations:
|
||||
ann_type = ann.get('type', 'comment')
|
||||
if ann_type not in by_type:
|
||||
by_type[ann_type] = []
|
||||
by_type[ann_type].append(ann)
|
||||
|
||||
for ann_type, anns in by_type.items():
|
||||
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
|
||||
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
|
||||
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
|
||||
|
||||
for idx, ann in enumerate(sorted_anns, 1):
|
||||
page = ann.get('page', 1)
|
||||
text = ann.get('text', '')
|
||||
suggestion = ann.get('suggestion', '')
|
||||
severity = ann.get('severity', 'minor')
|
||||
|
||||
ann_text = f"<b>[S.{page}]</b> {text}"
|
||||
if suggestion:
|
||||
ann_text += f" -> <i>{suggestion}</i>"
|
||||
|
||||
if severity == 'critical':
|
||||
ann_text = f"<font color='red'>{ann_text}</font>"
|
||||
elif severity == 'major':
|
||||
ann_text = f"<font color='orange'>{ann_text}</font>"
|
||||
|
||||
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
|
||||
|
||||
story.append(Spacer(1, 0.3*cm))
|
||||
|
||||
# Footer
|
||||
story.append(Spacer(1, 1*cm))
|
||||
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
|
||||
story.append(Spacer(1, 0.2*cm))
|
||||
story.append(Paragraph(
|
||||
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
|
||||
styles['MetaText']
|
||||
))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
PDF Export - Constants and ReportLab styles for Abiturkorrektur PDFs.
|
||||
"""
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
|
||||
|
||||
# =============================================
|
||||
# CONSTANTS
|
||||
# =============================================
|
||||
|
||||
GRADE_POINTS_TO_NOTE = {
|
||||
15: "1+", 14: "1", 13: "1-",
|
||||
12: "2+", 11: "2", 10: "2-",
|
||||
9: "3+", 8: "3", 7: "3-",
|
||||
6: "4+", 5: "4", 4: "4-",
|
||||
3: "5+", 2: "5", 1: "5-",
|
||||
0: "6"
|
||||
}
|
||||
|
||||
CRITERIA_DISPLAY_NAMES = {
|
||||
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
|
||||
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
|
||||
"inhalt": "Inhaltliche Leistung",
|
||||
"struktur": "Aufbau und Struktur",
|
||||
"stil": "Ausdruck und Stil"
|
||||
}
|
||||
|
||||
CRITERIA_WEIGHTS = {
|
||||
"rechtschreibung": 15,
|
||||
"grammatik": 15,
|
||||
"inhalt": 40,
|
||||
"struktur": 15,
|
||||
"stil": 15
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# STYLES
|
||||
# =============================================
|
||||
|
||||
def get_custom_styles():
|
||||
"""Create custom paragraph styles for Gutachten."""
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Title style
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenTitle',
|
||||
parent=styles['Heading1'],
|
||||
fontSize=16,
|
||||
spaceAfter=12,
|
||||
alignment=TA_CENTER,
|
||||
textColor=colors.HexColor('#1e3a5f')
|
||||
))
|
||||
|
||||
# Subtitle style
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenSubtitle',
|
||||
parent=styles['Heading2'],
|
||||
fontSize=12,
|
||||
spaceAfter=8,
|
||||
spaceBefore=16,
|
||||
textColor=colors.HexColor('#2c5282')
|
||||
))
|
||||
|
||||
# Section header
|
||||
styles.add(ParagraphStyle(
|
||||
name='SectionHeader',
|
||||
parent=styles['Heading3'],
|
||||
fontSize=11,
|
||||
spaceAfter=6,
|
||||
spaceBefore=12,
|
||||
textColor=colors.HexColor('#2d3748'),
|
||||
borderColor=colors.HexColor('#e2e8f0'),
|
||||
borderWidth=0,
|
||||
borderPadding=0
|
||||
))
|
||||
|
||||
# Body text
|
||||
styles.add(ParagraphStyle(
|
||||
name='GutachtenBody',
|
||||
parent=styles['Normal'],
|
||||
fontSize=10,
|
||||
leading=14,
|
||||
alignment=TA_JUSTIFY,
|
||||
spaceAfter=6
|
||||
))
|
||||
|
||||
# Small text for footer/meta
|
||||
styles.add(ParagraphStyle(
|
||||
name='MetaText',
|
||||
parent=styles['Normal'],
|
||||
fontSize=8,
|
||||
textColor=colors.grey,
|
||||
alignment=TA_LEFT
|
||||
))
|
||||
|
||||
# List item
|
||||
styles.add(ParagraphStyle(
|
||||
name='ListItem',
|
||||
parent=styles['Normal'],
|
||||
fontSize=10,
|
||||
leftIndent=20,
|
||||
bulletIndent=10,
|
||||
spaceAfter=4
|
||||
))
|
||||
|
||||
return styles
|
||||
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Qdrant Vector Database Service — QdrantService class for NiBiS Ingestion Pipeline.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
|
||||
|
||||
from qdrant_core import QDRANT_URL, VECTOR_SIZE
|
||||
|
||||
|
||||
class QdrantService:
|
||||
"""
|
||||
Class-based Qdrant service for flexible collection management.
|
||||
Used by nibis_ingestion.py for bulk indexing.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str = None):
|
||||
self.url = url or QDRANT_URL
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> QdrantClient:
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(url=self.url)
|
||||
return self._client
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector_size: int = VECTOR_SIZE) -> bool:
|
||||
"""
|
||||
Ensure collection exists, create if needed.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection
|
||||
vector_size: Dimension of vectors
|
||||
|
||||
Returns:
|
||||
True if collection exists/created
|
||||
"""
|
||||
try:
|
||||
collections = self.client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if collection_name not in collection_names:
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error ensuring collection: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_points(self, collection_name: str, points: List[Dict]) -> int:
|
||||
"""
|
||||
Upsert points into collection.
|
||||
|
||||
Args:
|
||||
collection_name: Target collection
|
||||
points: List of {id, vector, payload}
|
||||
|
||||
Returns:
|
||||
Number of upserted points
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if not points:
|
||||
return 0
|
||||
|
||||
qdrant_points = []
|
||||
for p in points:
|
||||
# Convert string ID to UUID for Qdrant compatibility
|
||||
point_id = p["id"]
|
||||
if isinstance(point_id, str):
|
||||
# Use uuid5 with DNS namespace for deterministic UUID from string
|
||||
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, point_id))
|
||||
|
||||
qdrant_points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=p["vector"],
|
||||
payload={**p.get("payload", {}), "original_id": p["id"]} # Keep original ID in payload
|
||||
)
|
||||
)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=qdrant_points)
|
||||
return len(qdrant_points)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: List[float],
|
||||
filter_conditions: Optional[Dict] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Semantic search in collection.
|
||||
|
||||
Args:
|
||||
collection_name: Collection to search
|
||||
query_vector: Query embedding
|
||||
filter_conditions: Optional filters (key: value pairs)
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching points with scores
|
||||
"""
|
||||
query_filter = None
|
||||
if filter_conditions:
|
||||
must_conditions = [
|
||||
FieldCondition(key=k, match=MatchValue(value=v))
|
||||
for k, v in filter_conditions.items()
|
||||
]
|
||||
query_filter = Filter(must=must_conditions)
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"payload": r.payload
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def get_stats(self, collection_name: str) -> Dict:
|
||||
"""Get collection statistics."""
|
||||
try:
|
||||
info = self.client.get_collection(collection_name)
|
||||
return {
|
||||
"name": collection_name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "name": collection_name}
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Qdrant Vector Database Service — core client and BYOEH functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
COLLECTION_NAME = "bp_eh"
|
||||
VECTOR_SIZE = 1536 # OpenAI text-embedding-3-small
|
||||
|
||||
_client: Optional[QdrantClient] = None
|
||||
|
||||
|
||||
def get_qdrant_client() -> QdrantClient:
|
||||
"""Get or create Qdrant client singleton."""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = QdrantClient(url=QDRANT_URL)
|
||||
return _client
|
||||
|
||||
|
||||
async def init_qdrant_collection() -> bool:
|
||||
"""Initialize Qdrant collection for BYOEH if not exists."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Check if collection exists
|
||||
collections = client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if COLLECTION_NAME not in collection_names:
|
||||
client.create_collection(
|
||||
collection_name=COLLECTION_NAME,
|
||||
vectors_config=VectorParams(
|
||||
size=VECTOR_SIZE,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created Qdrant collection: {COLLECTION_NAME}")
|
||||
else:
|
||||
print(f"Qdrant collection {COLLECTION_NAME} already exists")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize Qdrant: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def index_eh_chunks(
|
||||
eh_id: str,
|
||||
tenant_id: str,
|
||||
subject: str,
|
||||
chunks: List[Dict]
|
||||
) -> int:
|
||||
"""
|
||||
Index EH chunks in Qdrant.
|
||||
|
||||
Args:
|
||||
eh_id: Erwartungshorizont ID
|
||||
tenant_id: Tenant/School ID for isolation
|
||||
subject: Subject (deutsch, englisch, etc.)
|
||||
chunks: List of {text, embedding, encrypted_content}
|
||||
|
||||
Returns:
|
||||
Number of indexed chunks
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
point_id = f"{eh_id}_{i}"
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=chunk["embedding"],
|
||||
payload={
|
||||
"tenant_id": tenant_id,
|
||||
"eh_id": eh_id,
|
||||
"chunk_index": i,
|
||||
"subject": subject,
|
||||
"encrypted_content": chunk.get("encrypted_content", ""),
|
||||
"training_allowed": False # ALWAYS FALSE - critical for compliance
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if points:
|
||||
client.upsert(collection_name=COLLECTION_NAME, points=points)
|
||||
|
||||
return len(points)
|
||||
|
||||
|
||||
async def search_eh(
|
||||
query_embedding: List[float],
|
||||
tenant_id: str,
|
||||
subject: Optional[str] = None,
|
||||
limit: int = 5
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Semantic search in tenant's Erwartungshorizonte.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector (1536 dimensions)
|
||||
tenant_id: Tenant ID for isolation
|
||||
subject: Optional subject filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching chunks with scores
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Build filter conditions
|
||||
must_conditions = [
|
||||
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id))
|
||||
]
|
||||
|
||||
if subject:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="subject", match=MatchValue(value=subject))
|
||||
)
|
||||
|
||||
query_filter = Filter(must=must_conditions)
|
||||
|
||||
results = client.search(
|
||||
collection_name=COLLECTION_NAME,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"eh_id": r.payload.get("eh_id"),
|
||||
"chunk_index": r.payload.get("chunk_index"),
|
||||
"encrypted_content": r.payload.get("encrypted_content"),
|
||||
"subject": r.payload.get("subject")
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
async def delete_eh_vectors(eh_id: str) -> int:
|
||||
"""
|
||||
Delete all vectors for a specific Erwartungshorizont.
|
||||
|
||||
Args:
|
||||
eh_id: Erwartungshorizont ID
|
||||
|
||||
Returns:
|
||||
Number of deleted points
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Get all points for this EH first
|
||||
scroll_result = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
scroll_filter=Filter(
|
||||
must=[FieldCondition(key="eh_id", match=MatchValue(value=eh_id))]
|
||||
),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
point_ids = [str(p.id) for p in scroll_result[0]]
|
||||
|
||||
if point_ids:
|
||||
client.delete(
|
||||
collection_name=COLLECTION_NAME,
|
||||
points_selector=models.PointIdsList(points=point_ids)
|
||||
)
|
||||
|
||||
return len(point_ids)
|
||||
|
||||
|
||||
async def get_collection_info() -> Dict:
|
||||
"""Get collection statistics."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
return {
|
||||
"name": COLLECTION_NAME,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Qdrant Vector Database Service — Legal Templates RAG Search.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client.models import VectorParams, Distance, Filter, FieldCondition, MatchValue
|
||||
|
||||
from qdrant_core import get_qdrant_client
|
||||
|
||||
LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates"
|
||||
LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3
|
||||
|
||||
|
||||
async def init_legal_templates_collection() -> bool:
|
||||
"""Initialize Qdrant collection for legal templates if not exists."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
collections = client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if LEGAL_TEMPLATES_COLLECTION not in collection_names:
|
||||
client.create_collection(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
vectors_config=VectorParams(
|
||||
size=LEGAL_TEMPLATES_VECTOR_SIZE,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}")
|
||||
else:
|
||||
print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize legal templates collection: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def search_legal_templates(
|
||||
query_embedding: List[float],
|
||||
template_type: Optional[str] = None,
|
||||
license_types: Optional[List[str]] = None,
|
||||
language: Optional[str] = None,
|
||||
jurisdiction: Optional[str] = None,
|
||||
attribution_required: Optional[bool] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search in legal templates collection for document generation.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector (1024 dimensions, BGE-M3)
|
||||
template_type: Filter by template type (privacy_policy, terms_of_service, etc.)
|
||||
license_types: Filter by license types (cc0, mit, cc_by_4, etc.)
|
||||
language: Filter by language (de, en)
|
||||
jurisdiction: Filter by jurisdiction (DE, EU, US, etc.)
|
||||
attribution_required: Filter by attribution requirement
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching template chunks with full metadata
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Build filter conditions
|
||||
must_conditions = []
|
||||
|
||||
if template_type:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="template_type", match=MatchValue(value=template_type))
|
||||
)
|
||||
|
||||
if language:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="language", match=MatchValue(value=language))
|
||||
)
|
||||
|
||||
if jurisdiction:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if attribution_required is not None:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
|
||||
)
|
||||
|
||||
# License type filter (OR condition)
|
||||
should_conditions = []
|
||||
if license_types:
|
||||
for license_type in license_types:
|
||||
should_conditions.append(
|
||||
FieldCondition(key="license_id", match=MatchValue(value=license_type))
|
||||
)
|
||||
|
||||
# Construct filter
|
||||
query_filter = None
|
||||
if must_conditions or should_conditions:
|
||||
filter_args = {}
|
||||
if must_conditions:
|
||||
filter_args["must"] = must_conditions
|
||||
if should_conditions:
|
||||
filter_args["should"] = should_conditions
|
||||
query_filter = Filter(**filter_args)
|
||||
|
||||
try:
|
||||
results = client.search(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"text": r.payload.get("text", ""),
|
||||
"document_title": r.payload.get("document_title"),
|
||||
"template_type": r.payload.get("template_type"),
|
||||
"clause_category": r.payload.get("clause_category"),
|
||||
"language": r.payload.get("language"),
|
||||
"jurisdiction": r.payload.get("jurisdiction"),
|
||||
"license_id": r.payload.get("license_id"),
|
||||
"license_name": r.payload.get("license_name"),
|
||||
"license_url": r.payload.get("license_url"),
|
||||
"attribution_required": r.payload.get("attribution_required"),
|
||||
"attribution_text": r.payload.get("attribution_text"),
|
||||
"source_name": r.payload.get("source_name"),
|
||||
"source_url": r.payload.get("source_url"),
|
||||
"source_repo": r.payload.get("source_repo"),
|
||||
"placeholders": r.payload.get("placeholders", []),
|
||||
"is_complete_document": r.payload.get("is_complete_document"),
|
||||
"is_modular": r.payload.get("is_modular"),
|
||||
"requires_customization": r.payload.get("requires_customization"),
|
||||
"output_allowed": r.payload.get("output_allowed"),
|
||||
"modification_allowed": r.payload.get("modification_allowed"),
|
||||
"distortion_prohibited": r.payload.get("distortion_prohibited"),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Legal templates search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_legal_templates_stats() -> Dict:
|
||||
"""Get statistics for the legal templates collection."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
info = client.get_collection(LEGAL_TEMPLATES_COLLECTION)
|
||||
|
||||
# Count by template type
|
||||
template_types = ["privacy_policy", "terms_of_service", "cookie_banner",
|
||||
"impressum", "widerruf", "dpa", "sla", "agb"]
|
||||
type_counts = {}
|
||||
for ttype in template_types:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))]
|
||||
)
|
||||
)
|
||||
if result.count > 0:
|
||||
type_counts[ttype] = result.count
|
||||
|
||||
# Count by language
|
||||
lang_counts = {}
|
||||
for lang in ["de", "en"]:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="language", match=MatchValue(value=lang))]
|
||||
)
|
||||
)
|
||||
lang_counts[lang] = result.count
|
||||
|
||||
# Count by license
|
||||
license_counts = {}
|
||||
for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))]
|
||||
)
|
||||
)
|
||||
if result.count > 0:
|
||||
license_counts[license_id] = result.count
|
||||
|
||||
return {
|
||||
"collection": LEGAL_TEMPLATES_COLLECTION,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value,
|
||||
"template_types": type_counts,
|
||||
"languages": lang_counts,
|
||||
"licenses": license_counts,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION}
|
||||
|
||||
|
||||
async def delete_legal_templates_by_source(source_name: str) -> int:
|
||||
"""
|
||||
Delete all legal template chunks from a specific source.
|
||||
|
||||
Args:
|
||||
source_name: Name of the source to delete
|
||||
|
||||
Returns:
|
||||
Number of deleted points
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Count first
|
||||
count_result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
)
|
||||
)
|
||||
|
||||
# Delete by filter
|
||||
client.delete(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
points_selector=Filter(
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
)
|
||||
)
|
||||
|
||||
return count_result.count
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Qdrant Vector Database Service — NiBiS RAG Search for Klausurkorrektur.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
|
||||
from qdrant_core import get_qdrant_client
|
||||
|
||||
|
||||
async def search_nibis_eh(
|
||||
query_embedding: List[float],
|
||||
year: Optional[int] = None,
|
||||
subject: Optional[str] = None,
|
||||
niveau: Optional[str] = None,
|
||||
limit: int = 5
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search in NiBiS Erwartungshorizonte (public, pre-indexed data).
|
||||
|
||||
Unlike search_eh(), this searches in the public NiBiS collection
|
||||
and returns plaintext (not encrypted).
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector
|
||||
year: Optional year filter (2016, 2017, 2024, 2025)
|
||||
subject: Optional subject filter
|
||||
niveau: Optional niveau filter (eA, gA)
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching chunks with metadata
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
collection = "bp_nibis_eh"
|
||||
|
||||
# Build filter
|
||||
must_conditions = []
|
||||
|
||||
if year:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="year", match=MatchValue(value=year))
|
||||
)
|
||||
if subject:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="subject", match=MatchValue(value=subject))
|
||||
)
|
||||
if niveau:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="niveau", match=MatchValue(value=niveau))
|
||||
)
|
||||
|
||||
query_filter = Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
try:
|
||||
results = client.search(
|
||||
collection_name=collection,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"text": r.payload.get("text", ""),
|
||||
"year": r.payload.get("year"),
|
||||
"subject": r.payload.get("subject"),
|
||||
"niveau": r.payload.get("niveau"),
|
||||
"task_number": r.payload.get("task_number"),
|
||||
"doc_type": r.payload.get("doc_type"),
|
||||
"variant": r.payload.get("variant"),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"NiBiS search error: {e}")
|
||||
return []
|
||||
@@ -1,638 +1,38 @@
|
||||
"""
|
||||
Qdrant Vector Database Service for BYOEH
|
||||
Manages vector storage and semantic search for Erwartungshorizonte.
|
||||
Qdrant Vector Database Service for BYOEH — barrel re-export.
|
||||
|
||||
The actual code lives in:
|
||||
- qdrant_core.py (client singleton, BYOEH index/search/delete)
|
||||
- qdrant_class.py (QdrantService class for NiBiS pipeline)
|
||||
- qdrant_nibis.py (NiBiS RAG search)
|
||||
- qdrant_legal.py (Legal Templates RAG search)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
COLLECTION_NAME = "bp_eh"
|
||||
VECTOR_SIZE = 1536 # OpenAI text-embedding-3-small
|
||||
|
||||
_client: Optional[QdrantClient] = None
|
||||
|
||||
|
||||
def get_qdrant_client() -> QdrantClient:
|
||||
"""Get or create Qdrant client singleton."""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = QdrantClient(url=QDRANT_URL)
|
||||
return _client
|
||||
|
||||
|
||||
async def init_qdrant_collection() -> bool:
|
||||
"""Initialize Qdrant collection for BYOEH if not exists."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Check if collection exists
|
||||
collections = client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if COLLECTION_NAME not in collection_names:
|
||||
client.create_collection(
|
||||
collection_name=COLLECTION_NAME,
|
||||
vectors_config=VectorParams(
|
||||
size=VECTOR_SIZE,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created Qdrant collection: {COLLECTION_NAME}")
|
||||
else:
|
||||
print(f"Qdrant collection {COLLECTION_NAME} already exists")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize Qdrant: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def index_eh_chunks(
|
||||
eh_id: str,
|
||||
tenant_id: str,
|
||||
subject: str,
|
||||
chunks: List[Dict]
|
||||
) -> int:
|
||||
"""
|
||||
Index EH chunks in Qdrant.
|
||||
|
||||
Args:
|
||||
eh_id: Erwartungshorizont ID
|
||||
tenant_id: Tenant/School ID for isolation
|
||||
subject: Subject (deutsch, englisch, etc.)
|
||||
chunks: List of {text, embedding, encrypted_content}
|
||||
|
||||
Returns:
|
||||
Number of indexed chunks
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
point_id = f"{eh_id}_{i}"
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=chunk["embedding"],
|
||||
payload={
|
||||
"tenant_id": tenant_id,
|
||||
"eh_id": eh_id,
|
||||
"chunk_index": i,
|
||||
"subject": subject,
|
||||
"encrypted_content": chunk.get("encrypted_content", ""),
|
||||
"training_allowed": False # ALWAYS FALSE - critical for compliance
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if points:
|
||||
client.upsert(collection_name=COLLECTION_NAME, points=points)
|
||||
|
||||
return len(points)
|
||||
|
||||
|
||||
async def search_eh(
|
||||
query_embedding: List[float],
|
||||
tenant_id: str,
|
||||
subject: Optional[str] = None,
|
||||
limit: int = 5
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Semantic search in tenant's Erwartungshorizonte.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector (1536 dimensions)
|
||||
tenant_id: Tenant ID for isolation
|
||||
subject: Optional subject filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching chunks with scores
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Build filter conditions
|
||||
must_conditions = [
|
||||
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id))
|
||||
]
|
||||
|
||||
if subject:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="subject", match=MatchValue(value=subject))
|
||||
)
|
||||
|
||||
query_filter = Filter(must=must_conditions)
|
||||
|
||||
results = client.search(
|
||||
collection_name=COLLECTION_NAME,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"eh_id": r.payload.get("eh_id"),
|
||||
"chunk_index": r.payload.get("chunk_index"),
|
||||
"encrypted_content": r.payload.get("encrypted_content"),
|
||||
"subject": r.payload.get("subject")
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
async def delete_eh_vectors(eh_id: str) -> int:
|
||||
"""
|
||||
Delete all vectors for a specific Erwartungshorizont.
|
||||
|
||||
Args:
|
||||
eh_id: Erwartungshorizont ID
|
||||
|
||||
Returns:
|
||||
Number of deleted points
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Get all points for this EH first
|
||||
scroll_result = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
scroll_filter=Filter(
|
||||
must=[FieldCondition(key="eh_id", match=MatchValue(value=eh_id))]
|
||||
),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
point_ids = [str(p.id) for p in scroll_result[0]]
|
||||
|
||||
if point_ids:
|
||||
client.delete(
|
||||
collection_name=COLLECTION_NAME,
|
||||
points_selector=models.PointIdsList(points=point_ids)
|
||||
)
|
||||
|
||||
return len(point_ids)
|
||||
|
||||
|
||||
async def get_collection_info() -> Dict:
|
||||
"""Get collection statistics."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
return {
|
||||
"name": COLLECTION_NAME,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# QdrantService Class (for NiBiS Ingestion Pipeline)
|
||||
# =============================================================================
|
||||
|
||||
class QdrantService:
|
||||
"""
|
||||
Class-based Qdrant service for flexible collection management.
|
||||
Used by nibis_ingestion.py for bulk indexing.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str = None):
|
||||
self.url = url or QDRANT_URL
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> QdrantClient:
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(url=self.url)
|
||||
return self._client
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector_size: int = VECTOR_SIZE) -> bool:
|
||||
"""
|
||||
Ensure collection exists, create if needed.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection
|
||||
vector_size: Dimension of vectors
|
||||
|
||||
Returns:
|
||||
True if collection exists/created
|
||||
"""
|
||||
try:
|
||||
collections = self.client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if collection_name not in collection_names:
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error ensuring collection: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_points(self, collection_name: str, points: List[Dict]) -> int:
|
||||
"""
|
||||
Upsert points into collection.
|
||||
|
||||
Args:
|
||||
collection_name: Target collection
|
||||
points: List of {id, vector, payload}
|
||||
|
||||
Returns:
|
||||
Number of upserted points
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if not points:
|
||||
return 0
|
||||
|
||||
qdrant_points = []
|
||||
for p in points:
|
||||
# Convert string ID to UUID for Qdrant compatibility
|
||||
point_id = p["id"]
|
||||
if isinstance(point_id, str):
|
||||
# Use uuid5 with DNS namespace for deterministic UUID from string
|
||||
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, point_id))
|
||||
|
||||
qdrant_points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=p["vector"],
|
||||
payload={**p.get("payload", {}), "original_id": p["id"]} # Keep original ID in payload
|
||||
)
|
||||
)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=qdrant_points)
|
||||
return len(qdrant_points)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: List[float],
|
||||
filter_conditions: Optional[Dict] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Semantic search in collection.
|
||||
|
||||
Args:
|
||||
collection_name: Collection to search
|
||||
query_vector: Query embedding
|
||||
filter_conditions: Optional filters (key: value pairs)
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching points with scores
|
||||
"""
|
||||
query_filter = None
|
||||
if filter_conditions:
|
||||
must_conditions = [
|
||||
FieldCondition(key=k, match=MatchValue(value=v))
|
||||
for k, v in filter_conditions.items()
|
||||
]
|
||||
query_filter = Filter(must=must_conditions)
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"payload": r.payload
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def get_stats(self, collection_name: str) -> Dict:
|
||||
"""Get collection statistics."""
|
||||
try:
|
||||
info = self.client.get_collection(collection_name)
|
||||
return {
|
||||
"name": collection_name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "name": collection_name}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NiBiS RAG Search (for Klausurkorrektur Module)
|
||||
# =============================================================================
|
||||
|
||||
async def search_nibis_eh(
|
||||
query_embedding: List[float],
|
||||
year: Optional[int] = None,
|
||||
subject: Optional[str] = None,
|
||||
niveau: Optional[str] = None,
|
||||
limit: int = 5
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search in NiBiS Erwartungshorizonte (public, pre-indexed data).
|
||||
|
||||
Unlike search_eh(), this searches in the public NiBiS collection
|
||||
and returns plaintext (not encrypted).
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector
|
||||
year: Optional year filter (2016, 2017, 2024, 2025)
|
||||
subject: Optional subject filter
|
||||
niveau: Optional niveau filter (eA, gA)
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching chunks with metadata
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
collection = "bp_nibis_eh"
|
||||
|
||||
# Build filter
|
||||
must_conditions = []
|
||||
|
||||
if year:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="year", match=MatchValue(value=year))
|
||||
)
|
||||
if subject:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="subject", match=MatchValue(value=subject))
|
||||
)
|
||||
if niveau:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="niveau", match=MatchValue(value=niveau))
|
||||
)
|
||||
|
||||
query_filter = Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
try:
|
||||
results = client.search(
|
||||
collection_name=collection,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"text": r.payload.get("text", ""),
|
||||
"year": r.payload.get("year"),
|
||||
"subject": r.payload.get("subject"),
|
||||
"niveau": r.payload.get("niveau"),
|
||||
"task_number": r.payload.get("task_number"),
|
||||
"doc_type": r.payload.get("doc_type"),
|
||||
"variant": r.payload.get("variant"),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"NiBiS search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Legal Templates RAG Search (for Document Generator)
|
||||
# =============================================================================
|
||||
|
||||
LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates"
|
||||
LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3
|
||||
|
||||
|
||||
async def init_legal_templates_collection() -> bool:
|
||||
"""Initialize Qdrant collection for legal templates if not exists."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
collections = client.get_collections().collections
|
||||
collection_names = [c.name for c in collections]
|
||||
|
||||
if LEGAL_TEMPLATES_COLLECTION not in collection_names:
|
||||
client.create_collection(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
vectors_config=VectorParams(
|
||||
size=LEGAL_TEMPLATES_VECTOR_SIZE,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}")
|
||||
else:
|
||||
print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize legal templates collection: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def search_legal_templates(
|
||||
query_embedding: List[float],
|
||||
template_type: Optional[str] = None,
|
||||
license_types: Optional[List[str]] = None,
|
||||
language: Optional[str] = None,
|
||||
jurisdiction: Optional[str] = None,
|
||||
attribution_required: Optional[bool] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search in legal templates collection for document generation.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector (1024 dimensions, BGE-M3)
|
||||
template_type: Filter by template type (privacy_policy, terms_of_service, etc.)
|
||||
license_types: Filter by license types (cc0, mit, cc_by_4, etc.)
|
||||
language: Filter by language (de, en)
|
||||
jurisdiction: Filter by jurisdiction (DE, EU, US, etc.)
|
||||
attribution_required: Filter by attribution requirement
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
List of matching template chunks with full metadata
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Build filter conditions
|
||||
must_conditions = []
|
||||
|
||||
if template_type:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="template_type", match=MatchValue(value=template_type))
|
||||
)
|
||||
|
||||
if language:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="language", match=MatchValue(value=language))
|
||||
)
|
||||
|
||||
if jurisdiction:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if attribution_required is not None:
|
||||
must_conditions.append(
|
||||
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
|
||||
)
|
||||
|
||||
# License type filter (OR condition)
|
||||
should_conditions = []
|
||||
if license_types:
|
||||
for license_type in license_types:
|
||||
should_conditions.append(
|
||||
FieldCondition(key="license_id", match=MatchValue(value=license_type))
|
||||
)
|
||||
|
||||
# Construct filter
|
||||
query_filter = None
|
||||
if must_conditions or should_conditions:
|
||||
filter_args = {}
|
||||
if must_conditions:
|
||||
filter_args["must"] = must_conditions
|
||||
if should_conditions:
|
||||
filter_args["should"] = should_conditions
|
||||
query_filter = Filter(**filter_args)
|
||||
|
||||
try:
|
||||
results = client.search(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
query_vector=query_embedding,
|
||||
query_filter=query_filter,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"score": r.score,
|
||||
"text": r.payload.get("text", ""),
|
||||
"document_title": r.payload.get("document_title"),
|
||||
"template_type": r.payload.get("template_type"),
|
||||
"clause_category": r.payload.get("clause_category"),
|
||||
"language": r.payload.get("language"),
|
||||
"jurisdiction": r.payload.get("jurisdiction"),
|
||||
"license_id": r.payload.get("license_id"),
|
||||
"license_name": r.payload.get("license_name"),
|
||||
"license_url": r.payload.get("license_url"),
|
||||
"attribution_required": r.payload.get("attribution_required"),
|
||||
"attribution_text": r.payload.get("attribution_text"),
|
||||
"source_name": r.payload.get("source_name"),
|
||||
"source_url": r.payload.get("source_url"),
|
||||
"source_repo": r.payload.get("source_repo"),
|
||||
"placeholders": r.payload.get("placeholders", []),
|
||||
"is_complete_document": r.payload.get("is_complete_document"),
|
||||
"is_modular": r.payload.get("is_modular"),
|
||||
"requires_customization": r.payload.get("requires_customization"),
|
||||
"output_allowed": r.payload.get("output_allowed"),
|
||||
"modification_allowed": r.payload.get("modification_allowed"),
|
||||
"distortion_prohibited": r.payload.get("distortion_prohibited"),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Legal templates search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_legal_templates_stats() -> Dict:
|
||||
"""Get statistics for the legal templates collection."""
|
||||
try:
|
||||
client = get_qdrant_client()
|
||||
info = client.get_collection(LEGAL_TEMPLATES_COLLECTION)
|
||||
|
||||
# Count by template type
|
||||
template_types = ["privacy_policy", "terms_of_service", "cookie_banner",
|
||||
"impressum", "widerruf", "dpa", "sla", "agb"]
|
||||
type_counts = {}
|
||||
for ttype in template_types:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))]
|
||||
)
|
||||
)
|
||||
if result.count > 0:
|
||||
type_counts[ttype] = result.count
|
||||
|
||||
# Count by language
|
||||
lang_counts = {}
|
||||
for lang in ["de", "en"]:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="language", match=MatchValue(value=lang))]
|
||||
)
|
||||
)
|
||||
lang_counts[lang] = result.count
|
||||
|
||||
# Count by license
|
||||
license_counts = {}
|
||||
for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]:
|
||||
result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))]
|
||||
)
|
||||
)
|
||||
if result.count > 0:
|
||||
license_counts[license_id] = result.count
|
||||
|
||||
return {
|
||||
"collection": LEGAL_TEMPLATES_COLLECTION,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value,
|
||||
"template_types": type_counts,
|
||||
"languages": lang_counts,
|
||||
"licenses": license_counts,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION}
|
||||
|
||||
|
||||
async def delete_legal_templates_by_source(source_name: str) -> int:
|
||||
"""
|
||||
Delete all legal template chunks from a specific source.
|
||||
|
||||
Args:
|
||||
source_name: Name of the source to delete
|
||||
|
||||
Returns:
|
||||
Number of deleted points
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
|
||||
# Count first
|
||||
count_result = client.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
)
|
||||
)
|
||||
|
||||
# Delete by filter
|
||||
client.delete(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
points_selector=Filter(
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
)
|
||||
)
|
||||
|
||||
return count_result.count
|
||||
# Core client & BYOEH functions
|
||||
from qdrant_core import ( # noqa: F401
|
||||
QDRANT_URL,
|
||||
COLLECTION_NAME,
|
||||
VECTOR_SIZE,
|
||||
get_qdrant_client,
|
||||
init_qdrant_collection,
|
||||
index_eh_chunks,
|
||||
search_eh,
|
||||
delete_eh_vectors,
|
||||
get_collection_info,
|
||||
)
|
||||
|
||||
# Class-based service
|
||||
from qdrant_class import QdrantService # noqa: F401
|
||||
|
||||
# NiBiS search
|
||||
from qdrant_nibis import search_nibis_eh # noqa: F401
|
||||
|
||||
# Legal templates
|
||||
from qdrant_legal import ( # noqa: F401
|
||||
LEGAL_TEMPLATES_COLLECTION,
|
||||
LEGAL_TEMPLATES_VECTOR_SIZE,
|
||||
init_legal_templates_collection,
|
||||
search_legal_templates,
|
||||
get_legal_templates_stats,
|
||||
delete_legal_templates_by_source,
|
||||
)
|
||||
|
||||
@@ -1,625 +1,31 @@
|
||||
"""
|
||||
Training API - Endpoints for managing AI training jobs
|
||||
Training API — barrel re-export.
|
||||
|
||||
Provides endpoints for:
|
||||
- Starting/stopping training jobs
|
||||
- Monitoring training progress
|
||||
- Managing model versions
|
||||
- Configuring training parameters
|
||||
- SSE streaming for real-time metrics
|
||||
|
||||
Phase 2.2: Server-Sent Events for live progress
|
||||
The actual code lives in:
|
||||
- training_models.py (enums, Pydantic models, in-memory state)
|
||||
- training_simulation.py (simulate_training_progress, SSE generators)
|
||||
- training_routes.py (FastAPI router + all endpoints)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENUMS & MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
PREPARING = "preparing"
|
||||
TRAINING = "training"
|
||||
VALIDATING = "validating"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ZEUGNIS = "zeugnis"
|
||||
KLAUSUR = "klausur"
|
||||
GENERAL = "general"
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Configuration for a training job."""
|
||||
name: str = Field(..., description="Name for the training job")
|
||||
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
|
||||
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
|
||||
batch_size: int = Field(16, ge=1, le=128)
|
||||
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
|
||||
epochs: int = Field(10, ge=1, le=100)
|
||||
warmup_steps: int = Field(500, ge=0, le=10000)
|
||||
weight_decay: float = Field(0.01, ge=0, le=1)
|
||||
gradient_accumulation: int = Field(4, ge=1, le=32)
|
||||
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
|
||||
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Metrics from a training job."""
|
||||
precision: float = 0.0
|
||||
recall: float = 0.0
|
||||
f1_score: float = 0.0
|
||||
accuracy: float = 0.0
|
||||
loss_history: List[float] = []
|
||||
val_loss_history: List[float] = []
|
||||
|
||||
|
||||
class TrainingJob(BaseModel):
|
||||
"""A training job with full details."""
|
||||
id: str
|
||||
name: str
|
||||
model_type: ModelType
|
||||
status: TrainingStatus
|
||||
progress: float
|
||||
current_epoch: int
|
||||
total_epochs: int
|
||||
loss: float
|
||||
val_loss: float
|
||||
learning_rate: float
|
||||
documents_processed: int
|
||||
total_documents: int
|
||||
started_at: Optional[datetime]
|
||||
estimated_completion: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
error_message: Optional[str]
|
||||
metrics: TrainingMetrics
|
||||
config: TrainingConfig
|
||||
|
||||
|
||||
class ModelVersion(BaseModel):
|
||||
"""A trained model version."""
|
||||
id: str
|
||||
job_id: str
|
||||
version: str
|
||||
model_type: ModelType
|
||||
created_at: datetime
|
||||
metrics: TrainingMetrics
|
||||
is_active: bool
|
||||
size_mb: float
|
||||
bundeslaender: List[str]
|
||||
|
||||
|
||||
class DatasetStats(BaseModel):
|
||||
"""Statistics about the training dataset."""
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
training_allowed: int
|
||||
by_bundesland: Dict[str, int]
|
||||
by_doc_type: Dict[str, int]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# IN-MEMORY STATE (Replace with database in production)
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""Global training state."""
|
||||
jobs: Dict[str, dict] = field(default_factory=dict)
|
||||
model_versions: Dict[str, dict] = field(default_factory=dict)
|
||||
active_job_id: Optional[str] = None
|
||||
|
||||
|
||||
_state = TrainingState()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
async def simulate_training_progress(job_id: str):
|
||||
"""Simulate training progress (replace with actual training logic)."""
|
||||
global _state
|
||||
|
||||
if job_id not in _state.jobs:
|
||||
return
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
job["started_at"] = datetime.now().isoformat()
|
||||
|
||||
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
|
||||
current_step = 0
|
||||
|
||||
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
|
||||
# Update progress
|
||||
progress = (current_step / total_steps) * 100
|
||||
current_epoch = current_step // 100 + 1
|
||||
|
||||
# Simulate decreasing loss
|
||||
base_loss = 0.8 * (1 - progress / 100) + 0.1
|
||||
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
|
||||
val_loss = loss * 1.1
|
||||
|
||||
# Update job state
|
||||
job["progress"] = progress
|
||||
job["current_epoch"] = min(current_epoch, job["total_epochs"])
|
||||
job["loss"] = round(loss, 4)
|
||||
job["val_loss"] = round(val_loss, 4)
|
||||
job["documents_processed"] = int((progress / 100) * job["total_documents"])
|
||||
|
||||
# Update metrics
|
||||
job["metrics"]["loss_history"].append(round(loss, 4))
|
||||
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
|
||||
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
|
||||
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
|
||||
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
|
||||
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
|
||||
|
||||
# Keep only last 50 history points
|
||||
if len(job["metrics"]["loss_history"]) > 50:
|
||||
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
|
||||
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
|
||||
|
||||
# Estimate completion
|
||||
if progress > 0:
|
||||
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
|
||||
remaining = (elapsed / progress) * (100 - progress)
|
||||
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
|
||||
|
||||
current_step += 1
|
||||
await asyncio.sleep(0.5) # Simulate work
|
||||
|
||||
# Mark as completed
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
job["status"] = TrainingStatus.COMPLETED.value
|
||||
job["progress"] = 100
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
# Create model version
|
||||
version_id = str(uuid.uuid4())
|
||||
_state.model_versions[version_id] = {
|
||||
"id": version_id,
|
||||
"job_id": job_id,
|
||||
"version": f"v{len(_state.model_versions) + 1}.0",
|
||||
"model_type": job["model_type"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"metrics": job["metrics"],
|
||||
"is_active": True,
|
||||
"size_mb": 245.7,
|
||||
"bundeslaender": job["config"]["bundeslaender"],
|
||||
}
|
||||
|
||||
_state.active_job_id = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ROUTER
|
||||
# ============================================================================
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=List[dict])
|
||||
async def list_training_jobs():
|
||||
"""Get all training jobs."""
|
||||
return list(_state.jobs.values())
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=dict)
|
||||
async def get_training_job(job_id: str):
|
||||
"""Get details for a specific training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return _state.jobs[job_id]
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=dict)
|
||||
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
|
||||
"""Create and start a new training job."""
|
||||
global _state
|
||||
|
||||
# Check if there's already an active job
|
||||
if _state.active_job_id:
|
||||
active_job = _state.jobs.get(_state.active_job_id)
|
||||
if active_job and active_job["status"] in [
|
||||
TrainingStatus.TRAINING.value,
|
||||
TrainingStatus.PREPARING.value,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Another training job is already running"
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = str(uuid.uuid4())
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": config.name,
|
||||
"model_type": config.model_type.value,
|
||||
"status": TrainingStatus.QUEUED.value,
|
||||
"progress": 0,
|
||||
"current_epoch": 0,
|
||||
"total_epochs": config.epochs,
|
||||
"loss": 1.0,
|
||||
"val_loss": 1.0,
|
||||
"learning_rate": config.learning_rate,
|
||||
"documents_processed": 0,
|
||||
"total_documents": len(config.bundeslaender) * 50, # Estimate
|
||||
"started_at": None,
|
||||
"estimated_completion": None,
|
||||
"completed_at": None,
|
||||
"error_message": None,
|
||||
"metrics": {
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"accuracy": 0.0,
|
||||
"loss_history": [],
|
||||
"val_loss_history": [],
|
||||
},
|
||||
"config": config.dict(),
|
||||
}
|
||||
|
||||
_state.jobs[job_id] = job
|
||||
_state.active_job_id = job_id
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"id": job_id, "status": "queued", "message": "Training job created"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/pause", response_model=dict)
|
||||
async def pause_training_job(job_id: str):
|
||||
"""Pause a running training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not running")
|
||||
|
||||
job["status"] = TrainingStatus.PAUSED.value
|
||||
return {"success": True, "message": "Training paused"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/resume", response_model=dict)
|
||||
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
|
||||
"""Resume a paused training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.PAUSED.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not paused")
|
||||
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
_state.active_job_id = job_id
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"success": True, "message": "Training resumed"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel", response_model=dict)
|
||||
async def cancel_training_job(job_id: str):
|
||||
"""Cancel a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.CANCELLED.value
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
if _state.active_job_id == job_id:
|
||||
_state.active_job_id = None
|
||||
|
||||
return {"success": True, "message": "Training cancelled"}
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}", response_model=dict)
|
||||
async def delete_training_job(job_id: str):
|
||||
"""Delete a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete running job")
|
||||
|
||||
del _state.jobs[job_id]
|
||||
return {"success": True, "message": "Job deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MODEL VERSIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/models", response_model=List[dict])
|
||||
async def list_model_versions():
|
||||
"""Get all trained model versions."""
|
||||
return list(_state.model_versions.values())
|
||||
|
||||
|
||||
@router.get("/models/{version_id}", response_model=dict)
|
||||
async def get_model_version(version_id: str):
|
||||
"""Get details for a specific model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
return _state.model_versions[version_id]
|
||||
|
||||
|
||||
@router.post("/models/{version_id}/activate", response_model=dict)
|
||||
async def activate_model_version(version_id: str):
|
||||
"""Set a model version as active."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Deactivate all other versions of same type
|
||||
model = _state.model_versions[version_id]
|
||||
for v in _state.model_versions.values():
|
||||
if v["model_type"] == model["model_type"]:
|
||||
v["is_active"] = False
|
||||
|
||||
model["is_active"] = True
|
||||
return {"success": True, "message": "Model activated"}
|
||||
|
||||
|
||||
@router.delete("/models/{version_id}", response_model=dict)
|
||||
async def delete_model_version(version_id: str):
|
||||
"""Delete a model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
model = _state.model_versions[version_id]
|
||||
if model["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete active model")
|
||||
|
||||
del _state.model_versions[version_id]
|
||||
return {"success": True, "message": "Model deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATASET STATS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/dataset/stats", response_model=dict)
|
||||
async def get_dataset_stats():
|
||||
"""Get statistics about the training dataset."""
|
||||
# Get stats from zeugnis sources
|
||||
from metrics_db import get_zeugnis_stats
|
||||
|
||||
zeugnis_stats = await get_zeugnis_stats()
|
||||
|
||||
return {
|
||||
"total_documents": zeugnis_stats.get("total_documents", 0),
|
||||
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12, # Estimate ~12 chunks per doc
|
||||
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
|
||||
"by_bundesland": {
|
||||
bl["bundesland"]: bl.get("doc_count", 0)
|
||||
for bl in zeugnis_stats.get("per_bundesland", [])
|
||||
},
|
||||
"by_doc_type": {
|
||||
"verordnung": 150,
|
||||
"schulordnung": 80,
|
||||
"handreichung": 45,
|
||||
"erlass": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRAINING STATUS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/status", response_model=dict)
|
||||
async def get_training_status():
|
||||
"""Get overall training system status."""
|
||||
active_job = None
|
||||
if _state.active_job_id and _state.active_job_id in _state.jobs:
|
||||
active_job = _state.jobs[_state.active_job_id]
|
||||
|
||||
return {
|
||||
"is_training": _state.active_job_id is not None and active_job is not None and
|
||||
active_job["status"] == TrainingStatus.TRAINING.value,
|
||||
"active_job_id": _state.active_job_id,
|
||||
"total_jobs": len(_state.jobs),
|
||||
"completed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.COMPLETED.value
|
||||
),
|
||||
"failed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.FAILED.value
|
||||
),
|
||||
"model_versions": len(_state.model_versions),
|
||||
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SERVER-SENT EVENTS (SSE) ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
async def training_metrics_generator(job_id: str, request: Request):
|
||||
"""
|
||||
SSE generator for streaming training metrics.
|
||||
|
||||
Yields JSON-encoded training status updates every 500ms.
|
||||
"""
|
||||
while True:
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Get job status
|
||||
if job_id not in _state.jobs:
|
||||
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
|
||||
break
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
|
||||
# Build metrics response
|
||||
metrics_data = {
|
||||
"job_id": job["id"],
|
||||
"status": job["status"],
|
||||
"progress": job["progress"],
|
||||
"current_epoch": job["current_epoch"],
|
||||
"total_epochs": job["total_epochs"],
|
||||
"current_step": int(job["progress"] * job["total_epochs"]),
|
||||
"total_steps": job["total_epochs"] * 100,
|
||||
"elapsed_time_ms": 0,
|
||||
"estimated_remaining_ms": 0,
|
||||
"metrics": {
|
||||
"loss": job["loss"],
|
||||
"val_loss": job["val_loss"],
|
||||
"accuracy": job["metrics"]["accuracy"],
|
||||
"learning_rate": job["learning_rate"]
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"epoch": i + 1,
|
||||
"step": (i + 1) * 10,
|
||||
"loss": loss,
|
||||
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
|
||||
"learning_rate": job["learning_rate"],
|
||||
"timestamp": 0
|
||||
}
|
||||
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
|
||||
]
|
||||
}
|
||||
|
||||
# Calculate elapsed time
|
||||
if job["started_at"]:
|
||||
started = datetime.fromisoformat(job["started_at"])
|
||||
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
|
||||
|
||||
# Calculate remaining time
|
||||
if job["estimated_completion"]:
|
||||
estimated = datetime.fromisoformat(job["estimated_completion"])
|
||||
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
|
||||
|
||||
# Send SSE event
|
||||
yield f"data: {json.dumps(metrics_data)}\n\n"
|
||||
|
||||
# Check if job completed
|
||||
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
|
||||
break
|
||||
|
||||
# Wait before next update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
@router.get("/metrics/stream")
|
||||
async def stream_training_metrics(job_id: str, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming training metrics.
|
||||
|
||||
Streams real-time training progress for a specific job.
|
||||
|
||||
Usage:
|
||||
const eventSource = new EventSource('/api/v1/admin/training/metrics/stream?job_id=xxx')
|
||||
eventSource.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data)
|
||||
console.log(data.progress, data.metrics.loss)
|
||||
}
|
||||
"""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return StreamingResponse(
|
||||
training_metrics_generator(job_id, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # Disable nginx buffering
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def batch_ocr_progress_generator(images_count: int, request: Request):
|
||||
"""
|
||||
SSE generator for batch OCR progress simulation.
|
||||
|
||||
In production, this would integrate with actual OCR processing.
|
||||
"""
|
||||
import random
|
||||
|
||||
for i in range(images_count):
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(random.uniform(0.3, 0.8))
|
||||
|
||||
progress_data = {
|
||||
"type": "progress",
|
||||
"current": i + 1,
|
||||
"total": images_count,
|
||||
"progress_percent": ((i + 1) / images_count) * 100,
|
||||
"elapsed_ms": (i + 1) * 500,
|
||||
"estimated_remaining_ms": (images_count - i - 1) * 500,
|
||||
"result": {
|
||||
"text": f"Sample recognized text for image {i + 1}",
|
||||
"confidence": round(random.uniform(0.7, 0.98), 2),
|
||||
"processing_time_ms": random.randint(200, 600),
|
||||
"from_cache": random.random() < 0.2
|
||||
}
|
||||
}
|
||||
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
# Send completion event
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
|
||||
|
||||
|
||||
@router.get("/ocr/stream")
|
||||
async def stream_batch_ocr(images_count: int, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming batch OCR progress.
|
||||
|
||||
Simulates batch OCR processing with progress updates.
|
||||
In production, integrate with actual TrOCR batch processing.
|
||||
|
||||
Args:
|
||||
images_count: Number of images to process
|
||||
|
||||
Usage:
|
||||
const eventSource = new EventSource('/api/v1/admin/training/ocr/stream?images_count=10')
|
||||
eventSource.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data)
|
||||
if (data.type === 'progress') {
|
||||
console.log(`${data.current}/${data.total}`)
|
||||
}
|
||||
}
|
||||
"""
|
||||
if images_count < 1 or images_count > 100:
|
||||
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
|
||||
|
||||
return StreamingResponse(
|
||||
batch_ocr_progress_generator(images_count, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# Models & enums
|
||||
from training_models import ( # noqa: F401
|
||||
TrainingStatus,
|
||||
ModelType,
|
||||
TrainingConfig,
|
||||
TrainingMetrics,
|
||||
TrainingJob,
|
||||
ModelVersion,
|
||||
DatasetStats,
|
||||
TrainingState,
|
||||
_state,
|
||||
)
|
||||
|
||||
# Simulation helpers
|
||||
from training_simulation import ( # noqa: F401
|
||||
simulate_training_progress,
|
||||
training_metrics_generator,
|
||||
batch_ocr_progress_generator,
|
||||
)
|
||||
|
||||
# Router
|
||||
from training_routes import router # noqa: F401
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Training API — enums, request/response models, and in-memory state.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENUMS
|
||||
# ============================================================================
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
PREPARING = "preparing"
|
||||
TRAINING = "training"
|
||||
VALIDATING = "validating"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ZEUGNIS = "zeugnis"
|
||||
KLAUSUR = "klausur"
|
||||
GENERAL = "general"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# REQUEST/RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Configuration for a training job."""
|
||||
name: str = Field(..., description="Name for the training job")
|
||||
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
|
||||
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
|
||||
batch_size: int = Field(16, ge=1, le=128)
|
||||
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
|
||||
epochs: int = Field(10, ge=1, le=100)
|
||||
warmup_steps: int = Field(500, ge=0, le=10000)
|
||||
weight_decay: float = Field(0.01, ge=0, le=1)
|
||||
gradient_accumulation: int = Field(4, ge=1, le=32)
|
||||
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
|
||||
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Metrics from a training job."""
|
||||
precision: float = 0.0
|
||||
recall: float = 0.0
|
||||
f1_score: float = 0.0
|
||||
accuracy: float = 0.0
|
||||
loss_history: List[float] = []
|
||||
val_loss_history: List[float] = []
|
||||
|
||||
|
||||
class TrainingJob(BaseModel):
|
||||
"""A training job with full details."""
|
||||
id: str
|
||||
name: str
|
||||
model_type: ModelType
|
||||
status: TrainingStatus
|
||||
progress: float
|
||||
current_epoch: int
|
||||
total_epochs: int
|
||||
loss: float
|
||||
val_loss: float
|
||||
learning_rate: float
|
||||
documents_processed: int
|
||||
total_documents: int
|
||||
started_at: Optional[datetime]
|
||||
estimated_completion: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
error_message: Optional[str]
|
||||
metrics: TrainingMetrics
|
||||
config: TrainingConfig
|
||||
|
||||
|
||||
class ModelVersion(BaseModel):
|
||||
"""A trained model version."""
|
||||
id: str
|
||||
job_id: str
|
||||
version: str
|
||||
model_type: ModelType
|
||||
created_at: datetime
|
||||
metrics: TrainingMetrics
|
||||
is_active: bool
|
||||
size_mb: float
|
||||
bundeslaender: List[str]
|
||||
|
||||
|
||||
class DatasetStats(BaseModel):
|
||||
"""Statistics about the training dataset."""
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
training_allowed: int
|
||||
by_bundesland: Dict[str, int]
|
||||
by_doc_type: Dict[str, int]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# IN-MEMORY STATE (Replace with database in production)
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""Global training state."""
|
||||
jobs: Dict[str, dict] = field(default_factory=dict)
|
||||
model_versions: Dict[str, dict] = field(default_factory=dict)
|
||||
active_job_id: Optional[str] = None
|
||||
|
||||
|
||||
_state = TrainingState()
|
||||
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Training API — FastAPI route handlers.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from training_models import (
|
||||
TrainingStatus,
|
||||
TrainingConfig,
|
||||
_state,
|
||||
)
|
||||
from training_simulation import (
|
||||
simulate_training_progress,
|
||||
training_metrics_generator,
|
||||
batch_ocr_progress_generator,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRAINING JOBS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/jobs", response_model=List[dict])
|
||||
async def list_training_jobs():
|
||||
"""Get all training jobs."""
|
||||
return list(_state.jobs.values())
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=dict)
|
||||
async def get_training_job(job_id: str):
|
||||
"""Get details for a specific training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return _state.jobs[job_id]
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=dict)
|
||||
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
|
||||
"""Create and start a new training job."""
|
||||
# Check if there's already an active job
|
||||
if _state.active_job_id:
|
||||
active_job = _state.jobs.get(_state.active_job_id)
|
||||
if active_job and active_job["status"] in [
|
||||
TrainingStatus.TRAINING.value,
|
||||
TrainingStatus.PREPARING.value,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Another training job is already running"
|
||||
)
|
||||
|
||||
# Create job
|
||||
job_id = str(uuid.uuid4())
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": config.name,
|
||||
"model_type": config.model_type.value,
|
||||
"status": TrainingStatus.QUEUED.value,
|
||||
"progress": 0,
|
||||
"current_epoch": 0,
|
||||
"total_epochs": config.epochs,
|
||||
"loss": 1.0,
|
||||
"val_loss": 1.0,
|
||||
"learning_rate": config.learning_rate,
|
||||
"documents_processed": 0,
|
||||
"total_documents": len(config.bundeslaender) * 50, # Estimate
|
||||
"started_at": None,
|
||||
"estimated_completion": None,
|
||||
"completed_at": None,
|
||||
"error_message": None,
|
||||
"metrics": {
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"accuracy": 0.0,
|
||||
"loss_history": [],
|
||||
"val_loss_history": [],
|
||||
},
|
||||
"config": config.dict(),
|
||||
}
|
||||
|
||||
_state.jobs[job_id] = job
|
||||
_state.active_job_id = job_id
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"id": job_id, "status": "queued", "message": "Training job created"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/pause", response_model=dict)
|
||||
async def pause_training_job(job_id: str):
|
||||
"""Pause a running training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not running")
|
||||
|
||||
job["status"] = TrainingStatus.PAUSED.value
|
||||
return {"success": True, "message": "Training paused"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/resume", response_model=dict)
|
||||
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
|
||||
"""Resume a paused training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] != TrainingStatus.PAUSED.value:
|
||||
raise HTTPException(status_code=400, detail="Job is not paused")
|
||||
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
_state.active_job_id = job_id
|
||||
background_tasks.add_task(simulate_training_progress, job_id)
|
||||
|
||||
return {"success": True, "message": "Training resumed"}
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel", response_model=dict)
|
||||
async def cancel_training_job(job_id: str):
|
||||
"""Cancel a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.CANCELLED.value
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
if _state.active_job_id == job_id:
|
||||
_state.active_job_id = None
|
||||
|
||||
return {"success": True, "message": "Training cancelled"}
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}", response_model=dict)
|
||||
async def delete_training_job(job_id: str):
|
||||
"""Delete a training job."""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete running job")
|
||||
|
||||
del _state.jobs[job_id]
|
||||
return {"success": True, "message": "Job deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MODEL VERSIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/models", response_model=List[dict])
|
||||
async def list_model_versions():
|
||||
"""Get all trained model versions."""
|
||||
return list(_state.model_versions.values())
|
||||
|
||||
|
||||
@router.get("/models/{version_id}", response_model=dict)
|
||||
async def get_model_version(version_id: str):
|
||||
"""Get details for a specific model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
return _state.model_versions[version_id]
|
||||
|
||||
|
||||
@router.post("/models/{version_id}/activate", response_model=dict)
|
||||
async def activate_model_version(version_id: str):
|
||||
"""Set a model version as active."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Deactivate all other versions of same type
|
||||
model = _state.model_versions[version_id]
|
||||
for v in _state.model_versions.values():
|
||||
if v["model_type"] == model["model_type"]:
|
||||
v["is_active"] = False
|
||||
|
||||
model["is_active"] = True
|
||||
return {"success": True, "message": "Model activated"}
|
||||
|
||||
|
||||
@router.delete("/models/{version_id}", response_model=dict)
|
||||
async def delete_model_version(version_id: str):
|
||||
"""Delete a model version."""
|
||||
if version_id not in _state.model_versions:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
model = _state.model_versions[version_id]
|
||||
if model["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete active model")
|
||||
|
||||
del _state.model_versions[version_id]
|
||||
return {"success": True, "message": "Model deleted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATASET STATS & STATUS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/dataset/stats", response_model=dict)
|
||||
async def get_dataset_stats():
|
||||
"""Get statistics about the training dataset."""
|
||||
from metrics_db import get_zeugnis_stats
|
||||
|
||||
zeugnis_stats = await get_zeugnis_stats()
|
||||
|
||||
return {
|
||||
"total_documents": zeugnis_stats.get("total_documents", 0),
|
||||
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
|
||||
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
|
||||
"by_bundesland": {
|
||||
bl["bundesland"]: bl.get("doc_count", 0)
|
||||
for bl in zeugnis_stats.get("per_bundesland", [])
|
||||
},
|
||||
"by_doc_type": {
|
||||
"verordnung": 150,
|
||||
"schulordnung": 80,
|
||||
"handreichung": 45,
|
||||
"erlass": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status", response_model=dict)
|
||||
async def get_training_status():
|
||||
"""Get overall training system status."""
|
||||
active_job = None
|
||||
if _state.active_job_id and _state.active_job_id in _state.jobs:
|
||||
active_job = _state.jobs[_state.active_job_id]
|
||||
|
||||
return {
|
||||
"is_training": _state.active_job_id is not None and active_job is not None and
|
||||
active_job["status"] == TrainingStatus.TRAINING.value,
|
||||
"active_job_id": _state.active_job_id,
|
||||
"total_jobs": len(_state.jobs),
|
||||
"completed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.COMPLETED.value
|
||||
),
|
||||
"failed_jobs": sum(
|
||||
1 for j in _state.jobs.values()
|
||||
if j["status"] == TrainingStatus.FAILED.value
|
||||
),
|
||||
"model_versions": len(_state.model_versions),
|
||||
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SSE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics/stream")
|
||||
async def stream_training_metrics(job_id: str, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming training metrics.
|
||||
|
||||
Streams real-time training progress for a specific job.
|
||||
"""
|
||||
if job_id not in _state.jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return StreamingResponse(
|
||||
training_metrics_generator(job_id, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ocr/stream")
|
||||
async def stream_batch_ocr(images_count: int, request: Request):
|
||||
"""
|
||||
SSE endpoint for streaming batch OCR progress.
|
||||
|
||||
Simulates batch OCR processing with progress updates.
|
||||
"""
|
||||
if images_count < 1 or images_count > 100:
|
||||
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
|
||||
|
||||
return StreamingResponse(
|
||||
batch_ocr_progress_generator(images_count, request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Training API — simulation helper and SSE generators.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from training_models import TrainingStatus, _state
|
||||
|
||||
|
||||
async def simulate_training_progress(job_id: str):
|
||||
"""Simulate training progress (replace with actual training logic)."""
|
||||
if job_id not in _state.jobs:
|
||||
return
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
job["status"] = TrainingStatus.TRAINING.value
|
||||
job["started_at"] = datetime.now().isoformat()
|
||||
|
||||
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
|
||||
current_step = 0
|
||||
|
||||
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
|
||||
# Update progress
|
||||
progress = (current_step / total_steps) * 100
|
||||
current_epoch = current_step // 100 + 1
|
||||
|
||||
# Simulate decreasing loss
|
||||
base_loss = 0.8 * (1 - progress / 100) + 0.1
|
||||
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
|
||||
val_loss = loss * 1.1
|
||||
|
||||
# Update job state
|
||||
job["progress"] = progress
|
||||
job["current_epoch"] = min(current_epoch, job["total_epochs"])
|
||||
job["loss"] = round(loss, 4)
|
||||
job["val_loss"] = round(val_loss, 4)
|
||||
job["documents_processed"] = int((progress / 100) * job["total_documents"])
|
||||
|
||||
# Update metrics
|
||||
job["metrics"]["loss_history"].append(round(loss, 4))
|
||||
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
|
||||
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
|
||||
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
|
||||
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
|
||||
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
|
||||
|
||||
# Keep only last 50 history points
|
||||
if len(job["metrics"]["loss_history"]) > 50:
|
||||
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
|
||||
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
|
||||
|
||||
# Estimate completion
|
||||
if progress > 0:
|
||||
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
|
||||
remaining = (elapsed / progress) * (100 - progress)
|
||||
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
|
||||
|
||||
current_step += 1
|
||||
await asyncio.sleep(0.5) # Simulate work
|
||||
|
||||
# Mark as completed
|
||||
if job["status"] == TrainingStatus.TRAINING.value:
|
||||
job["status"] = TrainingStatus.COMPLETED.value
|
||||
job["progress"] = 100
|
||||
job["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
# Create model version
|
||||
version_id = str(uuid.uuid4())
|
||||
_state.model_versions[version_id] = {
|
||||
"id": version_id,
|
||||
"job_id": job_id,
|
||||
"version": f"v{len(_state.model_versions) + 1}.0",
|
||||
"model_type": job["model_type"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"metrics": job["metrics"],
|
||||
"is_active": True,
|
||||
"size_mb": 245.7,
|
||||
"bundeslaender": job["config"]["bundeslaender"],
|
||||
}
|
||||
|
||||
_state.active_job_id = None
|
||||
|
||||
|
||||
async def training_metrics_generator(job_id: str, request):
|
||||
"""
|
||||
SSE generator for streaming training metrics.
|
||||
|
||||
Yields JSON-encoded training status updates every 500ms.
|
||||
"""
|
||||
while True:
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Get job status
|
||||
if job_id not in _state.jobs:
|
||||
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
|
||||
break
|
||||
|
||||
job = _state.jobs[job_id]
|
||||
|
||||
# Build metrics response
|
||||
metrics_data = {
|
||||
"job_id": job["id"],
|
||||
"status": job["status"],
|
||||
"progress": job["progress"],
|
||||
"current_epoch": job["current_epoch"],
|
||||
"total_epochs": job["total_epochs"],
|
||||
"current_step": int(job["progress"] * job["total_epochs"]),
|
||||
"total_steps": job["total_epochs"] * 100,
|
||||
"elapsed_time_ms": 0,
|
||||
"estimated_remaining_ms": 0,
|
||||
"metrics": {
|
||||
"loss": job["loss"],
|
||||
"val_loss": job["val_loss"],
|
||||
"accuracy": job["metrics"]["accuracy"],
|
||||
"learning_rate": job["learning_rate"]
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"epoch": i + 1,
|
||||
"step": (i + 1) * 10,
|
||||
"loss": loss,
|
||||
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
|
||||
"learning_rate": job["learning_rate"],
|
||||
"timestamp": 0
|
||||
}
|
||||
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
|
||||
]
|
||||
}
|
||||
|
||||
# Calculate elapsed time
|
||||
if job["started_at"]:
|
||||
started = datetime.fromisoformat(job["started_at"])
|
||||
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
|
||||
|
||||
# Calculate remaining time
|
||||
if job["estimated_completion"]:
|
||||
estimated = datetime.fromisoformat(job["estimated_completion"])
|
||||
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
|
||||
|
||||
# Send SSE event
|
||||
yield f"data: {json.dumps(metrics_data)}\n\n"
|
||||
|
||||
# Check if job completed
|
||||
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
|
||||
break
|
||||
|
||||
# Wait before next update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
async def batch_ocr_progress_generator(images_count: int, request):
|
||||
"""
|
||||
SSE generator for batch OCR progress simulation.
|
||||
|
||||
In production, this would integrate with actual OCR processing.
|
||||
"""
|
||||
import random
|
||||
|
||||
for i in range(images_count):
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(random.uniform(0.3, 0.8))
|
||||
|
||||
progress_data = {
|
||||
"type": "progress",
|
||||
"current": i + 1,
|
||||
"total": images_count,
|
||||
"progress_percent": ((i + 1) / images_count) * 100,
|
||||
"elapsed_ms": (i + 1) * 500,
|
||||
"estimated_remaining_ms": (images_count - i - 1) * 500,
|
||||
"result": {
|
||||
"text": f"Sample recognized text for image {i + 1}",
|
||||
"confidence": round(random.uniform(0.7, 0.98), 2),
|
||||
"processing_time_ms": random.randint(200, 600),
|
||||
"from_cache": random.random() < 0.2
|
||||
}
|
||||
}
|
||||
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
# Send completion event
|
||||
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Zeugnis Crawler - Start/stop/status control functions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from zeugnis_worker import ZeugnisCrawler, get_crawler_state
|
||||
|
||||
|
||||
_crawler_instance: Optional[ZeugnisCrawler] = None
|
||||
_crawler_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def start_crawler(bundesland: Optional[str] = None, source_id: Optional[str] = None) -> bool:
|
||||
"""Start the crawler."""
|
||||
global _crawler_instance, _crawler_task
|
||||
|
||||
state = get_crawler_state()
|
||||
|
||||
if state.is_running:
|
||||
return False
|
||||
|
||||
state.is_running = True
|
||||
state.documents_crawled_today = 0
|
||||
state.documents_indexed_today = 0
|
||||
state.errors_today = 0
|
||||
|
||||
_crawler_instance = ZeugnisCrawler()
|
||||
await _crawler_instance.init()
|
||||
|
||||
async def run_crawler():
|
||||
try:
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
# Get sources to crawl
|
||||
if source_id:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources WHERE id = $1",
|
||||
source_id
|
||||
)
|
||||
elif bundesland:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources WHERE bundesland = $1",
|
||||
bundesland
|
||||
)
|
||||
else:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources ORDER BY bundesland"
|
||||
)
|
||||
|
||||
for source in sources:
|
||||
if not state.is_running:
|
||||
break
|
||||
await _crawler_instance.crawl_source(source["id"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Crawler error: {e}")
|
||||
|
||||
finally:
|
||||
state.is_running = False
|
||||
if _crawler_instance:
|
||||
await _crawler_instance.close()
|
||||
|
||||
_crawler_task = asyncio.create_task(run_crawler())
|
||||
return True
|
||||
|
||||
|
||||
async def stop_crawler() -> bool:
|
||||
"""Stop the crawler."""
|
||||
global _crawler_task
|
||||
|
||||
state = get_crawler_state()
|
||||
|
||||
if not state.is_running:
|
||||
return False
|
||||
|
||||
state.is_running = False
|
||||
|
||||
if _crawler_task:
|
||||
_crawler_task.cancel()
|
||||
try:
|
||||
await _crawler_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_crawler_status() -> Dict[str, Any]:
|
||||
"""Get current crawler status."""
|
||||
state = get_crawler_state()
|
||||
return {
|
||||
"is_running": state.is_running,
|
||||
"current_source": state.current_source_id,
|
||||
"current_bundesland": state.current_bundesland,
|
||||
"queue_length": len(state.queue),
|
||||
"documents_crawled_today": state.documents_crawled_today,
|
||||
"documents_indexed_today": state.documents_indexed_today,
|
||||
"errors_today": state.errors_today,
|
||||
"last_activity": state.last_activity.isoformat() if state.last_activity else None,
|
||||
}
|
||||
@@ -1,676 +1,26 @@
|
||||
"""
|
||||
Zeugnis Rights-Aware Crawler
|
||||
|
||||
Crawls official government documents about school certificates (Zeugnisse)
|
||||
from all 16 German federal states. Only indexes documents where AI training
|
||||
is legally permitted.
|
||||
Barrel re-export: all public symbols for backward compatibility.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
# Local imports
|
||||
from zeugnis_models import (
|
||||
CrawlStatus, LicenseType, DocType, EventType,
|
||||
BUNDESLAENDER, TRAINING_PERMISSIONS,
|
||||
generate_id, get_training_allowed, get_bundesland_name,
|
||||
from zeugnis_text import ( # noqa: F401
|
||||
extract_text_from_pdf,
|
||||
extract_text_from_html,
|
||||
chunk_text,
|
||||
compute_hash,
|
||||
)
|
||||
from zeugnis_storage import ( # noqa: F401
|
||||
generate_embeddings,
|
||||
upload_to_minio,
|
||||
index_in_qdrant,
|
||||
)
|
||||
from zeugnis_worker import ( # noqa: F401
|
||||
CrawlerState,
|
||||
ZeugnisCrawler,
|
||||
)
|
||||
from zeugnis_control import ( # noqa: F401
|
||||
start_crawler,
|
||||
stop_crawler,
|
||||
get_crawler_status,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000")
|
||||
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "test-access-key")
|
||||
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "test-secret-key")
|
||||
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-rag")
|
||||
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
|
||||
|
||||
ZEUGNIS_COLLECTION = "bp_zeugnis"
|
||||
CHUNK_SIZE = 1000
|
||||
CHUNK_OVERLAP = 200
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 5 # seconds
|
||||
REQUEST_TIMEOUT = 30 # seconds
|
||||
USER_AGENT = "BreakPilot-Zeugnis-Crawler/1.0 (Educational Research)"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler State
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CrawlerState:
|
||||
"""Global crawler state."""
|
||||
is_running: bool = False
|
||||
current_source_id: Optional[str] = None
|
||||
current_bundesland: Optional[str] = None
|
||||
queue: List[Dict] = field(default_factory=list)
|
||||
documents_crawled_today: int = 0
|
||||
documents_indexed_today: int = 0
|
||||
errors_today: int = 0
|
||||
last_activity: Optional[datetime] = None
|
||||
|
||||
|
||||
_crawler_state = CrawlerState()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Text Extraction
|
||||
# =============================================================================
|
||||
|
||||
def extract_text_from_pdf(content: bytes) -> str:
|
||||
"""Extract text from PDF bytes."""
|
||||
try:
|
||||
from PyPDF2 import PdfReader
|
||||
import io
|
||||
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
text_parts = []
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
return "\n\n".join(text_parts)
|
||||
except Exception as e:
|
||||
print(f"PDF extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def extract_text_from_html(content: bytes, encoding: str = "utf-8") -> str:
|
||||
"""Extract text from HTML bytes."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
html = content.decode(encoding, errors="replace")
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Remove script and style elements
|
||||
for element in soup(["script", "style", "nav", "header", "footer"]):
|
||||
element.decompose()
|
||||
|
||||
# Get text
|
||||
text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
# Clean up whitespace
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
print(f"HTML extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
|
||||
"""Split text into overlapping chunks."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
separators = ["\n\n", "\n", ". ", " "]
|
||||
|
||||
def split_recursive(text: str, sep_index: int = 0) -> List[str]:
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
if sep_index >= len(separators):
|
||||
# Force split at chunk_size
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
chunk = text[i:i + chunk_size]
|
||||
if chunk.strip():
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
sep = separators[sep_index]
|
||||
parts = text.split(sep)
|
||||
result = []
|
||||
current = ""
|
||||
|
||||
for part in parts:
|
||||
if len(current) + len(sep) + len(part) <= chunk_size:
|
||||
current = current + sep + part if current else part
|
||||
else:
|
||||
if current.strip():
|
||||
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
|
||||
current = part
|
||||
|
||||
if current.strip():
|
||||
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
|
||||
|
||||
return result
|
||||
|
||||
chunks = split_recursive(text)
|
||||
|
||||
# Add overlap
|
||||
if overlap > 0 and len(chunks) > 1:
|
||||
overlapped = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i > 0:
|
||||
# Add end of previous chunk
|
||||
prev_end = chunks[i - 1][-overlap:]
|
||||
chunk = prev_end + chunk
|
||||
overlapped.append(chunk)
|
||||
chunks = overlapped
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def compute_hash(content: bytes) -> str:
|
||||
"""Compute SHA-256 hash of content."""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Generation
|
||||
# =============================================================================
|
||||
|
||||
_embedding_model = None
|
||||
|
||||
|
||||
def get_embedding_model():
|
||||
"""Get or initialize embedding model."""
|
||||
global _embedding_model
|
||||
if _embedding_model is None and EMBEDDING_BACKEND == "local":
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
print("Loaded local embedding model: all-MiniLM-L6-v2")
|
||||
except ImportError:
|
||||
print("Warning: sentence-transformers not installed")
|
||||
return _embedding_model
|
||||
|
||||
|
||||
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a list of texts."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if EMBEDDING_BACKEND == "local":
|
||||
model = get_embedding_model()
|
||||
if model:
|
||||
embeddings = model.encode(texts, show_progress_bar=False)
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
return []
|
||||
|
||||
elif EMBEDDING_BACKEND == "openai":
|
||||
import openai
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: OPENAI_API_KEY not set")
|
||||
return []
|
||||
|
||||
client = openai.AsyncOpenAI(api_key=api_key)
|
||||
response = await client.embeddings.create(
|
||||
input=texts,
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MinIO Storage
|
||||
# =============================================================================
|
||||
|
||||
async def upload_to_minio(
|
||||
content: bytes,
|
||||
bundesland: str,
|
||||
filename: str,
|
||||
content_type: str = "application/pdf",
|
||||
year: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""Upload document to MinIO."""
|
||||
try:
|
||||
from minio import Minio
|
||||
|
||||
client = Minio(
|
||||
MINIO_ENDPOINT,
|
||||
access_key=MINIO_ACCESS_KEY,
|
||||
secret_key=MINIO_SECRET_KEY,
|
||||
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Ensure bucket exists
|
||||
if not client.bucket_exists(MINIO_BUCKET):
|
||||
client.make_bucket(MINIO_BUCKET)
|
||||
|
||||
# Build path
|
||||
year_str = str(year) if year else str(datetime.now().year)
|
||||
object_name = f"landes-daten/{bundesland}/zeugnis/{year_str}/{filename}"
|
||||
|
||||
# Upload
|
||||
import io
|
||||
client.put_object(
|
||||
MINIO_BUCKET,
|
||||
object_name,
|
||||
io.BytesIO(content),
|
||||
len(content),
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
return object_name
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Qdrant Indexing
|
||||
# =============================================================================
|
||||
|
||||
async def index_in_qdrant(
|
||||
doc_id: str,
|
||||
chunks: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadata: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Index document chunks in Qdrant."""
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL)
|
||||
|
||||
# Ensure collection exists
|
||||
collections = client.get_collections().collections
|
||||
if not any(c.name == ZEUGNIS_COLLECTION for c in collections):
|
||||
vector_size = len(embeddings[0]) if embeddings else 384
|
||||
client.create_collection(
|
||||
collection_name=ZEUGNIS_COLLECTION,
|
||||
vectors_config=VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
)
|
||||
print(f"Created Qdrant collection: {ZEUGNIS_COLLECTION}")
|
||||
|
||||
# Create points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
"document_id": doc_id,
|
||||
"chunk_index": i,
|
||||
"chunk_text": chunk[:500], # Store first 500 chars for preview
|
||||
"bundesland": metadata.get("bundesland"),
|
||||
"doc_type": metadata.get("doc_type"),
|
||||
"title": metadata.get("title"),
|
||||
"source_url": metadata.get("url"),
|
||||
"training_allowed": metadata.get("training_allowed", False),
|
||||
"indexed_at": datetime.now().isoformat(),
|
||||
}
|
||||
))
|
||||
|
||||
# Upsert
|
||||
if points:
|
||||
client.upsert(
|
||||
collection_name=ZEUGNIS_COLLECTION,
|
||||
points=points,
|
||||
)
|
||||
|
||||
return len(points)
|
||||
except Exception as e:
|
||||
print(f"Qdrant indexing failed: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler Worker
|
||||
# =============================================================================
|
||||
|
||||
class ZeugnisCrawler:
|
||||
"""Rights-aware crawler for zeugnis documents."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.db_pool = None
|
||||
|
||||
async def init(self):
|
||||
"""Initialize crawler resources."""
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": USER_AGENT},
|
||||
)
|
||||
|
||||
# Initialize database connection
|
||||
try:
|
||||
from metrics_db import get_pool
|
||||
self.db_pool = await get_pool()
|
||||
except Exception as e:
|
||||
print(f"Failed to get database pool: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""Close crawler resources."""
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def fetch_url(self, url: str) -> Tuple[Optional[bytes], Optional[str]]:
|
||||
"""Fetch URL with retry logic."""
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
return response.content, content_type
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(f"HTTP error {e.response.status_code} for {url}")
|
||||
if e.response.status_code == 404:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed for {url}: {e}")
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
|
||||
return None, None
|
||||
|
||||
async def crawl_seed_url(
|
||||
self,
|
||||
seed_url_id: str,
|
||||
url: str,
|
||||
bundesland: str,
|
||||
doc_type: str,
|
||||
training_allowed: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Crawl a single seed URL."""
|
||||
global _crawler_state
|
||||
|
||||
result = {
|
||||
"seed_url_id": seed_url_id,
|
||||
"url": url,
|
||||
"success": False,
|
||||
"document_id": None,
|
||||
"indexed": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
try:
|
||||
# Fetch content
|
||||
content, content_type = await self.fetch_url(url)
|
||||
if not content:
|
||||
result["error"] = "Failed to fetch URL"
|
||||
return result
|
||||
|
||||
# Determine file type
|
||||
is_pdf = "pdf" in content_type.lower() or url.lower().endswith(".pdf")
|
||||
|
||||
# Extract text
|
||||
if is_pdf:
|
||||
text = extract_text_from_pdf(content)
|
||||
filename = url.split("/")[-1] or f"document_{seed_url_id}.pdf"
|
||||
else:
|
||||
text = extract_text_from_html(content)
|
||||
filename = f"document_{seed_url_id}.html"
|
||||
|
||||
if not text:
|
||||
result["error"] = "No text extracted"
|
||||
return result
|
||||
|
||||
# Compute hash for versioning
|
||||
content_hash = compute_hash(content)
|
||||
|
||||
# Upload to MinIO
|
||||
minio_path = await upload_to_minio(
|
||||
content,
|
||||
bundesland,
|
||||
filename,
|
||||
content_type=content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = generate_id()
|
||||
|
||||
# Store document in database
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_documents
|
||||
(id, seed_url_id, title, url, content_hash, minio_path,
|
||||
training_allowed, file_size, content_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
doc_id, seed_url_id, filename, url, content_hash,
|
||||
minio_path, training_allowed, len(content), content_type
|
||||
)
|
||||
|
||||
result["document_id"] = doc_id
|
||||
result["success"] = True
|
||||
_crawler_state.documents_crawled_today += 1
|
||||
|
||||
# Only index if training is allowed
|
||||
if training_allowed:
|
||||
chunks = chunk_text(text)
|
||||
if chunks:
|
||||
embeddings = await generate_embeddings(chunks)
|
||||
if embeddings:
|
||||
indexed_count = await index_in_qdrant(
|
||||
doc_id,
|
||||
chunks,
|
||||
embeddings,
|
||||
{
|
||||
"bundesland": bundesland,
|
||||
"doc_type": doc_type,
|
||||
"title": filename,
|
||||
"url": url,
|
||||
"training_allowed": True,
|
||||
}
|
||||
)
|
||||
if indexed_count > 0:
|
||||
result["indexed"] = True
|
||||
_crawler_state.documents_indexed_today += 1
|
||||
|
||||
# Update database
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_documents SET indexed_in_qdrant = true WHERE id = $1",
|
||||
doc_id
|
||||
)
|
||||
else:
|
||||
result["indexed"] = False
|
||||
result["error"] = "Training not allowed for this source"
|
||||
|
||||
_crawler_state.last_activity = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
_crawler_state.errors_today += 1
|
||||
|
||||
return result
|
||||
|
||||
async def crawl_source(self, source_id: str) -> Dict[str, Any]:
|
||||
"""Crawl all seed URLs for a source."""
|
||||
global _crawler_state
|
||||
|
||||
result = {
|
||||
"source_id": source_id,
|
||||
"documents_found": 0,
|
||||
"documents_indexed": 0,
|
||||
"errors": [],
|
||||
"started_at": datetime.now(),
|
||||
"completed_at": None,
|
||||
}
|
||||
|
||||
if not self.db_pool:
|
||||
result["errors"].append("Database not available")
|
||||
return result
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
# Get source info
|
||||
source = await conn.fetchrow(
|
||||
"SELECT * FROM zeugnis_sources WHERE id = $1",
|
||||
source_id
|
||||
)
|
||||
if not source:
|
||||
result["errors"].append(f"Source not found: {source_id}")
|
||||
return result
|
||||
|
||||
bundesland = source["bundesland"]
|
||||
training_allowed = source["training_allowed"]
|
||||
|
||||
_crawler_state.current_source_id = source_id
|
||||
_crawler_state.current_bundesland = bundesland
|
||||
|
||||
# Get seed URLs
|
||||
seed_urls = await conn.fetch(
|
||||
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 AND status != 'completed'",
|
||||
source_id
|
||||
)
|
||||
|
||||
for seed_url in seed_urls:
|
||||
# Update status to running
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'running' WHERE id = $1",
|
||||
seed_url["id"]
|
||||
)
|
||||
|
||||
# Crawl
|
||||
crawl_result = await self.crawl_seed_url(
|
||||
seed_url["id"],
|
||||
seed_url["url"],
|
||||
bundesland,
|
||||
seed_url["doc_type"],
|
||||
training_allowed,
|
||||
)
|
||||
|
||||
# Update status
|
||||
if crawl_result["success"]:
|
||||
result["documents_found"] += 1
|
||||
if crawl_result["indexed"]:
|
||||
result["documents_indexed"] += 1
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'completed', last_crawled = NOW() WHERE id = $1",
|
||||
seed_url["id"]
|
||||
)
|
||||
else:
|
||||
result["errors"].append(f"{seed_url['url']}: {crawl_result['error']}")
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'failed', error_message = $2 WHERE id = $1",
|
||||
seed_url["id"], crawl_result["error"]
|
||||
)
|
||||
|
||||
# Small delay between requests
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(str(e))
|
||||
|
||||
finally:
|
||||
result["completed_at"] = datetime.now()
|
||||
_crawler_state.current_source_id = None
|
||||
_crawler_state.current_bundesland = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler Control Functions
|
||||
# =============================================================================
|
||||
|
||||
_crawler_instance: Optional[ZeugnisCrawler] = None
|
||||
_crawler_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def start_crawler(bundesland: Optional[str] = None, source_id: Optional[str] = None) -> bool:
|
||||
"""Start the crawler."""
|
||||
global _crawler_state, _crawler_instance, _crawler_task
|
||||
|
||||
if _crawler_state.is_running:
|
||||
return False
|
||||
|
||||
_crawler_state.is_running = True
|
||||
_crawler_state.documents_crawled_today = 0
|
||||
_crawler_state.documents_indexed_today = 0
|
||||
_crawler_state.errors_today = 0
|
||||
|
||||
_crawler_instance = ZeugnisCrawler()
|
||||
await _crawler_instance.init()
|
||||
|
||||
async def run_crawler():
|
||||
try:
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
# Get sources to crawl
|
||||
if source_id:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources WHERE id = $1",
|
||||
source_id
|
||||
)
|
||||
elif bundesland:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources WHERE bundesland = $1",
|
||||
bundesland
|
||||
)
|
||||
else:
|
||||
sources = await conn.fetch(
|
||||
"SELECT id, bundesland FROM zeugnis_sources ORDER BY bundesland"
|
||||
)
|
||||
|
||||
for source in sources:
|
||||
if not _crawler_state.is_running:
|
||||
break
|
||||
await _crawler_instance.crawl_source(source["id"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Crawler error: {e}")
|
||||
|
||||
finally:
|
||||
_crawler_state.is_running = False
|
||||
if _crawler_instance:
|
||||
await _crawler_instance.close()
|
||||
|
||||
_crawler_task = asyncio.create_task(run_crawler())
|
||||
return True
|
||||
|
||||
|
||||
async def stop_crawler() -> bool:
|
||||
"""Stop the crawler."""
|
||||
global _crawler_state, _crawler_task
|
||||
|
||||
if not _crawler_state.is_running:
|
||||
return False
|
||||
|
||||
_crawler_state.is_running = False
|
||||
|
||||
if _crawler_task:
|
||||
_crawler_task.cancel()
|
||||
try:
|
||||
await _crawler_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_crawler_status() -> Dict[str, Any]:
|
||||
"""Get current crawler status."""
|
||||
global _crawler_state
|
||||
return {
|
||||
"is_running": _crawler_state.is_running,
|
||||
"current_source": _crawler_state.current_source_id,
|
||||
"current_bundesland": _crawler_state.current_bundesland,
|
||||
"queue_length": len(_crawler_state.queue),
|
||||
"documents_crawled_today": _crawler_state.documents_crawled_today,
|
||||
"documents_indexed_today": _crawler_state.documents_indexed_today,
|
||||
"errors_today": _crawler_state.errors_today,
|
||||
"last_activity": _crawler_state.last_activity.isoformat() if _crawler_state.last_activity else None,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Zeugnis Crawler - Embedding generation, MinIO upload, and Qdrant indexing.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000")
|
||||
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "test-access-key")
|
||||
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "test-secret-key")
|
||||
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-rag")
|
||||
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
|
||||
|
||||
ZEUGNIS_COLLECTION = "bp_zeugnis"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Generation
|
||||
# =============================================================================
|
||||
|
||||
_embedding_model = None
|
||||
|
||||
|
||||
def get_embedding_model():
|
||||
"""Get or initialize embedding model."""
|
||||
global _embedding_model
|
||||
if _embedding_model is None and EMBEDDING_BACKEND == "local":
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
print("Loaded local embedding model: all-MiniLM-L6-v2")
|
||||
except ImportError:
|
||||
print("Warning: sentence-transformers not installed")
|
||||
return _embedding_model
|
||||
|
||||
|
||||
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a list of texts."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if EMBEDDING_BACKEND == "local":
|
||||
model = get_embedding_model()
|
||||
if model:
|
||||
embeddings = model.encode(texts, show_progress_bar=False)
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
return []
|
||||
|
||||
elif EMBEDDING_BACKEND == "openai":
|
||||
import openai
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: OPENAI_API_KEY not set")
|
||||
return []
|
||||
|
||||
client = openai.AsyncOpenAI(api_key=api_key)
|
||||
response = await client.embeddings.create(
|
||||
input=texts,
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MinIO Storage
|
||||
# =============================================================================
|
||||
|
||||
async def upload_to_minio(
|
||||
content: bytes,
|
||||
bundesland: str,
|
||||
filename: str,
|
||||
content_type: str = "application/pdf",
|
||||
year: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""Upload document to MinIO."""
|
||||
try:
|
||||
from minio import Minio
|
||||
|
||||
client = Minio(
|
||||
MINIO_ENDPOINT,
|
||||
access_key=MINIO_ACCESS_KEY,
|
||||
secret_key=MINIO_SECRET_KEY,
|
||||
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Ensure bucket exists
|
||||
if not client.bucket_exists(MINIO_BUCKET):
|
||||
client.make_bucket(MINIO_BUCKET)
|
||||
|
||||
# Build path
|
||||
year_str = str(year) if year else str(datetime.now().year)
|
||||
object_name = f"landes-daten/{bundesland}/zeugnis/{year_str}/{filename}"
|
||||
|
||||
# Upload
|
||||
client.put_object(
|
||||
MINIO_BUCKET,
|
||||
object_name,
|
||||
io.BytesIO(content),
|
||||
len(content),
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
return object_name
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Qdrant Indexing
|
||||
# =============================================================================
|
||||
|
||||
async def index_in_qdrant(
|
||||
doc_id: str,
|
||||
chunks: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadata: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Index document chunks in Qdrant."""
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance, PointStruct
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL)
|
||||
|
||||
# Ensure collection exists
|
||||
collections = client.get_collections().collections
|
||||
if not any(c.name == ZEUGNIS_COLLECTION for c in collections):
|
||||
vector_size = len(embeddings[0]) if embeddings else 384
|
||||
client.create_collection(
|
||||
collection_name=ZEUGNIS_COLLECTION,
|
||||
vectors_config=VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
)
|
||||
print(f"Created Qdrant collection: {ZEUGNIS_COLLECTION}")
|
||||
|
||||
# Create points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
"document_id": doc_id,
|
||||
"chunk_index": i,
|
||||
"chunk_text": chunk[:500], # Store first 500 chars for preview
|
||||
"bundesland": metadata.get("bundesland"),
|
||||
"doc_type": metadata.get("doc_type"),
|
||||
"title": metadata.get("title"),
|
||||
"source_url": metadata.get("url"),
|
||||
"training_allowed": metadata.get("training_allowed", False),
|
||||
"indexed_at": datetime.now().isoformat(),
|
||||
}
|
||||
))
|
||||
|
||||
# Upsert
|
||||
if points:
|
||||
client.upsert(
|
||||
collection_name=ZEUGNIS_COLLECTION,
|
||||
points=points,
|
||||
)
|
||||
|
||||
return len(points)
|
||||
except Exception as e:
|
||||
print(f"Qdrant indexing failed: {e}")
|
||||
return 0
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Zeugnis Crawler - Text extraction, chunking, and hashing utilities.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
CHUNK_SIZE = 1000
|
||||
CHUNK_OVERLAP = 200
|
||||
|
||||
|
||||
def extract_text_from_pdf(content: bytes) -> str:
|
||||
"""Extract text from PDF bytes."""
|
||||
try:
|
||||
from PyPDF2 import PdfReader
|
||||
import io
|
||||
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
text_parts = []
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
return "\n\n".join(text_parts)
|
||||
except Exception as e:
|
||||
print(f"PDF extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def extract_text_from_html(content: bytes, encoding: str = "utf-8") -> str:
|
||||
"""Extract text from HTML bytes."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
html = content.decode(encoding, errors="replace")
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Remove script and style elements
|
||||
for element in soup(["script", "style", "nav", "header", "footer"]):
|
||||
element.decompose()
|
||||
|
||||
# Get text
|
||||
text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
# Clean up whitespace
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
print(f"HTML extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
|
||||
"""Split text into overlapping chunks."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
separators = ["\n\n", "\n", ". ", " "]
|
||||
|
||||
def split_recursive(text: str, sep_index: int = 0) -> List[str]:
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
if sep_index >= len(separators):
|
||||
# Force split at chunk_size
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
chunk = text[i:i + chunk_size]
|
||||
if chunk.strip():
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
sep = separators[sep_index]
|
||||
parts = text.split(sep)
|
||||
result = []
|
||||
current = ""
|
||||
|
||||
for part in parts:
|
||||
if len(current) + len(sep) + len(part) <= chunk_size:
|
||||
current = current + sep + part if current else part
|
||||
else:
|
||||
if current.strip():
|
||||
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
|
||||
current = part
|
||||
|
||||
if current.strip():
|
||||
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
|
||||
|
||||
return result
|
||||
|
||||
chunks = split_recursive(text)
|
||||
|
||||
# Add overlap
|
||||
if overlap > 0 and len(chunks) > 1:
|
||||
overlapped = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i > 0:
|
||||
# Add end of previous chunk
|
||||
prev_end = chunks[i - 1][-overlap:]
|
||||
chunk = prev_end + chunk
|
||||
overlapped.append(chunk)
|
||||
chunks = overlapped
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def compute_hash(content: bytes) -> str:
|
||||
"""Compute SHA-256 hash of content."""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Zeugnis Crawler - ZeugnisCrawler worker class and CrawlerState.
|
||||
|
||||
Crawls official government documents about school certificates from
|
||||
all 16 German federal states. Only indexes documents where AI training
|
||||
is legally permitted.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from zeugnis_models import generate_id
|
||||
from zeugnis_text import (
|
||||
extract_text_from_pdf,
|
||||
extract_text_from_html,
|
||||
chunk_text,
|
||||
compute_hash,
|
||||
)
|
||||
from zeugnis_storage import (
|
||||
upload_to_minio,
|
||||
generate_embeddings,
|
||||
index_in_qdrant,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 5 # seconds
|
||||
REQUEST_TIMEOUT = 30 # seconds
|
||||
USER_AGENT = "BreakPilot-Zeugnis-Crawler/1.0 (Educational Research)"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler State
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CrawlerState:
|
||||
"""Global crawler state."""
|
||||
is_running: bool = False
|
||||
current_source_id: Optional[str] = None
|
||||
current_bundesland: Optional[str] = None
|
||||
queue: List[Dict] = field(default_factory=list)
|
||||
documents_crawled_today: int = 0
|
||||
documents_indexed_today: int = 0
|
||||
errors_today: int = 0
|
||||
last_activity: Optional[datetime] = None
|
||||
|
||||
|
||||
_crawler_state = CrawlerState()
|
||||
|
||||
|
||||
def get_crawler_state() -> CrawlerState:
|
||||
"""Get the global crawler state."""
|
||||
return _crawler_state
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawler Worker
|
||||
# =============================================================================
|
||||
|
||||
class ZeugnisCrawler:
|
||||
"""Rights-aware crawler for zeugnis documents."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.db_pool = None
|
||||
|
||||
async def init(self):
|
||||
"""Initialize crawler resources."""
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": USER_AGENT},
|
||||
)
|
||||
|
||||
# Initialize database connection
|
||||
try:
|
||||
from metrics_db import get_pool
|
||||
self.db_pool = await get_pool()
|
||||
except Exception as e:
|
||||
print(f"Failed to get database pool: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""Close crawler resources."""
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def fetch_url(self, url: str) -> Tuple[Optional[bytes], Optional[str]]:
|
||||
"""Fetch URL with retry logic."""
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
return response.content, content_type
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(f"HTTP error {e.response.status_code} for {url}")
|
||||
if e.response.status_code == 404:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed for {url}: {e}")
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
|
||||
return None, None
|
||||
|
||||
async def crawl_seed_url(
|
||||
self,
|
||||
seed_url_id: str,
|
||||
url: str,
|
||||
bundesland: str,
|
||||
doc_type: str,
|
||||
training_allowed: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Crawl a single seed URL."""
|
||||
global _crawler_state
|
||||
|
||||
result = {
|
||||
"seed_url_id": seed_url_id,
|
||||
"url": url,
|
||||
"success": False,
|
||||
"document_id": None,
|
||||
"indexed": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
try:
|
||||
# Fetch content
|
||||
content, content_type = await self.fetch_url(url)
|
||||
if not content:
|
||||
result["error"] = "Failed to fetch URL"
|
||||
return result
|
||||
|
||||
# Determine file type
|
||||
is_pdf = "pdf" in content_type.lower() or url.lower().endswith(".pdf")
|
||||
|
||||
# Extract text
|
||||
if is_pdf:
|
||||
text = extract_text_from_pdf(content)
|
||||
filename = url.split("/")[-1] or f"document_{seed_url_id}.pdf"
|
||||
else:
|
||||
text = extract_text_from_html(content)
|
||||
filename = f"document_{seed_url_id}.html"
|
||||
|
||||
if not text:
|
||||
result["error"] = "No text extracted"
|
||||
return result
|
||||
|
||||
# Compute hash for versioning
|
||||
content_hash = compute_hash(content)
|
||||
|
||||
# Upload to MinIO
|
||||
minio_path = await upload_to_minio(
|
||||
content,
|
||||
bundesland,
|
||||
filename,
|
||||
content_type=content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = generate_id()
|
||||
|
||||
# Store document in database
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_documents
|
||||
(id, seed_url_id, title, url, content_hash, minio_path,
|
||||
training_allowed, file_size, content_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
doc_id, seed_url_id, filename, url, content_hash,
|
||||
minio_path, training_allowed, len(content), content_type
|
||||
)
|
||||
|
||||
result["document_id"] = doc_id
|
||||
result["success"] = True
|
||||
_crawler_state.documents_crawled_today += 1
|
||||
|
||||
# Only index if training is allowed
|
||||
if training_allowed:
|
||||
chunks = chunk_text(text)
|
||||
if chunks:
|
||||
embeddings = await generate_embeddings(chunks)
|
||||
if embeddings:
|
||||
indexed_count = await index_in_qdrant(
|
||||
doc_id,
|
||||
chunks,
|
||||
embeddings,
|
||||
{
|
||||
"bundesland": bundesland,
|
||||
"doc_type": doc_type,
|
||||
"title": filename,
|
||||
"url": url,
|
||||
"training_allowed": True,
|
||||
}
|
||||
)
|
||||
if indexed_count > 0:
|
||||
result["indexed"] = True
|
||||
_crawler_state.documents_indexed_today += 1
|
||||
|
||||
# Update database
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_documents SET indexed_in_qdrant = true WHERE id = $1",
|
||||
doc_id
|
||||
)
|
||||
else:
|
||||
result["indexed"] = False
|
||||
result["error"] = "Training not allowed for this source"
|
||||
|
||||
_crawler_state.last_activity = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
_crawler_state.errors_today += 1
|
||||
|
||||
return result
|
||||
|
||||
async def crawl_source(self, source_id: str) -> Dict[str, Any]:
|
||||
"""Crawl all seed URLs for a source."""
|
||||
global _crawler_state
|
||||
|
||||
result = {
|
||||
"source_id": source_id,
|
||||
"documents_found": 0,
|
||||
"documents_indexed": 0,
|
||||
"errors": [],
|
||||
"started_at": datetime.now(),
|
||||
"completed_at": None,
|
||||
}
|
||||
|
||||
if not self.db_pool:
|
||||
result["errors"].append("Database not available")
|
||||
return result
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
# Get source info
|
||||
source = await conn.fetchrow(
|
||||
"SELECT * FROM zeugnis_sources WHERE id = $1",
|
||||
source_id
|
||||
)
|
||||
if not source:
|
||||
result["errors"].append(f"Source not found: {source_id}")
|
||||
return result
|
||||
|
||||
bundesland = source["bundesland"]
|
||||
training_allowed = source["training_allowed"]
|
||||
|
||||
_crawler_state.current_source_id = source_id
|
||||
_crawler_state.current_bundesland = bundesland
|
||||
|
||||
# Get seed URLs
|
||||
seed_urls = await conn.fetch(
|
||||
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 AND status != 'completed'",
|
||||
source_id
|
||||
)
|
||||
|
||||
for seed_url in seed_urls:
|
||||
# Update status to running
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'running' WHERE id = $1",
|
||||
seed_url["id"]
|
||||
)
|
||||
|
||||
# Crawl
|
||||
crawl_result = await self.crawl_seed_url(
|
||||
seed_url["id"],
|
||||
seed_url["url"],
|
||||
bundesland,
|
||||
seed_url["doc_type"],
|
||||
training_allowed,
|
||||
)
|
||||
|
||||
# Update status
|
||||
if crawl_result["success"]:
|
||||
result["documents_found"] += 1
|
||||
if crawl_result["indexed"]:
|
||||
result["documents_indexed"] += 1
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'completed', last_crawled = NOW() WHERE id = $1",
|
||||
seed_url["id"]
|
||||
)
|
||||
else:
|
||||
result["errors"].append(f"{seed_url['url']}: {crawl_result['error']}")
|
||||
await conn.execute(
|
||||
"UPDATE zeugnis_seed_urls SET status = 'failed', error_message = $2 WHERE id = $1",
|
||||
seed_url["id"], crawl_result["error"]
|
||||
)
|
||||
|
||||
# Small delay between requests
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(str(e))
|
||||
|
||||
finally:
|
||||
result["completed_at"] = datetime.now()
|
||||
_crawler_state.current_source_id = None
|
||||
_crawler_state.current_bundesland = None
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user