[split-required] Split 500-850 LOC files (batch 2)

backend-lehrer (10 files):
- game/database.py (785 → 5), correction_api.py (683 → 4)
- classroom_engine/antizipation.py (676 → 5)
- llm_gateway schools/edu_search already done in prior batch

klausur-service (12 files):
- orientation_crop_api.py (694 → 5), pdf_export.py (677 → 4)
- zeugnis_crawler.py (676 → 5), grid_editor_api.py (671 → 5)
- eh_templates.py (658 → 5), mail/api.py (651 → 5)
- qdrant_service.py (638 → 5), training_api.py (625 → 4)

website (6 pages):
- middleware (696 → 8), mail (733 → 6), consent (628 → 8)
- compliance/risks (622 → 5), export (502 → 5), brandbook (629 → 7)

studio-v2 (3 components):
- B2BMigrationWizard (848 → 3), CleanupPanel (765 → 2)
- dashboard-experimental (739 → 2)

admin-lehrer (4 files):
- uebersetzungen (769 → 4), manager (670 → 2)
- ChunkBrowserQA (675 → 6), dsfa/page (674 → 5)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 08:24:01 +02:00
parent 34da9f4cda
commit b4613e26f3
118 changed files with 15258 additions and 14680 deletions
+290
View File
@@ -0,0 +1,290 @@
"""
Crop API endpoints (Step 4 / UI index 3 of OCR Pipeline).
Auto-crop, manual crop, and skip-crop for scanner/book borders.
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from page_crop import detect_and_crop_page, detect_page_splits
from ocr_pipeline_session_store import get_sub_sessions, update_session_db
from orientation_crop_helpers import ensure_cached, append_pipeline_log
from page_sub_sessions import create_page_sub_sessions
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/crop")
async def auto_crop(session_id: str):
"""Auto-detect and crop scanner/book borders.
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
Falls back to oriented -> original if earlier steps were skipped.
If the image is a multi-page spread (e.g. book on scanner), it will
automatically split into separate sub-sessions per page, crop each
individually, and return the split info.
"""
cached = await ensure_cached(session_id)
# Use dewarped (preferred), fall back to oriented, then original
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
t0 = time.time()
# --- Check for existing sub-sessions (from page-split step) ---
# If page-split already created sub-sessions, skip multi-page detection
# in the crop step. Each sub-session runs its own crop independently.
existing_subs = await get_sub_sessions(session_id)
if existing_subs:
crop_result = cached.get("crop_result") or {}
if crop_result.get("multi_page"):
# Already split -- just return the existing info
duration = time.time() - t0
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"sub_sessions": [
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
for i, s in enumerate(existing_subs)
],
"note": "Page split was already performed; each sub-session runs its own crop.",
}
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
page_splits = detect_page_splits(img_bgr)
if page_splits and len(page_splits) >= 2:
# Multi-page spread detected -- create sub-sessions
sub_sessions = await create_page_sub_sessions(
session_id, cached, img_bgr, page_splits,
)
duration = time.time() - t0
crop_info: Dict[str, Any] = {
"crop_applied": True,
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"duration_seconds": round(duration, 2),
}
cached["crop_result"] = crop_info
# Store the first page as the main cropped image for backward compat
first_page = page_splits[0]
first_bgr = img_bgr[
first_page["y"]:first_page["y"] + first_page["height"],
first_page["x"]:first_page["x"] + first_page["width"],
].copy()
first_cropped, _ = detect_and_crop_page(first_bgr)
cached["cropped_bgr"] = first_cropped
ok, png_buf = cv2.imencode(".png", first_cropped)
await update_session_db(
session_id,
cropped_png=png_buf.tobytes() if ok else b"",
crop_result=crop_info,
current_step=5,
status='split',
)
logger.info(
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
session_id, len(page_splits), duration,
)
await append_pipeline_log(session_id, "crop", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
h, w = first_cropped.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
"sub_sessions": sub_sessions,
}
# --- Single page (normal) ---
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
duration = time.time() - t0
crop_info["duration_seconds"] = round(duration, 2)
crop_info["multi_page"] = False
# Encode cropped image
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
# Update cache
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_info
# Persist to DB
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_info,
current_step=5,
)
logger.info(
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
session_id, crop_info["crop_applied"],
crop_info.get("detected_format", "?"),
duration,
)
await append_pipeline_log(session_id, "crop", {
"crop_applied": crop_info["crop_applied"],
"detected_format": crop_info.get("detected_format"),
"format_confidence": crop_info.get("format_confidence"),
}, duration_ms=int(duration * 1000))
h, w = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
class ManualCropRequest(BaseModel):
x: float # percentage 0-100
y: float # percentage 0-100
width: float # percentage 0-100
height: float # percentage 0-100
@router.post("/sessions/{session_id}/crop/manual")
async def manual_crop(session_id: str, req: ManualCropRequest):
"""Manually crop using percentage coordinates."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
h, w = img_bgr.shape[:2]
# Convert percentages to pixels
px_x = int(w * req.x / 100.0)
px_y = int(h * req.y / 100.0)
px_w = int(w * req.width / 100.0)
px_h = int(h * req.height / 100.0)
# Clamp
px_x = max(0, min(px_x, w - 1))
px_y = max(0, min(px_y, h - 1))
px_w = max(1, min(px_w, w - px_x))
px_h = max(1, min(px_h, h - px_y))
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": True,
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
"width": round(req.width, 2), "height": round(req.height, 2)},
"original_size": {"width": w, "height": h},
"cropped_size": {"width": px_w, "height": px_h},
"method": "manual",
}
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
ch, cw = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": cw,
"image_height": ch,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
@router.post("/sessions/{session_id}/crop/skip")
async def skip_crop(session_id: str):
"""Skip cropping -- use dewarped (or oriented/original) image as-is."""
cached = await ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available")
h, w = img_bgr.shape[:2]
# Store the dewarped image as cropped (identity crop)
success, png_buf = cv2.imencode(".png", img_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": False,
"skipped": True,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
}
cached["cropped_bgr"] = img_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
+28 -652
View File
@@ -1,658 +1,34 @@
"""
Erwartungshorizont Templates for Vorabitur Mode
Erwartungshorizont Templates for Vorabitur Mode — barrel re-export.
Provides pre-defined templates based on German Abitur text analysis types:
- Textanalyse (pragmatische Texte)
- Sachtextanalyse
- Gedichtanalyse / Lyrikinterpretation
- Dramenanalyse
- Epische Textanalyse / Prosaanalyse
- Eroerterung (textgebunden / frei)
- Literarische Eroerterung
- Materialgestuetztes Schreiben
Each template includes:
- Structured criteria with weights
- Typical expectations per section
- NiBiS-aligned evaluation points
The actual code lives in:
- eh_templates_types.py (AUFGABENTYPEN, EHKriterium, EHTemplate)
- eh_templates_analyse.py (Textanalyse, Gedicht, Prosa, Drama)
- eh_templates_eroerterung.py (Eroerterung textgebunden)
- eh_templates_registry.py (TEMPLATES, get_template, list_templates, etc.)
"""
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
import uuid
# Types
from eh_templates_types import ( # noqa: F401
AUFGABENTYPEN,
EHKriterium,
EHTemplate,
)
# Template factories
from eh_templates_analyse import ( # noqa: F401
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from eh_templates_eroerterung import get_eroerterung_template # noqa: F401
# =============================================
# TEMPLATE TYPES
# =============================================
AUFGABENTYPEN = {
"textanalyse_pragmatisch": {
"name": "Textanalyse (pragmatische Texte)",
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
"category": "analyse"
},
"sachtextanalyse": {
"name": "Sachtextanalyse",
"description": "Analyse von informativen und appellativen Sachtexten",
"category": "analyse"
},
"gedichtanalyse": {
"name": "Gedichtanalyse / Lyrikinterpretation",
"description": "Analyse und Interpretation lyrischer Texte",
"category": "interpretation"
},
"dramenanalyse": {
"name": "Dramenanalyse",
"description": "Analyse dramatischer Texte und Szenen",
"category": "interpretation"
},
"prosaanalyse": {
"name": "Epische Textanalyse / Prosaanalyse",
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
"category": "interpretation"
},
"eroerterung_textgebunden": {
"name": "Textgebundene Eroerterung",
"description": "Eroerterung auf Basis eines Sachtextes",
"category": "argumentation"
},
"eroerterung_frei": {
"name": "Freie Eroerterung",
"description": "Freie Eroerterung zu einem Thema",
"category": "argumentation"
},
"eroerterung_literarisch": {
"name": "Literarische Eroerterung",
"description": "Eroerterung zu literarischen Fragestellungen",
"category": "argumentation"
},
"materialgestuetzt": {
"name": "Materialgestuetztes Schreiben",
"description": "Verfassen eines Textes auf Materialbasis",
"category": "produktion"
}
}
# =============================================
# TEMPLATE STRUCTURES
# =============================================
@dataclass
class EHKriterium:
"""Single criterion in an Erwartungshorizont."""
id: str
name: str
beschreibung: str
gewichtung: int # Percentage weight (0-100)
erwartungen: List[str] # Expected points/elements
max_punkte: int = 100
def to_dict(self):
return asdict(self)
@dataclass
class EHTemplate:
"""Complete Erwartungshorizont template."""
id: str
aufgabentyp: str
name: str
beschreibung: str
kriterien: List[EHKriterium]
einleitung_hinweise: List[str]
hauptteil_hinweise: List[str]
schluss_hinweise: List[str]
sprachliche_aspekte: List[str]
created_at: datetime = field(default_factory=lambda: datetime.now())
def to_dict(self):
d = {
'id': self.id,
'aufgabentyp': self.aufgabentyp,
'name': self.name,
'beschreibung': self.beschreibung,
'kriterien': [k.to_dict() for k in self.kriterien],
'einleitung_hinweise': self.einleitung_hinweise,
'hauptteil_hinweise': self.hauptteil_hinweise,
'schluss_hinweise': self.schluss_hinweise,
'sprachliche_aspekte': self.sprachliche_aspekte,
'created_at': self.created_at.isoformat()
}
return d
# =============================================
# PRE-DEFINED TEMPLATES
# =============================================
def get_textanalyse_template() -> EHTemplate:
"""Template for pragmatic text analysis."""
return EHTemplate(
id="template_textanalyse_pragmatisch",
aufgabentyp="textanalyse_pragmatisch",
name="Textanalyse pragmatischer Texte",
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Wiedergabe des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Textaussage/These",
"Vollstaendige Wiedergabe der Argumentationsstruktur",
"Erkennen von Intention und Adressatenbezug",
"Einordnung in den historischen/gesellschaftlichen Kontext",
"Beruecksichtigung aller relevanten Textaspekte"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau und Gliederung der Analyse",
gewichtung=15,
erwartungen=[
"Sinnvolle Einleitung mit Basisinformationen",
"Logische Gliederung des Hauptteils",
"Stringente Gedankenfuehrung",
"Angemessener Schluss mit Fazit/Wertung",
"Absatzgliederung und Ueberlaenge"
]
),
EHKriterium(
id="analyse",
name="Analytische Qualitaet",
beschreibung="Tiefe und Qualitaet der Analyse",
gewichtung=15,
erwartungen=[
"Erkennen rhetorischer Mittel",
"Funktionale Deutung der Stilmittel",
"Analyse der Argumentationsweise",
"Beruecksichtigung von Wortwahl und Satzbau",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung",
"Korrekte Fremdwortschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung",
"Korrekte Bezuege und Kongruenz"
]
)
],
einleitung_hinweise=[
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
"Benennung des Themas",
"Formulierung der Kernthese/Hauptaussage",
"Ggf. Einordnung in den Kontext"
],
hauptteil_hinweise=[
"Systematische Analyse der Argumentationsstruktur",
"Untersuchung der sprachlichen Gestaltung",
"Funktionale Deutung der Stilmittel",
"Beruecksichtigung von Adressatenbezug und Intention",
"Textbelege durch Zitate"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bewertung der Ueberzeugungskraft",
"Ggf. aktuelle Relevanz",
"Persoenliche Stellungnahme (wenn gefordert)"
],
sprachliche_aspekte=[
"Fachsprachliche Begriffe korrekt verwenden",
"Konjunktiv fuer indirekte Rede",
"Praesens als Tempus der Analyse",
"Sachlicher, analytischer Stil"
]
)
def get_gedichtanalyse_template() -> EHTemplate:
"""Template for poetry analysis."""
return EHTemplate(
id="template_gedichtanalyse",
aufgabentyp="gedichtanalyse",
name="Gedichtanalyse / Lyrikinterpretation",
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Gedichtinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
"Vollstaendige inhaltliche Erschliessung aller Strophen",
"Erkennen der zentralen Motive und Themen",
"Epochenzuordnung und literaturgeschichtliche Einordnung",
"Deutung der Bildlichkeit und Symbolik"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Interpretation",
gewichtung=15,
erwartungen=[
"Einleitung mit Basisinformationen",
"Systematische strophenweise oder aspektorientierte Analyse",
"Verknuepfung von Form- und Inhaltsanalyse",
"Schluessige Gesamtdeutung im Schluss"
]
),
EHKriterium(
id="formanalyse",
name="Formale Analyse",
beschreibung="Analyse der lyrischen Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung von Metrum und Reimschema",
"Analyse der Klanggestaltung",
"Erkennen von Enjambements und Zaesuren",
"Deutung der formalen Mittel",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Entstehungsjahr/Epoche",
"Thema/Motiv des Gedichts",
"Erste Deutungshypothese",
"Formale Grunddaten (Strophen, Verse)"
],
hauptteil_hinweise=[
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
"Formale Analyse (Metrum, Reim, Klang)",
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
"Funktionale Verknuepfung aller Ebenen",
"Textbelege durch Zitate mit Versangabe"
],
schluss_hinweise=[
"Zusammenfassung der Interpretationsergebnisse",
"Bestaetigung/Modifikation der Deutungshypothese",
"Einordnung in Epoche/Werk des Autors",
"Aktualitaetsbezug (wenn sinnvoll)"
],
sprachliche_aspekte=[
"Fachbegriffe der Lyrikanalyse verwenden",
"Zwischen lyrischem Ich und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende statt beschreibende Formulierungen"
]
)
def get_eroerterung_template() -> EHTemplate:
"""Template for textgebundene Eroerterung."""
return EHTemplate(
id="template_eroerterung_textgebunden",
aufgabentyp="eroerterung_textgebunden",
name="Textgebundene Eroerterung",
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Qualitaet der Argumentation",
gewichtung=40,
erwartungen=[
"Korrekte Wiedergabe der Textposition",
"Differenzierte eigene Argumentation",
"Vielfaeltige und ueberzeugende Argumente",
"Beruecksichtigung von Pro und Contra",
"Sinnvolle Beispiele und Belege",
"Eigenstaendige Schlussfolgerung"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Eroerterung",
gewichtung=15,
erwartungen=[
"Problemorientierte Einleitung",
"Klare Gliederung der Argumentation",
"Logische Argumentationsfolge",
"Sinnvolle Ueberlaetze",
"Begruendetes Fazit"
]
),
EHKriterium(
id="textbezug",
name="Textbezug",
beschreibung="Verknuepfung mit dem Ausgangstext",
gewichtung=15,
erwartungen=[
"Angemessene Textwiedergabe",
"Kritische Auseinandersetzung mit Textposition",
"Korrekte Zitierweise",
"Verknuepfung eigener Argumente mit Text"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung",
"Variationsreicher Ausdruck"
]
)
],
einleitung_hinweise=[
"Hinfuehrung zum Thema",
"Nennung des Ausgangstextes",
"Formulierung der Leitfrage/These",
"Ueberleitung zum Hauptteil"
],
hauptteil_hinweise=[
"Kurze Wiedergabe der Textposition",
"Systematische Argumentation (dialektisch oder linear)",
"Jedes Argument: These - Begruendung - Beispiel",
"Gewichtung der Argumente",
"Verknuepfung mit Textposition"
],
schluss_hinweise=[
"Zusammenfassung der wichtigsten Argumente",
"Eigene begruendete Stellungnahme",
"Ggf. Ausblick oder Appell"
],
sprachliche_aspekte=[
"Argumentative Konnektoren verwenden",
"Sachlicher, ueberzeugender Stil",
"Eigene Meinung kennzeichnen",
"Konjunktiv fuer Textpositionen"
]
)
def get_prosaanalyse_template() -> EHTemplate:
"""Template for prose/narrative text analysis."""
return EHTemplate(
id="template_prosaanalyse",
aufgabentyp="prosaanalyse",
name="Epische Textanalyse / Prosaanalyse",
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Charakterisierung der Figuren",
"Erkennen der Erzaehlsituation",
"Deutung der Konflikte und Motive",
"Einordnung in den Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Informative Einleitung",
"Systematische Analyse im Hauptteil",
"Verknuepfung der Analyseergebnisse",
"Schluessige Gesamtdeutung"
]
),
EHKriterium(
id="erzaehltechnik",
name="Erzaehltechnische Analyse",
beschreibung="Analyse narrativer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung der Erzaehlperspektive",
"Analyse von Zeitgestaltung",
"Raumgestaltung und Atmosphaere",
"Figurenrede und Bewusstseinsdarstellung",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Textsorte, Erscheinungsjahr",
"Einordnung des Auszugs in den Gesamttext",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Kurze Inhaltsangabe des Auszugs",
"Analyse der Handlungsstruktur",
"Figurenanalyse mit Textbelegen",
"Erzaehltechnische Analyse",
"Sprachliche Analyse",
"Verknuepfung aller Ebenen"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bestaetigung der Deutungshypothese",
"Bedeutung fuer Gesamtwerk",
"Ggf. Aktualitaetsbezug"
],
sprachliche_aspekte=[
"Fachbegriffe der Erzaehltextanalyse",
"Zwischen Erzaehler und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende Formulierungen"
]
)
def get_dramenanalyse_template() -> EHTemplate:
"""Template for drama analysis."""
return EHTemplate(
id="template_dramenanalyse",
aufgabentyp="dramenanalyse",
name="Dramenanalyse",
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Szeneninhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Analyse der Figurenkonstellation",
"Erkennen des dramatischen Konflikts",
"Einordnung in den Handlungsverlauf",
"Deutung der Szene im Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Einleitung mit Kontextualisierung",
"Systematische Szenenanalyse",
"Verknuepfung der Analyseergebnisse",
"Schluessige Deutung"
]
),
EHKriterium(
id="dramentechnik",
name="Dramentechnische Analyse",
beschreibung="Analyse dramatischer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Analyse der Dialoggestaltung",
"Regieanweisungen und Buehnenraum",
"Dramatische Spannung",
"Monolog/Dialog-Formen",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Uraufführungsjahr, Dramenform",
"Einordnung der Szene in den Handlungsverlauf",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Situierung der Szene",
"Analyse des Dialogverlaufs",
"Figurenanalyse im Dialog",
"Sprachliche Analyse",
"Dramentechnische Mittel",
"Bedeutung fuer den Konflikt"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Funktion der Szene im Drama",
"Bedeutung fuer die Gesamtdeutung"
],
sprachliche_aspekte=[
"Fachbegriffe der Dramenanalyse",
"Praesens als Analysetempus",
"Korrekte Zitierweise mit Akt/Szene/Zeile"
]
)
# =============================================
# TEMPLATE REGISTRY
# =============================================
TEMPLATES: Dict[str, EHTemplate] = {}
def initialize_templates():
"""Initialize all pre-defined templates."""
global TEMPLATES
TEMPLATES = {
"textanalyse_pragmatisch": get_textanalyse_template(),
"gedichtanalyse": get_gedichtanalyse_template(),
"eroerterung_textgebunden": get_eroerterung_template(),
"prosaanalyse": get_prosaanalyse_template(),
"dramenanalyse": get_dramenanalyse_template(),
}
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
"""Get a template by Aufgabentyp."""
if not TEMPLATES:
initialize_templates()
return TEMPLATES.get(aufgabentyp)
def list_templates() -> List[Dict]:
"""List all available templates."""
if not TEMPLATES:
initialize_templates()
return [
{
"aufgabentyp": typ,
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
}
for typ in TEMPLATES.keys()
]
def get_aufgabentypen() -> Dict:
"""Get all Aufgabentypen definitions."""
return AUFGABENTYPEN
# Initialize on import
initialize_templates()
# Registry
from eh_templates_registry import ( # noqa: F401
TEMPLATES,
initialize_templates,
get_template,
list_templates,
get_aufgabentypen,
)
@@ -0,0 +1,395 @@
"""
Erwartungshorizont Templates — Analyse templates.
Contains templates for:
- Textanalyse (pragmatische Texte)
- Gedichtanalyse / Lyrikinterpretation
- Prosaanalyse
- Dramenanalyse
"""
from eh_templates_types import EHTemplate, EHKriterium
def get_textanalyse_template() -> EHTemplate:
"""Template for pragmatic text analysis."""
return EHTemplate(
id="template_textanalyse_pragmatisch",
aufgabentyp="textanalyse_pragmatisch",
name="Textanalyse pragmatischer Texte",
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Wiedergabe des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Textaussage/These",
"Vollstaendige Wiedergabe der Argumentationsstruktur",
"Erkennen von Intention und Adressatenbezug",
"Einordnung in den historischen/gesellschaftlichen Kontext",
"Beruecksichtigung aller relevanten Textaspekte"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau und Gliederung der Analyse",
gewichtung=15,
erwartungen=[
"Sinnvolle Einleitung mit Basisinformationen",
"Logische Gliederung des Hauptteils",
"Stringente Gedankenfuehrung",
"Angemessener Schluss mit Fazit/Wertung",
"Absatzgliederung und Ueberlaenge"
]
),
EHKriterium(
id="analyse",
name="Analytische Qualitaet",
beschreibung="Tiefe und Qualitaet der Analyse",
gewichtung=15,
erwartungen=[
"Erkennen rhetorischer Mittel",
"Funktionale Deutung der Stilmittel",
"Analyse der Argumentationsweise",
"Beruecksichtigung von Wortwahl und Satzbau",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung",
"Korrekte Fremdwortschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung",
"Korrekte Bezuege und Kongruenz"
]
)
],
einleitung_hinweise=[
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
"Benennung des Themas",
"Formulierung der Kernthese/Hauptaussage",
"Ggf. Einordnung in den Kontext"
],
hauptteil_hinweise=[
"Systematische Analyse der Argumentationsstruktur",
"Untersuchung der sprachlichen Gestaltung",
"Funktionale Deutung der Stilmittel",
"Beruecksichtigung von Adressatenbezug und Intention",
"Textbelege durch Zitate"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bewertung der Ueberzeugungskraft",
"Ggf. aktuelle Relevanz",
"Persoenliche Stellungnahme (wenn gefordert)"
],
sprachliche_aspekte=[
"Fachsprachliche Begriffe korrekt verwenden",
"Konjunktiv fuer indirekte Rede",
"Praesens als Tempus der Analyse",
"Sachlicher, analytischer Stil"
]
)
def get_gedichtanalyse_template() -> EHTemplate:
"""Template for poetry analysis."""
return EHTemplate(
id="template_gedichtanalyse",
aufgabentyp="gedichtanalyse",
name="Gedichtanalyse / Lyrikinterpretation",
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Gedichtinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
"Vollstaendige inhaltliche Erschliessung aller Strophen",
"Erkennen der zentralen Motive und Themen",
"Epochenzuordnung und literaturgeschichtliche Einordnung",
"Deutung der Bildlichkeit und Symbolik"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Interpretation",
gewichtung=15,
erwartungen=[
"Einleitung mit Basisinformationen",
"Systematische strophenweise oder aspektorientierte Analyse",
"Verknuepfung von Form- und Inhaltsanalyse",
"Schluessige Gesamtdeutung im Schluss"
]
),
EHKriterium(
id="formanalyse",
name="Formale Analyse",
beschreibung="Analyse der lyrischen Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung von Metrum und Reimschema",
"Analyse der Klanggestaltung",
"Erkennen von Enjambements und Zaesuren",
"Deutung der formalen Mittel",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Entstehungsjahr/Epoche",
"Thema/Motiv des Gedichts",
"Erste Deutungshypothese",
"Formale Grunddaten (Strophen, Verse)"
],
hauptteil_hinweise=[
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
"Formale Analyse (Metrum, Reim, Klang)",
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
"Funktionale Verknuepfung aller Ebenen",
"Textbelege durch Zitate mit Versangabe"
],
schluss_hinweise=[
"Zusammenfassung der Interpretationsergebnisse",
"Bestaetigung/Modifikation der Deutungshypothese",
"Einordnung in Epoche/Werk des Autors",
"Aktualitaetsbezug (wenn sinnvoll)"
],
sprachliche_aspekte=[
"Fachbegriffe der Lyrikanalyse verwenden",
"Zwischen lyrischem Ich und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende statt beschreibende Formulierungen"
]
)
def get_prosaanalyse_template() -> EHTemplate:
"""Template for prose/narrative text analysis."""
return EHTemplate(
id="template_prosaanalyse",
aufgabentyp="prosaanalyse",
name="Epische Textanalyse / Prosaanalyse",
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Charakterisierung der Figuren",
"Erkennen der Erzaehlsituation",
"Deutung der Konflikte und Motive",
"Einordnung in den Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Informative Einleitung",
"Systematische Analyse im Hauptteil",
"Verknuepfung der Analyseergebnisse",
"Schluessige Gesamtdeutung"
]
),
EHKriterium(
id="erzaehltechnik",
name="Erzaehltechnische Analyse",
beschreibung="Analyse narrativer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung der Erzaehlperspektive",
"Analyse von Zeitgestaltung",
"Raumgestaltung und Atmosphaere",
"Figurenrede und Bewusstseinsdarstellung",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Textsorte, Erscheinungsjahr",
"Einordnung des Auszugs in den Gesamttext",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Kurze Inhaltsangabe des Auszugs",
"Analyse der Handlungsstruktur",
"Figurenanalyse mit Textbelegen",
"Erzaehltechnische Analyse",
"Sprachliche Analyse",
"Verknuepfung aller Ebenen"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bestaetigung der Deutungshypothese",
"Bedeutung fuer Gesamtwerk",
"Ggf. Aktualitaetsbezug"
],
sprachliche_aspekte=[
"Fachbegriffe der Erzaehltextanalyse",
"Zwischen Erzaehler und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende Formulierungen"
]
)
def get_dramenanalyse_template() -> EHTemplate:
"""Template for drama analysis."""
return EHTemplate(
id="template_dramenanalyse",
aufgabentyp="dramenanalyse",
name="Dramenanalyse",
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Szeneninhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Analyse der Figurenkonstellation",
"Erkennen des dramatischen Konflikts",
"Einordnung in den Handlungsverlauf",
"Deutung der Szene im Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Einleitung mit Kontextualisierung",
"Systematische Szenenanalyse",
"Verknuepfung der Analyseergebnisse",
"Schluessige Deutung"
]
),
EHKriterium(
id="dramentechnik",
name="Dramentechnische Analyse",
beschreibung="Analyse dramatischer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Analyse der Dialoggestaltung",
"Regieanweisungen und Buehnenraum",
"Dramatische Spannung",
"Monolog/Dialog-Formen",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Urauffuehrungsjahr, Dramenform",
"Einordnung der Szene in den Handlungsverlauf",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Situierung der Szene",
"Analyse des Dialogverlaufs",
"Figurenanalyse im Dialog",
"Sprachliche Analyse",
"Dramentechnische Mittel",
"Bedeutung fuer den Konflikt"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Funktion der Szene im Drama",
"Bedeutung fuer die Gesamtdeutung"
],
sprachliche_aspekte=[
"Fachbegriffe der Dramenanalyse",
"Praesens als Analysetempus",
"Korrekte Zitierweise mit Akt/Szene/Zeile"
]
)
@@ -0,0 +1,101 @@
"""
Erwartungshorizont Templates — Eroerterung template.
"""
from eh_templates_types import EHTemplate, EHKriterium
def get_eroerterung_template() -> EHTemplate:
"""Template for textgebundene Eroerterung."""
return EHTemplate(
id="template_eroerterung_textgebunden",
aufgabentyp="eroerterung_textgebunden",
name="Textgebundene Eroerterung",
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Qualitaet der Argumentation",
gewichtung=40,
erwartungen=[
"Korrekte Wiedergabe der Textposition",
"Differenzierte eigene Argumentation",
"Vielfaeltige und ueberzeugende Argumente",
"Beruecksichtigung von Pro und Contra",
"Sinnvolle Beispiele und Belege",
"Eigenstaendige Schlussfolgerung"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Eroerterung",
gewichtung=15,
erwartungen=[
"Problemorientierte Einleitung",
"Klare Gliederung der Argumentation",
"Logische Argumentationsfolge",
"Sinnvolle Ueberlaetze",
"Begruendetes Fazit"
]
),
EHKriterium(
id="textbezug",
name="Textbezug",
beschreibung="Verknuepfung mit dem Ausgangstext",
gewichtung=15,
erwartungen=[
"Angemessene Textwiedergabe",
"Kritische Auseinandersetzung mit Textposition",
"Korrekte Zitierweise",
"Verknuepfung eigener Argumente mit Text"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung",
"Variationsreicher Ausdruck"
]
)
],
einleitung_hinweise=[
"Hinfuehrung zum Thema",
"Nennung des Ausgangstextes",
"Formulierung der Leitfrage/These",
"Ueberleitung zum Hauptteil"
],
hauptteil_hinweise=[
"Kurze Wiedergabe der Textposition",
"Systematische Argumentation (dialektisch oder linear)",
"Jedes Argument: These - Begruendung - Beispiel",
"Gewichtung der Argumente",
"Verknuepfung mit Textposition"
],
schluss_hinweise=[
"Zusammenfassung der wichtigsten Argumente",
"Eigene begruendete Stellungnahme",
"Ggf. Ausblick oder Appell"
],
sprachliche_aspekte=[
"Argumentative Konnektoren verwenden",
"Sachlicher, ueberzeugender Stil",
"Eigene Meinung kennzeichnen",
"Konjunktiv fuer Textpositionen"
]
)
@@ -0,0 +1,60 @@
"""
Erwartungshorizont Templates — registry for template lookup.
"""
from typing import Dict, List, Optional
from eh_templates_types import EHTemplate, AUFGABENTYPEN
from eh_templates_analyse import (
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from eh_templates_eroerterung import get_eroerterung_template
TEMPLATES: Dict[str, EHTemplate] = {}
def initialize_templates():
"""Initialize all pre-defined templates."""
global TEMPLATES
TEMPLATES = {
"textanalyse_pragmatisch": get_textanalyse_template(),
"gedichtanalyse": get_gedichtanalyse_template(),
"eroerterung_textgebunden": get_eroerterung_template(),
"prosaanalyse": get_prosaanalyse_template(),
"dramenanalyse": get_dramenanalyse_template(),
}
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
"""Get a template by Aufgabentyp."""
if not TEMPLATES:
initialize_templates()
return TEMPLATES.get(aufgabentyp)
def list_templates() -> List[Dict]:
"""List all available templates."""
if not TEMPLATES:
initialize_templates()
return [
{
"aufgabentyp": typ,
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
}
for typ in TEMPLATES.keys()
]
def get_aufgabentypen() -> Dict:
"""Get all Aufgabentypen definitions."""
return AUFGABENTYPEN
# Initialize on import
initialize_templates()
@@ -0,0 +1,100 @@
"""
Erwartungshorizont Templates — types and Aufgabentypen registry.
"""
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
AUFGABENTYPEN = {
"textanalyse_pragmatisch": {
"name": "Textanalyse (pragmatische Texte)",
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
"category": "analyse"
},
"sachtextanalyse": {
"name": "Sachtextanalyse",
"description": "Analyse von informativen und appellativen Sachtexten",
"category": "analyse"
},
"gedichtanalyse": {
"name": "Gedichtanalyse / Lyrikinterpretation",
"description": "Analyse und Interpretation lyrischer Texte",
"category": "interpretation"
},
"dramenanalyse": {
"name": "Dramenanalyse",
"description": "Analyse dramatischer Texte und Szenen",
"category": "interpretation"
},
"prosaanalyse": {
"name": "Epische Textanalyse / Prosaanalyse",
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
"category": "interpretation"
},
"eroerterung_textgebunden": {
"name": "Textgebundene Eroerterung",
"description": "Eroerterung auf Basis eines Sachtextes",
"category": "argumentation"
},
"eroerterung_frei": {
"name": "Freie Eroerterung",
"description": "Freie Eroerterung zu einem Thema",
"category": "argumentation"
},
"eroerterung_literarisch": {
"name": "Literarische Eroerterung",
"description": "Eroerterung zu literarischen Fragestellungen",
"category": "argumentation"
},
"materialgestuetzt": {
"name": "Materialgestuetztes Schreiben",
"description": "Verfassen eines Textes auf Materialbasis",
"category": "produktion"
}
}
@dataclass
class EHKriterium:
"""Single criterion in an Erwartungshorizont."""
id: str
name: str
beschreibung: str
gewichtung: int # Percentage weight (0-100)
erwartungen: List[str] # Expected points/elements
max_punkte: int = 100
def to_dict(self):
return asdict(self)
@dataclass
class EHTemplate:
"""Complete Erwartungshorizont template."""
id: str
aufgabentyp: str
name: str
beschreibung: str
kriterien: List[EHKriterium]
einleitung_hinweise: List[str]
hauptteil_hinweise: List[str]
schluss_hinweise: List[str]
sprachliche_aspekte: List[str]
created_at: datetime = field(default_factory=lambda: datetime.now())
def to_dict(self):
d = {
'id': self.id,
'aufgabentyp': self.aufgabentyp,
'name': self.name,
'beschreibung': self.beschreibung,
'kriterien': [k.to_dict() for k in self.kriterien],
'einleitung_hinweise': self.einleitung_hinweise,
'hauptteil_hinweise': self.hauptteil_hinweise,
'schluss_hinweise': self.schluss_hinweise,
'sprachliche_aspekte': self.sprachliche_aspekte,
'created_at': self.created_at.isoformat()
}
return d
+27 -667
View File
@@ -1,671 +1,31 @@
"""
Grid Editor API — endpoints for grid building, editing, and export.
Grid Editor API — barrel re-export.
The core grid building logic is in grid_build_core.py.
The actual endpoints live in:
- grid_editor_api_grid.py (build-grid, rerun-ocr, save-grid, get-grid)
- grid_editor_api_gutter.py (gutter-repair, gutter-repair/apply)
- grid_editor_api_box.py (build-box-grids)
- grid_editor_api_unified.py (build-unified-grid, unified-grid)
This module re-exports the combined router and key symbols so that
existing `from grid_editor_api import router` / `from grid_editor_api import _build_grid_core`
continue to work unchanged.
"""
import logging
import re
import time
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter, HTTPException, Query, Request
from grid_build_core import _build_grid_core
from grid_editor_helpers import _words_in_zone
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/build-grid")
async def build_grid(
session_id: str,
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
):
"""Build a structured, zone-aware grid from existing Kombi word results.
Requires that paddle-kombi or rapid-kombi has already been run on the session.
Uses the image for box detection and the word positions for grid structuring.
Query params:
ipa_mode: "auto" (only when English IPA detected), "all" (force), "none" (skip)
syllable_mode: "auto" (only when original has dividers), "all" (force), "none" (skip)
Returns a StructuredGrid with zones, each containing their own
columns, rows, and cells — ready for the frontend Excel-like editor.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
try:
result = await _build_grid_core(
session_id, session,
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
enhance=enhance,
max_columns=max_cols if max_cols > 0 else None,
min_conf=min_conf if min_conf > 0 else None,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Save automatic grid snapshot for later comparison with manual corrections
# Lazy import to avoid circular dependency with ocr_pipeline_regression
from ocr_pipeline_regression import _build_reference_snapshot
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
auto_pipeline = "kombi"
elif engine == "paddle_direct":
auto_pipeline = "paddle-direct"
else:
auto_pipeline = "pipeline"
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
gt = session.get("ground_truth") or {}
gt["auto_grid_snapshot"] = auto_snapshot
# Persist to DB and advance current_step to 11 (reconstruction complete)
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
logger.info(
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
"%d boxes in %.2fs",
session_id,
len(result.get("zones", [])),
result.get("summary", {}).get("total_columns", 0),
result.get("summary", {}).get("total_rows", 0),
result.get("summary", {}).get("total_cells", 0),
result.get("boxes_detected", 0),
result.get("duration_seconds", 0),
)
return result
@router.post("/sessions/{session_id}/rerun-ocr-and-build-grid")
async def rerun_ocr_and_build_grid(
session_id: str,
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
vision_fusion: bool = Query(False, description="Step 4: Vision-LLM fusion for degraded scans"),
doc_category: str = Query("", description="Document type for Vision-LLM prompt context"),
):
"""Re-run OCR with quality settings, then rebuild the grid.
Unlike build-grid (which only rebuilds from existing words),
this endpoint re-runs the full OCR pipeline on the cropped image
with optional CLAHE enhancement, then builds the grid.
Steps executed: Image Enhancement → OCR → Grid Build
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
import time as _time
t0 = _time.time()
# 1. Load the cropped/dewarped image from cache or session
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="No cropped/dewarped image available. Run preprocessing steps first.")
import numpy as np
img_h, img_w = dewarped_bgr.shape[:2]
ocr_input = dewarped_bgr.copy()
# 2. Scan quality assessment
scan_quality_info = {}
try:
from scan_quality import score_scan_quality
quality_report = score_scan_quality(ocr_input)
scan_quality_info = quality_report.to_dict()
actual_min_conf = min_conf if min_conf > 0 else quality_report.recommended_min_conf
except Exception as e:
logger.warning(f"rerun-ocr: scan quality failed: {e}")
actual_min_conf = min_conf if min_conf > 0 else 40
# 3. Image enhancement (Step 3)
is_degraded = scan_quality_info.get("is_degraded", False)
if enhance and is_degraded:
try:
from ocr_image_enhance import enhance_for_ocr
ocr_input = enhance_for_ocr(ocr_input, is_degraded=True)
logger.info("rerun-ocr: CLAHE enhancement applied")
except Exception as e:
logger.warning(f"rerun-ocr: enhancement failed: {e}")
# 4. Run dual-engine OCR
from PIL import Image
import pytesseract
# RapidOCR
rapid_words = []
try:
from cv_ocr_engines import ocr_region_rapid
from cv_vocab_types import PageRegion
full_region = PageRegion(type="full_page", x=0, y=0, width=img_w, height=img_h)
rapid_words = ocr_region_rapid(ocr_input, full_region) or []
except Exception as e:
logger.warning(f"rerun-ocr: RapidOCR failed: {e}")
# Tesseract
pil_img = Image.fromarray(ocr_input[:, :, ::-1])
data = pytesseract.image_to_data(pil_img, lang='eng+deu', config='--psm 6 --oem 3', output_type=pytesseract.Output.DICT)
tess_words = []
for i in range(len(data["text"])):
text = (data["text"][i] or "").strip()
conf_raw = str(data["conf"][i])
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
if not text or conf < actual_min_conf:
continue
tess_words.append({
"text": text, "left": data["left"][i], "top": data["top"][i],
"width": data["width"][i], "height": data["height"][i], "conf": conf,
})
# 5. Merge OCR results
from ocr_pipeline_ocr_merge import _split_paddle_multi_words, _merge_paddle_tesseract, _deduplicate_words
rapid_split = _split_paddle_multi_words(rapid_words) if rapid_words else []
if rapid_split or tess_words:
merged_words = _merge_paddle_tesseract(rapid_split, tess_words)
merged_words = _deduplicate_words(merged_words)
else:
merged_words = tess_words
# 6. Store updated word_result in session
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
for w in merged_words]
word_result = {
"cells": [{"text": " ".join(w["text"] for w in merged_words),
"word_boxes": cells_for_storage}],
"image_width": img_w,
"image_height": img_h,
"ocr_engine": "rapid_kombi",
"word_count": len(merged_words),
"raw_paddle_words": rapid_words,
}
# 6b. Vision-LLM Fusion (Step 4) — correct OCR using Vision model
vision_applied = False
if vision_fusion:
try:
from vision_ocr_fusion import vision_fuse_ocr
category = doc_category or session.get("document_category") or "vokabelseite"
logger.info(f"rerun-ocr: running Vision-LLM fusion (category={category})")
merged_words = await vision_fuse_ocr(ocr_input, merged_words, category)
vision_applied = True
# Rebuild storage from fused words
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
for w in merged_words]
word_result["cells"] = [{"text": " ".join(w["text"] for w in merged_words),
"word_boxes": cells_for_storage}]
word_result["word_count"] = len(merged_words)
word_result["ocr_engine"] = "vision_fusion"
except Exception as e:
logger.warning(f"rerun-ocr: Vision-LLM fusion failed: {e}")
await update_session_db(session_id, word_result=word_result)
# Reload session with updated word_result
session = await get_session_db(session_id)
ocr_duration = _time.time() - t0
logger.info(
"rerun-ocr session %s: %d words (rapid=%d, tess=%d, merged=%d) in %.1fs "
"(enhance=%s, min_conf=%d, quality=%s)",
session_id, len(merged_words), len(rapid_words), len(tess_words),
len(merged_words), ocr_duration, enhance, actual_min_conf,
scan_quality_info.get("quality_pct", "?"),
)
# 7. Build grid from new words
try:
result = await _build_grid_core(
session_id, session,
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
enhance=enhance,
max_columns=max_cols if max_cols > 0 else None,
min_conf=min_conf if min_conf > 0 else None,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Persist grid
await update_session_db(session_id, grid_editor_result=result, current_step=11)
# Add quality info to response
result["scan_quality"] = scan_quality_info
result["ocr_stats"] = {
"rapid_words": len(rapid_words),
"tess_words": len(tess_words),
"merged_words": len(merged_words),
"min_conf_used": actual_min_conf,
"enhance_applied": enhance and is_degraded,
"vision_fusion_applied": vision_applied,
"document_category": doc_category or session.get("document_category", ""),
"ocr_duration_seconds": round(ocr_duration, 1),
}
total_duration = _time.time() - t0
logger.info(
"rerun-ocr+build-grid session %s: %d zones, %d cols, %d cells in %.1fs",
session_id,
len(result.get("zones", [])),
result.get("summary", {}).get("total_columns", 0),
result.get("summary", {}).get("total_cells", 0),
total_duration,
)
return result
@router.post("/sessions/{session_id}/save-grid")
async def save_grid(session_id: str, request: Request):
"""Save edited grid data from the frontend Excel-like editor.
Receives the full StructuredGrid with user edits (text changes,
formatting changes like bold columns, header rows, etc.) and
persists it to the session's grid_editor_result.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
# Validate basic structure
if "zones" not in body:
raise HTTPException(status_code=400, detail="Missing 'zones' in request body")
# Preserve metadata from the original build
existing = session.get("grid_editor_result") or {}
result = {
"session_id": session_id,
"image_width": body.get("image_width", existing.get("image_width", 0)),
"image_height": body.get("image_height", existing.get("image_height", 0)),
"zones": body["zones"],
"boxes_detected": body.get("boxes_detected", existing.get("boxes_detected", 0)),
"summary": body.get("summary", existing.get("summary", {})),
"formatting": body.get("formatting", existing.get("formatting", {})),
"duration_seconds": existing.get("duration_seconds", 0),
"edited": True,
}
await update_session_db(session_id, grid_editor_result=result, current_step=11)
logger.info("save-grid session %s: %d zones saved", session_id, len(body["zones"]))
return {"session_id": session_id, "saved": True}
@router.get("/sessions/{session_id}/grid-editor")
async def get_grid(session_id: str):
"""Retrieve the current grid editor state for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
result = session.get("grid_editor_result")
if not result:
raise HTTPException(
status_code=404,
detail="No grid editor data. Run build-grid first.",
)
return result
# ---------------------------------------------------------------------------
# Gutter Repair endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/gutter-repair")
async def gutter_repair(session_id: str):
"""Analyse grid for gutter-edge OCR errors and return repair suggestions.
Detects:
- Words truncated/blurred at the book binding (spell_fix)
- Words split across rows with missing hyphen chars (hyphen_join)
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(
status_code=400,
detail="No grid data. Run build-grid first.",
)
from cv_gutter_repair import analyse_grid_for_gutter_repair
image_width = grid_data.get("image_width", 0)
result = analyse_grid_for_gutter_repair(grid_data, image_width=image_width)
# Persist suggestions in ground_truth.gutter_repair (avoids DB migration)
gt = session.get("ground_truth") or {}
gt["gutter_repair"] = result
await update_session_db(session_id, ground_truth=gt)
logger.info(
"gutter-repair session %s: %d suggestions in %.2fs",
session_id,
result.get("stats", {}).get("suggestions_found", 0),
result.get("duration_seconds", 0),
)
return result
@router.post("/sessions/{session_id}/gutter-repair/apply")
async def gutter_repair_apply(session_id: str, request: Request):
"""Apply accepted gutter repair suggestions to the grid.
Body: { "accepted": ["suggestion_id_1", "suggestion_id_2", ...] }
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data.")
gt = session.get("ground_truth") or {}
gutter_result = gt.get("gutter_repair")
if not gutter_result:
raise HTTPException(
status_code=400,
detail="No gutter repair data. Run gutter-repair first.",
)
body = await request.json()
accepted_ids = body.get("accepted", [])
if not accepted_ids:
return {"applied_count": 0, "changes": []}
# text_overrides: { suggestion_id: "alternative_text" }
# Allows the user to pick a different correction from the alternatives list
text_overrides = body.get("text_overrides", {})
from cv_gutter_repair import apply_gutter_suggestions
suggestions = gutter_result.get("suggestions", [])
# Apply user-selected alternatives before passing to apply
for s in suggestions:
sid = s.get("id", "")
if sid in text_overrides and text_overrides[sid]:
s["suggested_text"] = text_overrides[sid]
result = apply_gutter_suggestions(grid_data, accepted_ids, suggestions)
# Save updated grid back to session
await update_session_db(session_id, grid_editor_result=grid_data)
logger.info(
"gutter-repair/apply session %s: %d changes applied",
session_id,
result.get("applied_count", 0),
)
return result
# ---------------------------------------------------------------------------
# Box-Grid-Review endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/build-box-grids")
async def build_box_grids(session_id: str, request: Request):
"""Rebuild grid structure for all detected boxes with layout-aware detection.
Uses structure_result.boxes (from Step 7) as the source of box coordinates,
and raw_paddle_words as OCR word source. Creates or updates box zones in
the grid_editor_result.
Optional body: { "overrides": { "0": "bullet_list" } }
Maps box_index → forced layout_type.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
# Get raw OCR words (with top/left/width/height keys)
word_result = session.get("word_result") or {}
all_words = word_result.get("raw_paddle_words") or word_result.get("raw_tesseract_words") or []
if not all_words:
raise HTTPException(status_code=400, detail="No raw OCR words available.")
# Get detected boxes from structure_result
structure_result = session.get("structure_result") or {}
gt = session.get("ground_truth") or {}
if not structure_result:
structure_result = gt.get("structure_result") or {}
detected_boxes = structure_result.get("boxes") or []
if not detected_boxes:
return {"session_id": session_id, "box_zones_rebuilt": 0, "spell_fixes": 0, "message": "No boxes detected"}
# Filter out false-positive boxes in header/footer margins.
# Textbook pages have ~2.5cm margins at top/bottom. At typical scan
# resolutions (150-300 DPI), that's roughly 5-10% of image height.
# A box whose vertical CENTER falls within the top or bottom 7% of
# the image is likely a page number, unit header, or running footer.
img_h_for_filter = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
if img_h_for_filter > 0:
margin_frac = 0.07 # 7% of image height
margin_top = img_h_for_filter * margin_frac
margin_bottom = img_h_for_filter * (1 - margin_frac)
filtered = []
for box in detected_boxes:
by = box.get("y", 0)
bh = box.get("h", 0)
box_center_y = by + bh / 2
if box_center_y < margin_top or box_center_y > margin_bottom:
logger.info("build-box-grids: skipping header/footer box at y=%d h=%d (center=%.0f, margins=%.0f/%.0f)",
by, bh, box_center_y, margin_top, margin_bottom)
continue
filtered.append(box)
detected_boxes = filtered
body = {}
try:
body = await request.json()
except Exception:
pass
layout_overrides = body.get("overrides", {})
from cv_box_layout import build_box_zone_grid
from grid_editor_helpers import _words_in_zone
img_w = grid_data.get("image_width", 0) or word_result.get("image_width", 0)
img_h = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
zones = grid_data.get("zones", [])
# Find highest existing zone_index
max_zone_idx = max((z.get("zone_index", 0) for z in zones), default=-1)
# Remove old box zones (we'll rebuild them)
zones = [z for z in zones if z.get("zone_type") != "box"]
box_count = 0
spell_fixes = 0
for box_idx, box in enumerate(detected_boxes):
bx = box.get("x", 0)
by = box.get("y", 0)
bw = box.get("w", 0)
bh = box.get("h", 0)
if bw <= 0 or bh <= 0:
continue
# Filter raw OCR words inside this box
zone_words = _words_in_zone(all_words, by, bh, bx, bw)
if not zone_words:
logger.info("Box %d: no words found in bbox (%d,%d,%d,%d)", box_idx, bx, by, bw, bh)
continue
zone_idx = max_zone_idx + 1 + box_idx
forced_layout = layout_overrides.get(str(box_idx))
# Build box grid
box_grid = build_box_zone_grid(
zone_words, bx, by, bw, bh,
zone_idx, img_w, img_h,
layout_type=forced_layout,
)
# Apply SmartSpellChecker to all box cells
try:
from smart_spell import SmartSpellChecker
ssc = SmartSpellChecker()
for cell in box_grid.get("cells", []):
text = cell.get("text", "")
if not text:
continue
result = ssc.correct_text(text, lang="auto")
if result.changed:
cell["text"] = result.corrected
spell_fixes += 1
except ImportError:
pass
# Build zone entry
zone_entry = {
"zone_index": zone_idx,
"zone_type": "box",
"bbox_px": {"x": bx, "y": by, "w": bw, "h": bh},
"bbox_pct": {
"x": round(bx / img_w * 100, 2) if img_w else 0,
"y": round(by / img_h * 100, 2) if img_h else 0,
"w": round(bw / img_w * 100, 2) if img_w else 0,
"h": round(bh / img_h * 100, 2) if img_h else 0,
},
"border": None,
"word_count": len(zone_words),
"columns": box_grid["columns"],
"rows": box_grid["rows"],
"cells": box_grid["cells"],
"header_rows": box_grid.get("header_rows", []),
"box_layout_type": box_grid.get("box_layout_type", "flowing"),
"box_grid_reviewed": False,
"box_bg_color": box.get("bg_color_name", ""),
"box_bg_hex": box.get("bg_color_hex", ""),
}
zones.append(zone_entry)
box_count += 1
# Sort zones by y-position for correct reading order
zones.sort(key=lambda z: z.get("bbox_px", {}).get("y", 0))
grid_data["zones"] = zones
await update_session_db(session_id, grid_editor_result=grid_data)
logger.info(
"build-box-grids session %s: %d boxes processed (%d words spell-fixed) from %d detected",
session_id, box_count, spell_fixes, len(detected_boxes),
)
return {
"session_id": session_id,
"box_zones_rebuilt": box_count,
"total_detected_boxes": len(detected_boxes),
"spell_fixes": spell_fixes,
"zones": zones,
}
# ---------------------------------------------------------------------------
# Unified Grid endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/build-unified-grid")
async def build_unified_grid_endpoint(session_id: str):
"""Build a single-zone unified grid merging content + box zones.
Takes the existing multi-zone grid_editor_result and produces a
unified grid where boxes are integrated into the main row sequence.
Persists as unified_grid_result (preserves original multi-zone data).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
from unified_grid import build_unified_grid
result = build_unified_grid(
zones=grid_data.get("zones", []),
image_width=grid_data.get("image_width", 0),
image_height=grid_data.get("image_height", 0),
layout_metrics=grid_data.get("layout_metrics", {}),
)
# Persist as separate field (don't overwrite original multi-zone grid)
await update_session_db(session_id, unified_grid_result=result)
logger.info(
"build-unified-grid session %s: %d rows, %d cells",
session_id,
result.get("summary", {}).get("total_rows", 0),
result.get("summary", {}).get("total_cells", 0),
)
return result
@router.get("/sessions/{session_id}/unified-grid")
async def get_unified_grid(session_id: str):
"""Retrieve the unified grid for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
result = session.get("unified_grid_result")
if not result:
raise HTTPException(
status_code=404,
detail="No unified grid. Run build-unified-grid first.",
)
return result
from fastapi import APIRouter
from grid_editor_api_grid import router as _grid_router
from grid_editor_api_gutter import router as _gutter_router
from grid_editor_api_box import router as _box_router
from grid_editor_api_unified import router as _unified_router
# Re-export _build_grid_core so callers that do
# `from grid_editor_api import _build_grid_core` keep working.
from grid_build_core import _build_grid_core # noqa: F401
# Merge all sub-routers into one combined router
router = APIRouter()
router.include_router(_grid_router)
router.include_router(_gutter_router)
router.include_router(_box_router)
router.include_router(_unified_router)
@@ -0,0 +1,177 @@
"""
Grid Editor API — box-grid-review endpoints.
"""
import logging
from fastapi import APIRouter, HTTPException, Request
from grid_editor_helpers import _words_in_zone
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
@router.post("/sessions/{session_id}/build-box-grids")
async def build_box_grids(session_id: str, request: Request):
"""Rebuild grid structure for all detected boxes with layout-aware detection.
Uses structure_result.boxes (from Step 7) as the source of box coordinates,
and raw_paddle_words as OCR word source. Creates or updates box zones in
the grid_editor_result.
Optional body: { "overrides": { "0": "bullet_list" } }
Maps box_index -> forced layout_type.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
# Get raw OCR words (with top/left/width/height keys)
word_result = session.get("word_result") or {}
all_words = word_result.get("raw_paddle_words") or word_result.get("raw_tesseract_words") or []
if not all_words:
raise HTTPException(status_code=400, detail="No raw OCR words available.")
# Get detected boxes from structure_result
structure_result = session.get("structure_result") or {}
gt = session.get("ground_truth") or {}
if not structure_result:
structure_result = gt.get("structure_result") or {}
detected_boxes = structure_result.get("boxes") or []
if not detected_boxes:
return {"session_id": session_id, "box_zones_rebuilt": 0, "spell_fixes": 0, "message": "No boxes detected"}
# Filter out false-positive boxes in header/footer margins.
img_h_for_filter = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
if img_h_for_filter > 0:
margin_frac = 0.07 # 7% of image height
margin_top = img_h_for_filter * margin_frac
margin_bottom = img_h_for_filter * (1 - margin_frac)
filtered = []
for box in detected_boxes:
by = box.get("y", 0)
bh = box.get("h", 0)
box_center_y = by + bh / 2
if box_center_y < margin_top or box_center_y > margin_bottom:
logger.info("build-box-grids: skipping header/footer box at y=%d h=%d (center=%.0f, margins=%.0f/%.0f)",
by, bh, box_center_y, margin_top, margin_bottom)
continue
filtered.append(box)
detected_boxes = filtered
body = {}
try:
body = await request.json()
except Exception:
pass
layout_overrides = body.get("overrides", {})
from cv_box_layout import build_box_zone_grid
img_w = grid_data.get("image_width", 0) or word_result.get("image_width", 0)
img_h = grid_data.get("image_height", 0) or word_result.get("image_height", 0)
zones = grid_data.get("zones", [])
# Find highest existing zone_index
max_zone_idx = max((z.get("zone_index", 0) for z in zones), default=-1)
# Remove old box zones (we'll rebuild them)
zones = [z for z in zones if z.get("zone_type") != "box"]
box_count = 0
spell_fixes = 0
for box_idx, box in enumerate(detected_boxes):
bx = box.get("x", 0)
by = box.get("y", 0)
bw = box.get("w", 0)
bh = box.get("h", 0)
if bw <= 0 or bh <= 0:
continue
# Filter raw OCR words inside this box
zone_words = _words_in_zone(all_words, by, bh, bx, bw)
if not zone_words:
logger.info("Box %d: no words found in bbox (%d,%d,%d,%d)", box_idx, bx, by, bw, bh)
continue
zone_idx = max_zone_idx + 1 + box_idx
forced_layout = layout_overrides.get(str(box_idx))
# Build box grid
box_grid = build_box_zone_grid(
zone_words, bx, by, bw, bh,
zone_idx, img_w, img_h,
layout_type=forced_layout,
)
# Apply SmartSpellChecker to all box cells
try:
from smart_spell import SmartSpellChecker
ssc = SmartSpellChecker()
for cell in box_grid.get("cells", []):
text = cell.get("text", "")
if not text:
continue
result = ssc.correct_text(text, lang="auto")
if result.changed:
cell["text"] = result.corrected
spell_fixes += 1
except ImportError:
pass
# Build zone entry
zone_entry = {
"zone_index": zone_idx,
"zone_type": "box",
"bbox_px": {"x": bx, "y": by, "w": bw, "h": bh},
"bbox_pct": {
"x": round(bx / img_w * 100, 2) if img_w else 0,
"y": round(by / img_h * 100, 2) if img_h else 0,
"w": round(bw / img_w * 100, 2) if img_w else 0,
"h": round(bh / img_h * 100, 2) if img_h else 0,
},
"border": None,
"word_count": len(zone_words),
"columns": box_grid["columns"],
"rows": box_grid["rows"],
"cells": box_grid["cells"],
"header_rows": box_grid.get("header_rows", []),
"box_layout_type": box_grid.get("box_layout_type", "flowing"),
"box_grid_reviewed": False,
"box_bg_color": box.get("bg_color_name", ""),
"box_bg_hex": box.get("bg_color_hex", ""),
}
zones.append(zone_entry)
box_count += 1
# Sort zones by y-position for correct reading order
zones.sort(key=lambda z: z.get("bbox_px", {}).get("y", 0))
grid_data["zones"] = zones
await update_session_db(session_id, grid_editor_result=grid_data)
logger.info(
"build-box-grids session %s: %d boxes processed (%d words spell-fixed) from %d detected",
session_id, box_count, spell_fixes, len(detected_boxes),
)
return {
"session_id": session_id,
"box_zones_rebuilt": box_count,
"total_detected_boxes": len(detected_boxes),
"spell_fixes": spell_fixes,
"zones": zones,
}
@@ -0,0 +1,337 @@
"""
Grid Editor API — grid build, save, and retrieve endpoints.
"""
import logging
import time
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Query, Request
from grid_build_core import _build_grid_core
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
@router.post("/sessions/{session_id}/build-grid")
async def build_grid(
session_id: str,
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
):
"""Build a structured, zone-aware grid from existing Kombi word results.
Requires that paddle-kombi or rapid-kombi has already been run on the session.
Uses the image for box detection and the word positions for grid structuring.
Query params:
ipa_mode: "auto" (only when English IPA detected), "all" (force), "none" (skip)
syllable_mode: "auto" (only when original has dividers), "all" (force), "none" (skip)
Returns a StructuredGrid with zones, each containing their own
columns, rows, and cells — ready for the frontend Excel-like editor.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
try:
result = await _build_grid_core(
session_id, session,
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
enhance=enhance,
max_columns=max_cols if max_cols > 0 else None,
min_conf=min_conf if min_conf > 0 else None,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Save automatic grid snapshot for later comparison with manual corrections
# Lazy import to avoid circular dependency with ocr_pipeline_regression
from ocr_pipeline_regression import _build_reference_snapshot
wr = session.get("word_result") or {}
engine = wr.get("ocr_engine", "")
if engine in ("kombi", "rapid_kombi"):
auto_pipeline = "kombi"
elif engine == "paddle_direct":
auto_pipeline = "paddle-direct"
else:
auto_pipeline = "pipeline"
auto_snapshot = _build_reference_snapshot(result, pipeline=auto_pipeline)
gt = session.get("ground_truth") or {}
gt["auto_grid_snapshot"] = auto_snapshot
# Persist to DB and advance current_step to 11 (reconstruction complete)
await update_session_db(session_id, grid_editor_result=result, ground_truth=gt, current_step=11)
logger.info(
"build-grid session %s: %d zones, %d cols, %d rows, %d cells, "
"%d boxes in %.2fs",
session_id,
len(result.get("zones", [])),
result.get("summary", {}).get("total_columns", 0),
result.get("summary", {}).get("total_rows", 0),
result.get("summary", {}).get("total_cells", 0),
result.get("boxes_detected", 0),
result.get("duration_seconds", 0),
)
return result
@router.post("/sessions/{session_id}/rerun-ocr-and-build-grid")
async def rerun_ocr_and_build_grid(
session_id: str,
ipa_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
syllable_mode: str = Query("auto", pattern="^(auto|all|de|en|none)$"),
enhance: bool = Query(True, description="Step 3: CLAHE + denoise for degraded scans"),
max_cols: int = Query(0, description="Step 2: Max column count (0=unlimited)"),
min_conf: int = Query(0, description="Step 1: Min OCR confidence (0=auto)"),
vision_fusion: bool = Query(False, description="Step 4: Vision-LLM fusion for degraded scans"),
doc_category: str = Query("", description="Document type for Vision-LLM prompt context"),
):
"""Re-run OCR with quality settings, then rebuild the grid.
Unlike build-grid (which only rebuilds from existing words),
this endpoint re-runs the full OCR pipeline on the cropped image
with optional CLAHE enhancement, then builds the grid.
Steps executed: Image Enhancement -> OCR -> Grid Build
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
import time as _time
t0 = _time.time()
# 1. Load the cropped/dewarped image from cache or session
if session_id not in _cache:
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
raise HTTPException(status_code=400, detail="No cropped/dewarped image available. Run preprocessing steps first.")
import numpy as np
img_h, img_w = dewarped_bgr.shape[:2]
ocr_input = dewarped_bgr.copy()
# 2. Scan quality assessment
scan_quality_info = {}
try:
from scan_quality import score_scan_quality
quality_report = score_scan_quality(ocr_input)
scan_quality_info = quality_report.to_dict()
actual_min_conf = min_conf if min_conf > 0 else quality_report.recommended_min_conf
except Exception as e:
logger.warning(f"rerun-ocr: scan quality failed: {e}")
actual_min_conf = min_conf if min_conf > 0 else 40
# 3. Image enhancement (Step 3)
is_degraded = scan_quality_info.get("is_degraded", False)
if enhance and is_degraded:
try:
from ocr_image_enhance import enhance_for_ocr
ocr_input = enhance_for_ocr(ocr_input, is_degraded=True)
logger.info("rerun-ocr: CLAHE enhancement applied")
except Exception as e:
logger.warning(f"rerun-ocr: enhancement failed: {e}")
# 4. Run dual-engine OCR
from PIL import Image
import pytesseract
# RapidOCR
rapid_words = []
try:
from cv_ocr_engines import ocr_region_rapid
from cv_vocab_types import PageRegion
full_region = PageRegion(type="full_page", x=0, y=0, width=img_w, height=img_h)
rapid_words = ocr_region_rapid(ocr_input, full_region) or []
except Exception as e:
logger.warning(f"rerun-ocr: RapidOCR failed: {e}")
# Tesseract
pil_img = Image.fromarray(ocr_input[:, :, ::-1])
data = pytesseract.image_to_data(pil_img, lang='eng+deu', config='--psm 6 --oem 3', output_type=pytesseract.Output.DICT)
tess_words = []
for i in range(len(data["text"])):
text = (data["text"][i] or "").strip()
conf_raw = str(data["conf"][i])
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
if not text or conf < actual_min_conf:
continue
tess_words.append({
"text": text, "left": data["left"][i], "top": data["top"][i],
"width": data["width"][i], "height": data["height"][i], "conf": conf,
})
# 5. Merge OCR results
from ocr_pipeline_ocr_merge import _split_paddle_multi_words, _merge_paddle_tesseract, _deduplicate_words
rapid_split = _split_paddle_multi_words(rapid_words) if rapid_words else []
if rapid_split or tess_words:
merged_words = _merge_paddle_tesseract(rapid_split, tess_words)
merged_words = _deduplicate_words(merged_words)
else:
merged_words = tess_words
# 6. Store updated word_result in session
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
for w in merged_words]
word_result = {
"cells": [{"text": " ".join(w["text"] for w in merged_words),
"word_boxes": cells_for_storage}],
"image_width": img_w,
"image_height": img_h,
"ocr_engine": "rapid_kombi",
"word_count": len(merged_words),
"raw_paddle_words": rapid_words,
}
# 6b. Vision-LLM Fusion (Step 4) — correct OCR using Vision model
vision_applied = False
if vision_fusion:
try:
from vision_ocr_fusion import vision_fuse_ocr
category = doc_category or session.get("document_category") or "vokabelseite"
logger.info(f"rerun-ocr: running Vision-LLM fusion (category={category})")
merged_words = await vision_fuse_ocr(ocr_input, merged_words, category)
vision_applied = True
# Rebuild storage from fused words
cells_for_storage = [{"text": w["text"], "left": w["left"], "top": w["top"],
"width": w["width"], "height": w["height"], "conf": w.get("conf", 0)}
for w in merged_words]
word_result["cells"] = [{"text": " ".join(w["text"] for w in merged_words),
"word_boxes": cells_for_storage}]
word_result["word_count"] = len(merged_words)
word_result["ocr_engine"] = "vision_fusion"
except Exception as e:
logger.warning(f"rerun-ocr: Vision-LLM fusion failed: {e}")
await update_session_db(session_id, word_result=word_result)
# Reload session with updated word_result
session = await get_session_db(session_id)
ocr_duration = _time.time() - t0
logger.info(
"rerun-ocr session %s: %d words (rapid=%d, tess=%d, merged=%d) in %.1fs "
"(enhance=%s, min_conf=%d, quality=%s)",
session_id, len(merged_words), len(rapid_words), len(tess_words),
len(merged_words), ocr_duration, enhance, actual_min_conf,
scan_quality_info.get("quality_pct", "?"),
)
# 7. Build grid from new words
try:
result = await _build_grid_core(
session_id, session,
ipa_mode=ipa_mode, syllable_mode=syllable_mode,
enhance=enhance,
max_columns=max_cols if max_cols > 0 else None,
min_conf=min_conf if min_conf > 0 else None,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Persist grid
await update_session_db(session_id, grid_editor_result=result, current_step=11)
# Add quality info to response
result["scan_quality"] = scan_quality_info
result["ocr_stats"] = {
"rapid_words": len(rapid_words),
"tess_words": len(tess_words),
"merged_words": len(merged_words),
"min_conf_used": actual_min_conf,
"enhance_applied": enhance and is_degraded,
"vision_fusion_applied": vision_applied,
"document_category": doc_category or session.get("document_category", ""),
"ocr_duration_seconds": round(ocr_duration, 1),
}
total_duration = _time.time() - t0
logger.info(
"rerun-ocr+build-grid session %s: %d zones, %d cols, %d cells in %.1fs",
session_id,
len(result.get("zones", [])),
result.get("summary", {}).get("total_columns", 0),
result.get("summary", {}).get("total_cells", 0),
total_duration,
)
return result
@router.post("/sessions/{session_id}/save-grid")
async def save_grid(session_id: str, request: Request):
"""Save edited grid data from the frontend Excel-like editor.
Receives the full StructuredGrid with user edits (text changes,
formatting changes like bold columns, header rows, etc.) and
persists it to the session's grid_editor_result.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
# Validate basic structure
if "zones" not in body:
raise HTTPException(status_code=400, detail="Missing 'zones' in request body")
# Preserve metadata from the original build
existing = session.get("grid_editor_result") or {}
result = {
"session_id": session_id,
"image_width": body.get("image_width", existing.get("image_width", 0)),
"image_height": body.get("image_height", existing.get("image_height", 0)),
"zones": body["zones"],
"boxes_detected": body.get("boxes_detected", existing.get("boxes_detected", 0)),
"summary": body.get("summary", existing.get("summary", {})),
"formatting": body.get("formatting", existing.get("formatting", {})),
"duration_seconds": existing.get("duration_seconds", 0),
"edited": True,
}
await update_session_db(session_id, grid_editor_result=result, current_step=11)
logger.info("save-grid session %s: %d zones saved", session_id, len(body["zones"]))
return {"session_id": session_id, "saved": True}
@router.get("/sessions/{session_id}/grid-editor")
async def get_grid(session_id: str):
"""Retrieve the current grid editor state for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
result = session.get("grid_editor_result")
if not result:
raise HTTPException(
status_code=404,
detail="No grid editor data. Run build-grid first.",
)
return result
@@ -0,0 +1,110 @@
"""
Grid Editor API — gutter repair endpoints.
"""
import logging
from fastapi import APIRouter, HTTPException, Request
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
@router.post("/sessions/{session_id}/gutter-repair")
async def gutter_repair(session_id: str):
"""Analyse grid for gutter-edge OCR errors and return repair suggestions.
Detects:
- Words truncated/blurred at the book binding (spell_fix)
- Words split across rows with missing hyphen chars (hyphen_join)
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(
status_code=400,
detail="No grid data. Run build-grid first.",
)
from cv_gutter_repair import analyse_grid_for_gutter_repair
image_width = grid_data.get("image_width", 0)
result = analyse_grid_for_gutter_repair(grid_data, image_width=image_width)
# Persist suggestions in ground_truth.gutter_repair (avoids DB migration)
gt = session.get("ground_truth") or {}
gt["gutter_repair"] = result
await update_session_db(session_id, ground_truth=gt)
logger.info(
"gutter-repair session %s: %d suggestions in %.2fs",
session_id,
result.get("stats", {}).get("suggestions_found", 0),
result.get("duration_seconds", 0),
)
return result
@router.post("/sessions/{session_id}/gutter-repair/apply")
async def gutter_repair_apply(session_id: str, request: Request):
"""Apply accepted gutter repair suggestions to the grid.
Body: { "accepted": ["suggestion_id_1", "suggestion_id_2", ...] }
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data.")
gt = session.get("ground_truth") or {}
gutter_result = gt.get("gutter_repair")
if not gutter_result:
raise HTTPException(
status_code=400,
detail="No gutter repair data. Run gutter-repair first.",
)
body = await request.json()
accepted_ids = body.get("accepted", [])
if not accepted_ids:
return {"applied_count": 0, "changes": []}
# text_overrides: { suggestion_id: "alternative_text" }
# Allows the user to pick a different correction from the alternatives list
text_overrides = body.get("text_overrides", {})
from cv_gutter_repair import apply_gutter_suggestions
suggestions = gutter_result.get("suggestions", [])
# Apply user-selected alternatives before passing to apply
for s in suggestions:
sid = s.get("id", "")
if sid in text_overrides and text_overrides[sid]:
s["suggested_text"] = text_overrides[sid]
result = apply_gutter_suggestions(grid_data, accepted_ids, suggestions)
# Save updated grid back to session
await update_session_db(session_id, grid_editor_result=grid_data)
logger.info(
"gutter-repair/apply session %s: %d changes applied",
session_id,
result.get("applied_count", 0),
)
return result
@@ -0,0 +1,71 @@
"""
Grid Editor API — unified grid endpoints.
"""
import logging
from fastapi import APIRouter, HTTPException
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["grid-editor"])
@router.post("/sessions/{session_id}/build-unified-grid")
async def build_unified_grid_endpoint(session_id: str):
"""Build a single-zone unified grid merging content + box zones.
Takes the existing multi-zone grid_editor_result and produces a
unified grid where boxes are integrated into the main row sequence.
Persists as unified_grid_result (preserves original multi-zone data).
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
grid_data = session.get("grid_editor_result")
if not grid_data:
raise HTTPException(status_code=400, detail="No grid data. Run build-grid first.")
from unified_grid import build_unified_grid
result = build_unified_grid(
zones=grid_data.get("zones", []),
image_width=grid_data.get("image_width", 0),
image_height=grid_data.get("image_height", 0),
layout_metrics=grid_data.get("layout_metrics", {}),
)
# Persist as separate field (don't overwrite original multi-zone grid)
await update_session_db(session_id, unified_grid_result=result)
logger.info(
"build-unified-grid session %s: %d rows, %d cells",
session_id,
result.get("summary", {}).get("total_rows", 0),
result.get("summary", {}).get("total_cells", 0),
)
return result
@router.get("/sessions/{session_id}/unified-grid")
async def get_unified_grid(session_id: str):
"""Retrieve the unified grid for a session."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
result = session.get("unified_grid_result")
if not result:
raise HTTPException(
status_code=404,
detail="No unified grid. Run build-unified-grid first.",
)
return result
+25 -609
View File
@@ -1,68 +1,43 @@
"""
Unified Inbox Mail API
Unified Inbox Mail API — barrel re-export.
FastAPI router for the mail system.
The actual endpoints live in:
- api_accounts.py (account CRUD, test, sync)
- api_inbox.py (unified inbox, email detail, send)
- api_ai.py (AI analysis, response suggestions)
- api_tasks.py (task CRUD, dashboard, from-email)
"""
import logging
from typing import Optional, List
from datetime import datetime
from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException, Query
from .models import (
EmailAccountCreate,
EmailAccountUpdate,
EmailAccount,
AccountTestResult,
AggregatedEmail,
EmailSearchParams,
TaskCreate,
TaskUpdate,
InboxTask,
TaskDashboardStats,
EmailComposeRequest,
EmailSendResult,
MailStats,
MailHealthCheck,
EmailAnalysisResult,
ResponseSuggestion,
TaskStatus,
TaskPriority,
EmailCategory,
)
from .mail_db import (
init_mail_tables,
create_email_account,
get_email_accounts,
get_email_account,
delete_email_account,
get_unified_inbox,
get_email,
mark_email_read,
mark_email_starred,
get_mail_stats,
log_mail_audit,
)
from .credentials import get_credentials_service
from .aggregator import get_mail_aggregator
from .ai_service import get_ai_email_service
from .task_service import get_task_service
from .models import MailHealthCheck, MailStats
from .mail_db import init_mail_tables, get_mail_stats
from .api_accounts import router as _accounts_router
from .api_inbox import router as _inbox_router
from .api_ai import router as _ai_router
from .api_tasks import router as _tasks_router
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
router = APIRouter()
# Merge sub-routers
router.include_router(_accounts_router)
router.include_router(_inbox_router)
router.include_router(_ai_router)
router.include_router(_tasks_router)
# =============================================================================
# Health & Init
# Health & Init (kept here as they are small)
# =============================================================================
@router.get("/health", response_model=MailHealthCheck)
@router.get("/api/v1/mail/health", response_model=MailHealthCheck)
async def health_check():
"""Health check for the mail system."""
# TODO: Implement full health check
return MailHealthCheck(
status="healthy",
database_connected=True,
@@ -70,7 +45,7 @@ async def health_check():
)
@router.post("/init")
@router.post("/api/v1/mail/init")
async def initialize_mail_system():
"""Initialize mail database tables."""
success = await init_mail_tables()
@@ -79,573 +54,14 @@ async def initialize_mail_system():
return {"status": "initialized"}
# =============================================================================
# Account Management
# =============================================================================
class AccountCreateRequest(BaseModel):
"""Request to create an email account."""
email: str
display_name: str
account_type: str = "personal"
imap_host: str
imap_port: int = 993
imap_ssl: bool = True
smtp_host: str
smtp_port: int = 465
smtp_ssl: bool = True
password: str
@router.post("/accounts", response_model=dict)
async def create_account(
request: AccountCreateRequest,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a new email account."""
credentials_service = get_credentials_service()
# Store credentials securely
vault_path = await credentials_service.store_credentials(
account_id=f"{user_id}_{request.email}",
email=request.email,
password=request.password,
imap_host=request.imap_host,
imap_port=request.imap_port,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
)
# Create account in database
account_id = await create_email_account(
user_id=user_id,
tenant_id=tenant_id,
email=request.email,
display_name=request.display_name,
account_type=request.account_type,
imap_host=request.imap_host,
imap_port=request.imap_port,
imap_ssl=request.imap_ssl,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
smtp_ssl=request.smtp_ssl,
vault_path=vault_path,
)
if not account_id:
raise HTTPException(status_code=500, detail="Failed to create account")
# Log audit
await log_mail_audit(
user_id=user_id,
action="account_created",
entity_type="account",
entity_id=account_id,
details={"email": request.email},
tenant_id=tenant_id,
)
return {"id": account_id, "status": "created"}
@router.get("/accounts", response_model=List[dict])
async def list_accounts(
user_id: str = Query(..., description="User ID"),
tenant_id: Optional[str] = Query(None, description="Tenant ID"),
):
"""List all email accounts for a user."""
accounts = await get_email_accounts(user_id, tenant_id)
# Remove sensitive fields
for account in accounts:
account.pop("vault_path", None)
return accounts
@router.get("/accounts/{account_id}", response_model=dict)
async def get_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
account.pop("vault_path", None)
return account
@router.delete("/accounts/{account_id}")
async def remove_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Delete an email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
# Delete credentials
credentials_service = get_credentials_service()
vault_path = account.get("vault_path", "")
if vault_path:
await credentials_service.delete_credentials(account_id, vault_path)
# Delete from database (cascades to emails)
success = await delete_email_account(account_id, user_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete account")
await log_mail_audit(
user_id=user_id,
action="account_deleted",
entity_type="account",
entity_id=account_id,
)
return {"status": "deleted"}
@router.post("/accounts/{account_id}/test", response_model=AccountTestResult)
async def test_account_connection(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Test connection for an email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
# Get credentials
credentials_service = get_credentials_service()
vault_path = account.get("vault_path", "")
creds = await credentials_service.get_credentials(account_id, vault_path)
if not creds:
return AccountTestResult(
success=False,
error_message="Credentials not found"
)
# Test connection
aggregator = get_mail_aggregator()
result = await aggregator.test_account_connection(
imap_host=account["imap_host"],
imap_port=account["imap_port"],
imap_ssl=account["imap_ssl"],
smtp_host=account["smtp_host"],
smtp_port=account["smtp_port"],
smtp_ssl=account["smtp_ssl"],
email_address=creds.email,
password=creds.password,
)
return result
class ConnectionTestRequest(BaseModel):
"""Request to test connection before saving account."""
email: str
imap_host: str
imap_port: int = 993
imap_ssl: bool = True
smtp_host: str
smtp_port: int = 465
smtp_ssl: bool = True
password: str
@router.post("/accounts/test-connection", response_model=AccountTestResult)
async def test_connection_before_save(request: ConnectionTestRequest):
"""
Test IMAP/SMTP connection before saving an account.
This allows the wizard to verify credentials are correct
before creating the account in the database.
"""
aggregator = get_mail_aggregator()
result = await aggregator.test_account_connection(
imap_host=request.imap_host,
imap_port=request.imap_port,
imap_ssl=request.imap_ssl,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
smtp_ssl=request.smtp_ssl,
email_address=request.email,
password=request.password,
)
return result
@router.post("/accounts/{account_id}/sync")
async def sync_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
max_emails: int = Query(100, ge=1, le=500),
background_tasks: BackgroundTasks = None,
):
"""Sync emails from an account."""
aggregator = get_mail_aggregator()
try:
new_count, total_count = await aggregator.sync_account(
account_id=account_id,
user_id=user_id,
max_emails=max_emails,
)
return {
"status": "synced",
"new_emails": new_count,
"total_emails": total_count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Unified Inbox
# =============================================================================
@router.get("/inbox", response_model=List[dict])
async def get_inbox(
user_id: str = Query(..., description="User ID"),
account_ids: Optional[str] = Query(None, description="Comma-separated account IDs"),
categories: Optional[str] = Query(None, description="Comma-separated categories"),
is_read: Optional[bool] = Query(None),
is_starred: Optional[bool] = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
):
"""Get unified inbox with all accounts aggregated."""
# Parse comma-separated values
account_id_list = account_ids.split(",") if account_ids else None
category_list = categories.split(",") if categories else None
emails = await get_unified_inbox(
user_id=user_id,
account_ids=account_id_list,
categories=category_list,
is_read=is_read,
is_starred=is_starred,
limit=limit,
offset=offset,
)
return emails
@router.get("/inbox/{email_id}", response_model=dict)
async def get_email_detail(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single email with full details."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
# Mark as read
await mark_email_read(email_id, user_id, is_read=True)
return email_data
@router.post("/inbox/{email_id}/read")
async def mark_read(
email_id: str,
user_id: str = Query(..., description="User ID"),
is_read: bool = Query(True),
):
"""Mark email as read/unread."""
success = await mark_email_read(email_id, user_id, is_read)
if not success:
raise HTTPException(status_code=500, detail="Failed to update email")
return {"status": "updated", "is_read": is_read}
@router.post("/inbox/{email_id}/star")
async def mark_starred(
email_id: str,
user_id: str = Query(..., description="User ID"),
is_starred: bool = Query(True),
):
"""Mark email as starred/unstarred."""
success = await mark_email_starred(email_id, user_id, is_starred)
if not success:
raise HTTPException(status_code=500, detail="Failed to update email")
return {"status": "updated", "is_starred": is_starred}
# =============================================================================
# Send Email
# =============================================================================
@router.post("/send", response_model=EmailSendResult)
async def send_email(
request: EmailComposeRequest,
user_id: str = Query(..., description="User ID"),
):
"""Send an email."""
aggregator = get_mail_aggregator()
result = await aggregator.send_email(
account_id=request.account_id,
user_id=user_id,
request=request,
)
if result.success:
await log_mail_audit(
user_id=user_id,
action="email_sent",
entity_type="email",
details={
"account_id": request.account_id,
"to": request.to,
"subject": request.subject,
},
)
return result
# =============================================================================
# AI Analysis
# =============================================================================
@router.post("/analyze/{email_id}", response_model=EmailAnalysisResult)
async def analyze_email(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Run AI analysis on an email."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
ai_service = get_ai_email_service()
result = await ai_service.analyze_email(
email_id=email_id,
sender_email=email_data.get("sender_email", ""),
sender_name=email_data.get("sender_name"),
subject=email_data.get("subject", ""),
body_text=email_data.get("body_text"),
body_preview=email_data.get("body_preview"),
)
return result
@router.get("/suggestions/{email_id}", response_model=List[ResponseSuggestion])
async def get_response_suggestions(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get AI-generated response suggestions for an email."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
ai_service = get_ai_email_service()
# Use stored analysis if available
from .models import SenderType, EmailCategory as EC
sender_type = SenderType(email_data.get("sender_type", "unbekannt"))
category = EC(email_data.get("category", "sonstiges"))
suggestions = await ai_service.suggest_response(
subject=email_data.get("subject", ""),
body_text=email_data.get("body_text", ""),
sender_type=sender_type,
category=category,
)
return suggestions
# =============================================================================
# Tasks (Arbeitsvorrat)
# =============================================================================
@router.get("/tasks", response_model=List[dict])
async def list_tasks(
user_id: str = Query(..., description="User ID"),
status: Optional[str] = Query(None, description="Filter by status"),
priority: Optional[str] = Query(None, description="Filter by priority"),
include_completed: bool = Query(False),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
):
"""Get all tasks for a user."""
task_service = get_task_service()
status_enum = TaskStatus(status) if status else None
priority_enum = TaskPriority(priority) if priority else None
tasks = await task_service.get_user_tasks(
user_id=user_id,
status=status_enum,
priority=priority_enum,
include_completed=include_completed,
limit=limit,
offset=offset,
)
return tasks
@router.post("/tasks", response_model=dict)
async def create_task(
request: TaskCreate,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a new task manually."""
task_service = get_task_service()
task_id = await task_service.create_manual_task(
user_id=user_id,
tenant_id=tenant_id,
task_data=request,
)
if not task_id:
raise HTTPException(status_code=500, detail="Failed to create task")
return {"id": task_id, "status": "created"}
@router.get("/tasks/dashboard", response_model=TaskDashboardStats)
async def get_task_dashboard(
user_id: str = Query(..., description="User ID"),
):
"""Get dashboard statistics for tasks."""
task_service = get_task_service()
return await task_service.get_dashboard_stats(user_id)
@router.get("/tasks/{task_id}", response_model=dict)
async def get_task(
task_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single task."""
task_service = get_task_service()
task = await task_service.get_task(task_id, user_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.put("/tasks/{task_id}")
async def update_task(
task_id: str,
request: TaskUpdate,
user_id: str = Query(..., description="User ID"),
):
"""Update a task."""
task_service = get_task_service()
success = await task_service.update_task(task_id, user_id, request)
if not success:
raise HTTPException(status_code=500, detail="Failed to update task")
return {"status": "updated"}
@router.post("/tasks/{task_id}/complete")
async def complete_task(
task_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Mark a task as completed."""
task_service = get_task_service()
success = await task_service.mark_completed(task_id, user_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to complete task")
return {"status": "completed"}
@router.post("/tasks/from-email/{email_id}")
async def create_task_from_email(
email_id: str,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a task from an email (after analysis)."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
# Get deadlines from stored analysis
deadlines_raw = email_data.get("detected_deadlines", [])
from .models import DeadlineExtraction, SenderType
deadlines = []
for d in deadlines_raw:
try:
deadlines.append(DeadlineExtraction(
deadline_date=datetime.fromisoformat(d["date"]),
description=d.get("description", "Frist"),
confidence=0.8,
source_text="",
is_firm=d.get("is_firm", True),
))
except (KeyError, ValueError):
continue
sender_type = None
if email_data.get("sender_type"):
try:
sender_type = SenderType(email_data["sender_type"])
except ValueError:
pass
task_service = get_task_service()
task_id = await task_service.create_task_from_email(
user_id=user_id,
tenant_id=tenant_id,
email_id=email_id,
deadlines=deadlines,
sender_type=sender_type,
)
if not task_id:
raise HTTPException(status_code=500, detail="Failed to create task")
return {"id": task_id, "status": "created"}
# =============================================================================
# Statistics
# =============================================================================
@router.get("/stats", response_model=MailStats)
@router.get("/api/v1/mail/stats", response_model=MailStats)
async def get_statistics(
user_id: str = Query(..., description="User ID"),
):
"""Get overall mail statistics for a user."""
stats = await get_mail_stats(user_id)
return MailStats(**stats)
# =============================================================================
# Sync All
# =============================================================================
@router.post("/sync-all")
async def sync_all_accounts(
user_id: str = Query(..., description="User ID"),
tenant_id: Optional[str] = Query(None),
):
"""Sync all email accounts for a user."""
aggregator = get_mail_aggregator()
results = await aggregator.sync_all_accounts(user_id, tenant_id)
return {"status": "synced", "results": results}
@@ -0,0 +1,258 @@
"""
Mail API — account management and sync endpoints.
"""
import logging
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from .models import AccountTestResult
from .mail_db import (
create_email_account,
get_email_accounts,
get_email_account,
delete_email_account,
log_mail_audit,
)
from .credentials import get_credentials_service
from .aggregator import get_mail_aggregator
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
class AccountCreateRequest(BaseModel):
"""Request to create an email account."""
email: str
display_name: str
account_type: str = "personal"
imap_host: str
imap_port: int = 993
imap_ssl: bool = True
smtp_host: str
smtp_port: int = 465
smtp_ssl: bool = True
password: str
@router.post("/accounts", response_model=dict)
async def create_account(
request: AccountCreateRequest,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a new email account."""
credentials_service = get_credentials_service()
# Store credentials securely
vault_path = await credentials_service.store_credentials(
account_id=f"{user_id}_{request.email}",
email=request.email,
password=request.password,
imap_host=request.imap_host,
imap_port=request.imap_port,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
)
# Create account in database
account_id = await create_email_account(
user_id=user_id,
tenant_id=tenant_id,
email=request.email,
display_name=request.display_name,
account_type=request.account_type,
imap_host=request.imap_host,
imap_port=request.imap_port,
imap_ssl=request.imap_ssl,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
smtp_ssl=request.smtp_ssl,
vault_path=vault_path,
)
if not account_id:
raise HTTPException(status_code=500, detail="Failed to create account")
# Log audit
await log_mail_audit(
user_id=user_id,
action="account_created",
entity_type="account",
entity_id=account_id,
details={"email": request.email},
tenant_id=tenant_id,
)
return {"id": account_id, "status": "created"}
@router.get("/accounts", response_model=List[dict])
async def list_accounts(
user_id: str = Query(..., description="User ID"),
tenant_id: Optional[str] = Query(None, description="Tenant ID"),
):
"""List all email accounts for a user."""
accounts = await get_email_accounts(user_id, tenant_id)
# Remove sensitive fields
for account in accounts:
account.pop("vault_path", None)
return accounts
@router.get("/accounts/{account_id}", response_model=dict)
async def get_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
account.pop("vault_path", None)
return account
@router.delete("/accounts/{account_id}")
async def remove_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Delete an email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
# Delete credentials
credentials_service = get_credentials_service()
vault_path = account.get("vault_path", "")
if vault_path:
await credentials_service.delete_credentials(account_id, vault_path)
# Delete from database (cascades to emails)
success = await delete_email_account(account_id, user_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete account")
await log_mail_audit(
user_id=user_id,
action="account_deleted",
entity_type="account",
entity_id=account_id,
)
return {"status": "deleted"}
@router.post("/accounts/{account_id}/test", response_model=AccountTestResult)
async def test_account_connection(
account_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Test connection for an email account."""
account = await get_email_account(account_id, user_id)
if not account:
raise HTTPException(status_code=404, detail="Account not found")
# Get credentials
credentials_service = get_credentials_service()
vault_path = account.get("vault_path", "")
creds = await credentials_service.get_credentials(account_id, vault_path)
if not creds:
return AccountTestResult(
success=False,
error_message="Credentials not found"
)
# Test connection
aggregator = get_mail_aggregator()
result = await aggregator.test_account_connection(
imap_host=account["imap_host"],
imap_port=account["imap_port"],
imap_ssl=account["imap_ssl"],
smtp_host=account["smtp_host"],
smtp_port=account["smtp_port"],
smtp_ssl=account["smtp_ssl"],
email_address=creds.email,
password=creds.password,
)
return result
class ConnectionTestRequest(BaseModel):
"""Request to test connection before saving account."""
email: str
imap_host: str
imap_port: int = 993
imap_ssl: bool = True
smtp_host: str
smtp_port: int = 465
smtp_ssl: bool = True
password: str
@router.post("/accounts/test-connection", response_model=AccountTestResult)
async def test_connection_before_save(request: ConnectionTestRequest):
"""
Test IMAP/SMTP connection before saving an account.
This allows the wizard to verify credentials are correct
before creating the account in the database.
"""
aggregator = get_mail_aggregator()
result = await aggregator.test_account_connection(
imap_host=request.imap_host,
imap_port=request.imap_port,
imap_ssl=request.imap_ssl,
smtp_host=request.smtp_host,
smtp_port=request.smtp_port,
smtp_ssl=request.smtp_ssl,
email_address=request.email,
password=request.password,
)
return result
@router.post("/accounts/{account_id}/sync")
async def sync_account(
account_id: str,
user_id: str = Query(..., description="User ID"),
max_emails: int = Query(100, ge=1, le=500),
background_tasks: BackgroundTasks = None,
):
"""Sync emails from an account."""
aggregator = get_mail_aggregator()
try:
new_count, total_count = await aggregator.sync_account(
account_id=account_id,
user_id=user_id,
max_emails=max_emails,
)
return {
"status": "synced",
"new_emails": new_count,
"total_emails": total_count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sync-all")
async def sync_all_accounts(
user_id: str = Query(..., description="User ID"),
tenant_id: Optional[str] = Query(None),
):
"""Sync all email accounts for a user."""
aggregator = get_mail_aggregator()
results = await aggregator.sync_all_accounts(user_id, tenant_id)
return {"status": "synced", "results": results}
+69
View File
@@ -0,0 +1,69 @@
"""
Mail API — AI analysis and response suggestion endpoints.
"""
import logging
from typing import List
from fastapi import APIRouter, HTTPException, Query
from .models import (
EmailAnalysisResult,
ResponseSuggestion,
)
from .mail_db import get_email
from .ai_service import get_ai_email_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
@router.post("/analyze/{email_id}", response_model=EmailAnalysisResult)
async def analyze_email(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Run AI analysis on an email."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
ai_service = get_ai_email_service()
result = await ai_service.analyze_email(
email_id=email_id,
sender_email=email_data.get("sender_email", ""),
sender_name=email_data.get("sender_name"),
subject=email_data.get("subject", ""),
body_text=email_data.get("body_text"),
body_preview=email_data.get("body_preview"),
)
return result
@router.get("/suggestions/{email_id}", response_model=List[ResponseSuggestion])
async def get_response_suggestions(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get AI-generated response suggestions for an email."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
ai_service = get_ai_email_service()
# Use stored analysis if available
from .models import SenderType, EmailCategory as EC
sender_type = SenderType(email_data.get("sender_type", "unbekannt"))
category = EC(email_data.get("category", "sonstiges"))
suggestions = await ai_service.suggest_response(
subject=email_data.get("subject", ""),
body_text=email_data.get("body_text", ""),
sender_type=sender_type,
category=category,
)
return suggestions
+123
View File
@@ -0,0 +1,123 @@
"""
Mail API — unified inbox, send, and email detail endpoints.
"""
import logging
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from .models import (
EmailComposeRequest,
EmailSendResult,
)
from .mail_db import (
get_unified_inbox,
get_email,
mark_email_read,
mark_email_starred,
log_mail_audit,
)
from .aggregator import get_mail_aggregator
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
@router.get("/inbox", response_model=List[dict])
async def get_inbox(
user_id: str = Query(..., description="User ID"),
account_ids: Optional[str] = Query(None, description="Comma-separated account IDs"),
categories: Optional[str] = Query(None, description="Comma-separated categories"),
is_read: Optional[bool] = Query(None),
is_starred: Optional[bool] = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
):
"""Get unified inbox with all accounts aggregated."""
# Parse comma-separated values
account_id_list = account_ids.split(",") if account_ids else None
category_list = categories.split(",") if categories else None
emails = await get_unified_inbox(
user_id=user_id,
account_ids=account_id_list,
categories=category_list,
is_read=is_read,
is_starred=is_starred,
limit=limit,
offset=offset,
)
return emails
@router.get("/inbox/{email_id}", response_model=dict)
async def get_email_detail(
email_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single email with full details."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
# Mark as read
await mark_email_read(email_id, user_id, is_read=True)
return email_data
@router.post("/inbox/{email_id}/read")
async def mark_read(
email_id: str,
user_id: str = Query(..., description="User ID"),
is_read: bool = Query(True),
):
"""Mark email as read/unread."""
success = await mark_email_read(email_id, user_id, is_read)
if not success:
raise HTTPException(status_code=500, detail="Failed to update email")
return {"status": "updated", "is_read": is_read}
@router.post("/inbox/{email_id}/star")
async def mark_starred(
email_id: str,
user_id: str = Query(..., description="User ID"),
is_starred: bool = Query(True),
):
"""Mark email as starred/unstarred."""
success = await mark_email_starred(email_id, user_id, is_starred)
if not success:
raise HTTPException(status_code=500, detail="Failed to update email")
return {"status": "updated", "is_starred": is_starred}
@router.post("/send", response_model=EmailSendResult)
async def send_email(
request: EmailComposeRequest,
user_id: str = Query(..., description="User ID"),
):
"""Send an email."""
aggregator = get_mail_aggregator()
result = await aggregator.send_email(
account_id=request.account_id,
user_id=user_id,
request=request,
)
if result.success:
await log_mail_audit(
user_id=user_id,
action="email_sent",
entity_type="email",
details={
"account_id": request.account_id,
"to": request.to,
"subject": request.subject,
},
)
return result
+176
View File
@@ -0,0 +1,176 @@
"""
Mail API — task (Arbeitsvorrat) endpoints.
"""
import logging
from datetime import datetime
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from .models import (
TaskCreate,
TaskUpdate,
TaskDashboardStats,
TaskStatus,
TaskPriority,
)
from .mail_db import get_email
from .task_service import get_task_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/mail", tags=["Mail"])
@router.get("/tasks", response_model=List[dict])
async def list_tasks(
user_id: str = Query(..., description="User ID"),
status: Optional[str] = Query(None, description="Filter by status"),
priority: Optional[str] = Query(None, description="Filter by priority"),
include_completed: bool = Query(False),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
):
"""Get all tasks for a user."""
task_service = get_task_service()
status_enum = TaskStatus(status) if status else None
priority_enum = TaskPriority(priority) if priority else None
tasks = await task_service.get_user_tasks(
user_id=user_id,
status=status_enum,
priority=priority_enum,
include_completed=include_completed,
limit=limit,
offset=offset,
)
return tasks
@router.post("/tasks", response_model=dict)
async def create_task(
request: TaskCreate,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a new task manually."""
task_service = get_task_service()
task_id = await task_service.create_manual_task(
user_id=user_id,
tenant_id=tenant_id,
task_data=request,
)
if not task_id:
raise HTTPException(status_code=500, detail="Failed to create task")
return {"id": task_id, "status": "created"}
@router.get("/tasks/dashboard", response_model=TaskDashboardStats)
async def get_task_dashboard(
user_id: str = Query(..., description="User ID"),
):
"""Get dashboard statistics for tasks."""
task_service = get_task_service()
return await task_service.get_dashboard_stats(user_id)
@router.get("/tasks/{task_id}", response_model=dict)
async def get_task(
task_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Get a single task."""
task_service = get_task_service()
task = await task_service.get_task(task_id, user_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.put("/tasks/{task_id}")
async def update_task(
task_id: str,
request: TaskUpdate,
user_id: str = Query(..., description="User ID"),
):
"""Update a task."""
task_service = get_task_service()
success = await task_service.update_task(task_id, user_id, request)
if not success:
raise HTTPException(status_code=500, detail="Failed to update task")
return {"status": "updated"}
@router.post("/tasks/{task_id}/complete")
async def complete_task(
task_id: str,
user_id: str = Query(..., description="User ID"),
):
"""Mark a task as completed."""
task_service = get_task_service()
success = await task_service.mark_completed(task_id, user_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to complete task")
return {"status": "completed"}
@router.post("/tasks/from-email/{email_id}")
async def create_task_from_email(
email_id: str,
user_id: str = Query(..., description="User ID"),
tenant_id: str = Query(..., description="Tenant ID"),
):
"""Create a task from an email (after analysis)."""
email_data = await get_email(email_id, user_id)
if not email_data:
raise HTTPException(status_code=404, detail="Email not found")
# Get deadlines from stored analysis
deadlines_raw = email_data.get("detected_deadlines", [])
from .models import DeadlineExtraction, SenderType
deadlines = []
for d in deadlines_raw:
try:
deadlines.append(DeadlineExtraction(
deadline_date=datetime.fromisoformat(d["date"]),
description=d.get("description", "Frist"),
confidence=0.8,
source_text="",
is_firm=d.get("is_firm", True),
))
except (KeyError, ValueError):
continue
sender_type = None
if email_data.get("sender_type"):
try:
sender_type = SenderType(email_data["sender_type"])
except ValueError:
pass
task_service = get_task_service()
task_id = await task_service.create_task_from_email(
user_id=user_id,
tenant_id=tenant_id,
email_id=email_id,
deadlines=deadlines,
sender_type=sender_type,
)
if not task_id:
raise HTTPException(status_code=500, detail="Failed to create task")
return {"id": task_id, "status": "created"}
+188
View File
@@ -0,0 +1,188 @@
"""
Orientation & Page-Split API endpoints (Steps 1 and 1b of OCR Pipeline).
"""
import logging
import time
from typing import Any, Dict
import cv2
from fastapi import APIRouter, HTTPException
from cv_vocab_pipeline import detect_and_fix_orientation
from page_crop import detect_page_splits
from ocr_pipeline_session_store import update_session_db
from orientation_crop_helpers import ensure_cached, append_pipeline_log
from page_sub_sessions import create_page_sub_sessions_full
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 1: Orientation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/orientation")
async def detect_orientation(session_id: str):
"""Detect and fix 90/180/270 degree rotations from scanners.
Reads the original image, applies orientation correction,
stores the result as oriented_png.
"""
cached = await ensure_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
# Detect and fix orientation
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
duration = time.time() - t0
orientation_result = {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
"duration_seconds": round(duration, 2),
}
# Encode oriented image
success, png_buf = cv2.imencode(".png", oriented_bgr)
oriented_png = png_buf.tobytes() if success else b""
# Update cache
cached["oriented_bgr"] = oriented_bgr
cached["orientation_result"] = orientation_result
# Persist to DB
await update_session_db(
session_id,
oriented_png=oriented_png,
orientation_result=orientation_result,
current_step=2,
)
logger.info(
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
session_id, orientation_deg,
"corrected" if orientation_deg else "no change",
duration,
)
await append_pipeline_log(session_id, "orientation", {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
}, duration_ms=int(duration * 1000))
h, w = oriented_bgr.shape[:2]
return {
"session_id": session_id,
**orientation_result,
"image_width": w,
"image_height": h,
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
}
# ---------------------------------------------------------------------------
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/page-split")
async def detect_page_split(session_id: str):
"""Detect if the image is a double-page book spread and split into sub-sessions.
Must be called **after orientation** (step 1) and **before deskew** (step 2).
Each sub-session receives the raw page region and goes through the full
pipeline (deskew -> dewarp -> crop -> columns -> rows -> words -> grid)
independently, so each page gets its own deskew correction.
Returns ``{"multi_page": false}`` if only one page is detected.
"""
cached = await ensure_cached(session_id)
# Use oriented (preferred), fall back to original
img_bgr = next(
(v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for page-split detection")
t0 = time.time()
page_splits = detect_page_splits(img_bgr)
used_original = False
if not page_splits or len(page_splits) < 2:
# Orientation may have rotated a landscape double-page spread to
# portrait. Try the original (pre-orientation) image as fallback.
orig_bgr = cached.get("original_bgr")
if orig_bgr is not None and orig_bgr is not img_bgr:
page_splits_orig = detect_page_splits(orig_bgr)
if page_splits_orig and len(page_splits_orig) >= 2:
logger.info(
"OCR Pipeline: page-split session %s: spread detected on "
"ORIGINAL (orientation rotated it away)",
session_id,
)
img_bgr = orig_bgr
page_splits = page_splits_orig
used_original = True
if not page_splits or len(page_splits) < 2:
duration = time.time() - t0
logger.info(
"OCR Pipeline: page-split session %s: single page (%.2fs)",
session_id, duration,
)
return {
"session_id": session_id,
"multi_page": False,
"duration_seconds": round(duration, 2),
}
# Multi-page spread detected — create sub-sessions for full pipeline.
# start_step=2 means "ready for deskew" (orientation already applied).
# start_step=1 means "needs orientation too" (split from original image).
start_step = 1 if used_original else 2
sub_sessions = await create_page_sub_sessions_full(
session_id, cached, img_bgr, page_splits, start_step=start_step,
)
duration = time.time() - t0
split_info: Dict[str, Any] = {
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"used_original": used_original,
"duration_seconds": round(duration, 2),
}
# Mark parent session as split and hidden from session list
await update_session_db(session_id, crop_result=split_info, status='split')
cached["crop_result"] = split_info
await append_pipeline_log(session_id, "page_split", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
logger.info(
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
session_id, len(page_splits), duration,
)
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**split_info,
"image_width": w,
"image_height": h,
"sub_sessions": sub_sessions,
}
+9 -687
View File
@@ -1,694 +1,16 @@
"""
Orientation & Crop API - Steps 1 and 4 of the OCR Pipeline.
Step 1: Orientation detection (fix 90/180/270 degree rotations)
Step 4 (UI index 3): Page cropping (after deskew + dewarp, so the image is straight)
These endpoints were extracted from the main pipeline to keep files manageable.
Barrel re-export: merges routers from orientation_api and crop_api,
and re-exports set_cache_ref for main.py.
"""
import logging
import time
import uuid as uuid_mod
from typing import Any, Dict, List, Optional
from fastapi import APIRouter
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from orientation_crop_helpers import set_cache_ref # noqa: F401
from orientation_api import router as _orientation_router
from crop_api import router as _crop_router
from cv_vocab_pipeline import detect_and_fix_orientation
from page_crop import detect_and_crop_page, detect_page_splits
from ocr_pipeline_session_store import (
create_session_db,
get_session_db,
get_session_image,
get_sub_sessions,
update_session_db,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
_cache: Dict[str, Dict[str, Any]] = {}
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
"""Set reference to the shared cache from ocr_pipeline_api."""
global _cache
_cache = cache
async def _ensure_cached(session_id: str) -> Dict[str, Any]:
"""Ensure session is in cache, loading from DB if needed."""
if session_id in _cache:
return _cache[session_id]
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
async def _append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
"""Append a step entry to the pipeline log."""
from datetime import datetime
session = await get_session_db(session_id)
if not session:
return
pipeline_log = session.get("pipeline_log") or {"steps": []}
pipeline_log["steps"].append({
"step": step,
"completed_at": datetime.utcnow().isoformat(),
"success": True,
"duration_ms": duration_ms,
"metrics": metrics,
})
await update_session_db(session_id, pipeline_log=pipeline_log)
# ---------------------------------------------------------------------------
# Step 1: Orientation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/orientation")
async def detect_orientation(session_id: str):
"""Detect and fix 90/180/270 degree rotations from scanners.
Reads the original image, applies orientation correction,
stores the result as oriented_png.
"""
cached = await _ensure_cached(session_id)
img_bgr = cached.get("original_bgr")
if img_bgr is None:
raise HTTPException(status_code=400, detail="Original image not available")
t0 = time.time()
# Detect and fix orientation
oriented_bgr, orientation_deg = detect_and_fix_orientation(img_bgr.copy())
duration = time.time() - t0
orientation_result = {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
"duration_seconds": round(duration, 2),
}
# Encode oriented image
success, png_buf = cv2.imencode(".png", oriented_bgr)
oriented_png = png_buf.tobytes() if success else b""
# Update cache
cached["oriented_bgr"] = oriented_bgr
cached["orientation_result"] = orientation_result
# Persist to DB
await update_session_db(
session_id,
oriented_png=oriented_png,
orientation_result=orientation_result,
current_step=2,
)
logger.info(
"OCR Pipeline: orientation session %s: %d° (%s) in %.2fs",
session_id, orientation_deg,
"corrected" if orientation_deg else "no change",
duration,
)
await _append_pipeline_log(session_id, "orientation", {
"orientation_degrees": orientation_deg,
"corrected": orientation_deg != 0,
}, duration_ms=int(duration * 1000))
h, w = oriented_bgr.shape[:2]
return {
"session_id": session_id,
**orientation_result,
"image_width": w,
"image_height": h,
"oriented_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/oriented",
}
# ---------------------------------------------------------------------------
# Step 1b: Page-split detection — runs AFTER orientation, BEFORE deskew
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/page-split")
async def detect_page_split(session_id: str):
"""Detect if the image is a double-page book spread and split into sub-sessions.
Must be called **after orientation** (step 1) and **before deskew** (step 2).
Each sub-session receives the raw page region and goes through the full
pipeline (deskew → dewarp → crop → columns → rows → words → grid)
independently, so each page gets its own deskew correction.
Returns ``{"multi_page": false}`` if only one page is detected.
"""
cached = await _ensure_cached(session_id)
# Use oriented (preferred), fall back to original
img_bgr = next(
(v for k in ("oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for page-split detection")
t0 = time.time()
page_splits = detect_page_splits(img_bgr)
used_original = False
if not page_splits or len(page_splits) < 2:
# Orientation may have rotated a landscape double-page spread to
# portrait. Try the original (pre-orientation) image as fallback.
orig_bgr = cached.get("original_bgr")
if orig_bgr is not None and orig_bgr is not img_bgr:
page_splits_orig = detect_page_splits(orig_bgr)
if page_splits_orig and len(page_splits_orig) >= 2:
logger.info(
"OCR Pipeline: page-split session %s: spread detected on "
"ORIGINAL (orientation rotated it away)",
session_id,
)
img_bgr = orig_bgr
page_splits = page_splits_orig
used_original = True
if not page_splits or len(page_splits) < 2:
duration = time.time() - t0
logger.info(
"OCR Pipeline: page-split session %s: single page (%.2fs)",
session_id, duration,
)
return {
"session_id": session_id,
"multi_page": False,
"duration_seconds": round(duration, 2),
}
# Multi-page spread detected — create sub-sessions for full pipeline.
# start_step=2 means "ready for deskew" (orientation already applied).
# start_step=1 means "needs orientation too" (split from original image).
start_step = 1 if used_original else 2
sub_sessions = await _create_page_sub_sessions_full(
session_id, cached, img_bgr, page_splits, start_step=start_step,
)
duration = time.time() - t0
split_info: Dict[str, Any] = {
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"used_original": used_original,
"duration_seconds": round(duration, 2),
}
# Mark parent session as split and hidden from session list
await update_session_db(session_id, crop_result=split_info, status='split')
cached["crop_result"] = split_info
await _append_pipeline_log(session_id, "page_split", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
logger.info(
"OCR Pipeline: page-split session %s: %d pages detected in %.2fs",
session_id, len(page_splits), duration,
)
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**split_info,
"image_width": w,
"image_height": h,
"sub_sessions": sub_sessions,
}
# ---------------------------------------------------------------------------
# Step 4 (UI index 3): Crop — runs after deskew + dewarp
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/crop")
async def auto_crop(session_id: str):
"""Auto-detect and crop scanner/book borders.
Reads the dewarped image (post-deskew + dewarp, so the page is straight).
Falls back to oriented → original if earlier steps were skipped.
If the image is a multi-page spread (e.g. book on scanner), it will
automatically split into separate sub-sessions per page, crop each
individually, and return the split info.
"""
cached = await _ensure_cached(session_id)
# Use dewarped (preferred), fall back to oriented, then original
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
t0 = time.time()
# --- Check for existing sub-sessions (from page-split step) ---
# If page-split already created sub-sessions, skip multi-page detection
# in the crop step. Each sub-session runs its own crop independently.
existing_subs = await get_sub_sessions(session_id)
if existing_subs:
crop_result = cached.get("crop_result") or {}
if crop_result.get("multi_page"):
# Already split — just return the existing info
duration = time.time() - t0
h, w = img_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"sub_sessions": [
{"id": s["id"], "name": s.get("name"), "page_index": s.get("box_index", i)}
for i, s in enumerate(existing_subs)
],
"note": "Page split was already performed; each sub-session runs its own crop.",
}
# --- Multi-page detection (fallback for sessions that skipped page-split) ---
page_splits = detect_page_splits(img_bgr)
if page_splits and len(page_splits) >= 2:
# Multi-page spread detected — create sub-sessions
sub_sessions = await _create_page_sub_sessions(
session_id, cached, img_bgr, page_splits,
)
duration = time.time() - t0
crop_info: Dict[str, Any] = {
"crop_applied": True,
"multi_page": True,
"page_count": len(page_splits),
"page_splits": page_splits,
"duration_seconds": round(duration, 2),
}
cached["crop_result"] = crop_info
# Store the first page as the main cropped image for backward compat
first_page = page_splits[0]
first_bgr = img_bgr[
first_page["y"]:first_page["y"] + first_page["height"],
first_page["x"]:first_page["x"] + first_page["width"],
].copy()
first_cropped, _ = detect_and_crop_page(first_bgr)
cached["cropped_bgr"] = first_cropped
ok, png_buf = cv2.imencode(".png", first_cropped)
await update_session_db(
session_id,
cropped_png=png_buf.tobytes() if ok else b"",
crop_result=crop_info,
current_step=5,
status='split',
)
logger.info(
"OCR Pipeline: crop session %s: multi-page split into %d pages in %.2fs",
session_id, len(page_splits), duration,
)
await _append_pipeline_log(session_id, "crop", {
"multi_page": True,
"page_count": len(page_splits),
}, duration_ms=int(duration * 1000))
h, w = first_cropped.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
"sub_sessions": sub_sessions,
}
# --- Single page (normal) ---
cropped_bgr, crop_info = detect_and_crop_page(img_bgr)
duration = time.time() - t0
crop_info["duration_seconds"] = round(duration, 2)
crop_info["multi_page"] = False
# Encode cropped image
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
# Update cache
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_info
# Persist to DB
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_info,
current_step=5,
)
logger.info(
"OCR Pipeline: crop session %s: applied=%s format=%s in %.2fs",
session_id, crop_info["crop_applied"],
crop_info.get("detected_format", "?"),
duration,
)
await _append_pipeline_log(session_id, "crop", {
"crop_applied": crop_info["crop_applied"],
"detected_format": crop_info.get("detected_format"),
"format_confidence": crop_info.get("format_confidence"),
}, duration_ms=int(duration * 1000))
h, w = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_info,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
async def _create_page_sub_sessions(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each detected page in a multi-page spread.
Each page region is individually cropped, then stored as a sub-session
with its own cropped image ready for the rest of the pipeline.
"""
# Check for existing sub-sessions (idempotent)
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract page region
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Crop each page individually (remove its own borders)
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
# Encode as PNG
ok, png_buf = cv2.imencode(".png", cropped_page)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# Pre-populate: set cropped = original (already cropped)
await update_session_db(
sub_id,
cropped_png=page_png,
crop_result=page_crop_info,
current_step=5,
)
ch, cw = cropped_page.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"cropped_size": {"width": cw, "height": ch},
"detected_format": page_crop_info.get("detected_format"),
})
logger.info(
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
sub_id, pi + 1, px, pw, cw, ch,
)
return sub_sessions
async def _create_page_sub_sessions_full(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
start_step: int = 2,
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
Unlike ``_create_page_sub_sessions`` (used by the crop step), these
sub-sessions store the *uncropped* page region and start at
``start_step`` (default 2 = ready for deskew; 1 if orientation still
needed). Each page goes through its own pipeline independently,
which is essential for book spreads where each page has a different tilt.
"""
# Idempotent: reuse existing sub-sessions
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract RAW page region — NO individual cropping here; each
# sub-session will run its own crop step after deskew + dewarp.
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Encode as PNG
ok, png_buf = cv2.imencode(".png", page_bgr)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# start_step=2 → ready for deskew (orientation already done on spread)
# start_step=1 → needs its own orientation (split from original image)
await update_session_db(sub_id, current_step=start_step)
# Cache the BGR so the pipeline can start immediately
_cache[sub_id] = {
"id": sub_id,
"filename": parent_filename,
"name": sub_name,
"original_bgr": page_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": start_step,
}
rh, rw = page_bgr.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"image_size": {"width": rw, "height": rh},
})
logger.info(
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d%dx%d",
sub_id, pi + 1, px, pw, rw, rh,
)
return sub_sessions
class ManualCropRequest(BaseModel):
x: float # percentage 0-100
y: float # percentage 0-100
width: float # percentage 0-100
height: float # percentage 0-100
@router.post("/sessions/{session_id}/crop/manual")
async def manual_crop(session_id: str, req: ManualCropRequest):
"""Manually crop using percentage coordinates."""
cached = await _ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available for cropping")
h, w = img_bgr.shape[:2]
# Convert percentages to pixels
px_x = int(w * req.x / 100.0)
px_y = int(h * req.y / 100.0)
px_w = int(w * req.width / 100.0)
px_h = int(h * req.height / 100.0)
# Clamp
px_x = max(0, min(px_x, w - 1))
px_y = max(0, min(px_y, h - 1))
px_w = max(1, min(px_w, w - px_x))
px_h = max(1, min(px_h, h - px_y))
cropped_bgr = img_bgr[px_y:px_y + px_h, px_x:px_x + px_w].copy()
success, png_buf = cv2.imencode(".png", cropped_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": True,
"crop_rect": {"x": px_x, "y": px_y, "width": px_w, "height": px_h},
"crop_rect_pct": {"x": round(req.x, 2), "y": round(req.y, 2),
"width": round(req.width, 2), "height": round(req.height, 2)},
"original_size": {"width": w, "height": h},
"cropped_size": {"width": px_w, "height": px_h},
"method": "manual",
}
cached["cropped_bgr"] = cropped_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
ch, cw = cropped_bgr.shape[:2]
return {
"session_id": session_id,
**crop_result,
"image_width": cw,
"image_height": ch,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
@router.post("/sessions/{session_id}/crop/skip")
async def skip_crop(session_id: str):
"""Skip cropping — use dewarped (or oriented/original) image as-is."""
cached = await _ensure_cached(session_id)
img_bgr = next(
(v for k in ("dewarped_bgr", "oriented_bgr", "original_bgr")
if (v := cached.get(k)) is not None),
None,
)
if img_bgr is None:
raise HTTPException(status_code=400, detail="No image available")
h, w = img_bgr.shape[:2]
# Store the dewarped image as cropped (identity crop)
success, png_buf = cv2.imencode(".png", img_bgr)
cropped_png = png_buf.tobytes() if success else b""
crop_result = {
"crop_applied": False,
"skipped": True,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
}
cached["cropped_bgr"] = img_bgr
cached["crop_result"] = crop_result
await update_session_db(
session_id,
cropped_png=cropped_png,
crop_result=crop_result,
current_step=5,
)
return {
"session_id": session_id,
**crop_result,
"image_width": w,
"image_height": h,
"cropped_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/cropped",
}
router = APIRouter()
router.include_router(_orientation_router)
router.include_router(_crop_router)
@@ -0,0 +1,86 @@
"""
Orientation & Crop shared helpers - cache management and pipeline logging.
"""
import logging
from typing import Any, Dict
import cv2
import numpy as np
from fastapi import HTTPException
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
logger = logging.getLogger(__name__)
# Reference to the shared cache from ocr_pipeline_api (set in main.py)
_cache: Dict[str, Dict[str, Any]] = {}
def set_cache_ref(cache: Dict[str, Dict[str, Any]]):
"""Set reference to the shared cache from ocr_pipeline_api."""
global _cache
_cache = cache
def get_cache_ref() -> Dict[str, Dict[str, Any]]:
"""Get reference to the shared cache."""
return _cache
async def ensure_cached(session_id: str) -> Dict[str, Any]:
"""Ensure session is in cache, loading from DB if needed."""
if session_id in _cache:
return _cache[session_id]
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
cache_entry: Dict[str, Any] = {
"id": session_id,
**session,
"original_bgr": None,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
}
for img_type, bgr_key in [
("original", "original_bgr"),
("oriented", "oriented_bgr"),
("cropped", "cropped_bgr"),
("deskewed", "deskewed_bgr"),
("dewarped", "dewarped_bgr"),
]:
png_data = await get_session_image(session_id, img_type)
if png_data:
arr = np.frombuffer(png_data, dtype=np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
cache_entry[bgr_key] = bgr
_cache[session_id] = cache_entry
return cache_entry
async def append_pipeline_log(session_id: str, step: str, metrics: dict, duration_ms: int):
"""Append a step entry to the pipeline log."""
from datetime import datetime
session = await get_session_db(session_id)
if not session:
return
pipeline_log = session.get("pipeline_log") or {"steps": []}
pipeline_log["steps"].append({
"step": step,
"completed_at": datetime.utcnow().isoformat(),
"success": True,
"duration_ms": duration_ms,
"metrics": metrics,
})
await update_session_db(session_id, pipeline_log=pipeline_log)
@@ -0,0 +1,189 @@
"""
Sub-session creation for multi-page spreads.
Used by both the page-split and crop steps when a double-page scan is detected.
"""
import logging
import uuid as uuid_mod
from typing import Any, Dict, List
import cv2
import numpy as np
from page_crop import detect_and_crop_page
from ocr_pipeline_session_store import (
create_session_db,
get_sub_sessions,
update_session_db,
)
from orientation_crop_helpers import get_cache_ref
logger = logging.getLogger(__name__)
async def create_page_sub_sessions(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each detected page in a multi-page spread.
Each page region is individually cropped, then stored as a sub-session
with its own cropped image ready for the rest of the pipeline.
"""
# Check for existing sub-sessions (idempotent)
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract page region
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Crop each page individually (remove its own borders)
cropped_page, page_crop_info = detect_and_crop_page(page_bgr)
# Encode as PNG
ok, png_buf = cv2.imencode(".png", cropped_page)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# Pre-populate: set cropped = original (already cropped)
await update_session_db(
sub_id,
cropped_png=page_png,
crop_result=page_crop_info,
current_step=5,
)
ch, cw = cropped_page.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"cropped_size": {"width": cw, "height": ch},
"detected_format": page_crop_info.get("detected_format"),
})
logger.info(
"Page sub-session %s: page %d, region x=%d w=%d -> cropped %dx%d",
sub_id, pi + 1, px, pw, cw, ch,
)
return sub_sessions
async def create_page_sub_sessions_full(
parent_session_id: str,
parent_cached: dict,
full_img_bgr: np.ndarray,
page_splits: List[Dict[str, Any]],
start_step: int = 2,
) -> List[Dict[str, Any]]:
"""Create sub-sessions for each page with RAW regions for full pipeline processing.
Unlike ``create_page_sub_sessions`` (used by the crop step), these
sub-sessions store the *uncropped* page region and start at
``start_step`` (default 2 = ready for deskew; 1 if orientation still
needed). Each page goes through its own pipeline independently,
which is essential for book spreads where each page has a different tilt.
"""
_cache = get_cache_ref()
# Idempotent: reuse existing sub-sessions
existing = await get_sub_sessions(parent_session_id)
if existing:
return [
{"id": s["id"], "name": s["name"], "page_index": s.get("box_index", i)}
for i, s in enumerate(existing)
]
parent_name = parent_cached.get("name", "Scan")
parent_filename = parent_cached.get("filename", "scan.png")
sub_sessions: List[Dict[str, Any]] = []
for page in page_splits:
pi = page["page_index"]
px, py = page["x"], page["y"]
pw, ph = page["width"], page["height"]
# Extract RAW page region — NO individual cropping here; each
# sub-session will run its own crop step after deskew + dewarp.
page_bgr = full_img_bgr[py:py + ph, px:px + pw].copy()
# Encode as PNG
ok, png_buf = cv2.imencode(".png", page_bgr)
page_png = png_buf.tobytes() if ok else b""
sub_id = str(uuid_mod.uuid4())
sub_name = f"{parent_name} — Seite {pi + 1}"
await create_session_db(
session_id=sub_id,
name=sub_name,
filename=parent_filename,
original_png=page_png,
)
# start_step=2 -> ready for deskew (orientation already done on spread)
# start_step=1 -> needs its own orientation (split from original image)
await update_session_db(sub_id, current_step=start_step)
# Cache the BGR so the pipeline can start immediately
_cache[sub_id] = {
"id": sub_id,
"filename": parent_filename,
"name": sub_name,
"original_bgr": page_bgr,
"oriented_bgr": None,
"cropped_bgr": None,
"deskewed_bgr": None,
"dewarped_bgr": None,
"orientation_result": None,
"crop_result": None,
"deskew_result": None,
"dewarp_result": None,
"ground_truth": {},
"current_step": start_step,
}
rh, rw = page_bgr.shape[:2]
sub_sessions.append({
"id": sub_id,
"name": sub_name,
"page_index": pi,
"source_rect": page,
"image_size": {"width": rw, "height": rh},
})
logger.info(
"Page sub-session %s (full pipeline): page %d, region x=%d w=%d -> %dx%d",
sub_id, pi + 1, px, pw, rw, rh,
)
return sub_sessions
+11 -671
View File
@@ -1,677 +1,17 @@
"""
PDF Export Module for Abiturkorrektur System
Generates:
- Individual Gutachten PDFs for each student
- Klausur overview PDFs with grade distribution
- Niedersachsen-compliant formatting
Barrel re-export: all PDF generation functions and constants.
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import cm, mm
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY, TA_RIGHT
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
PageBreak, HRFlowable, Image, KeepTogether
from pdf_export_styles import ( # noqa: F401
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
from pdf_export_gutachten import generate_gutachten_pdf # noqa: F401
from pdf_export_overview import ( # noqa: F401
generate_klausur_overview_pdf,
generate_annotations_pdf,
)
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.ttfonts import TTFont
# =============================================
# CONSTANTS
# =============================================
GRADE_POINTS_TO_NOTE = {
15: "1+", 14: "1", 13: "1-",
12: "2+", 11: "2", 10: "2-",
9: "3+", 8: "3", 7: "3-",
6: "4+", 5: "4", 4: "4-",
3: "5+", 2: "5", 1: "5-",
0: "6"
}
CRITERIA_DISPLAY_NAMES = {
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
"inhalt": "Inhaltliche Leistung",
"struktur": "Aufbau und Struktur",
"stil": "Ausdruck und Stil"
}
CRITERIA_WEIGHTS = {
"rechtschreibung": 15,
"grammatik": 15,
"inhalt": 40,
"struktur": 15,
"stil": 15
}
# =============================================
# STYLES
# =============================================
def get_custom_styles():
"""Create custom paragraph styles for Gutachten."""
styles = getSampleStyleSheet()
# Title style
styles.add(ParagraphStyle(
name='GutachtenTitle',
parent=styles['Heading1'],
fontSize=16,
spaceAfter=12,
alignment=TA_CENTER,
textColor=colors.HexColor('#1e3a5f')
))
# Subtitle style
styles.add(ParagraphStyle(
name='GutachtenSubtitle',
parent=styles['Heading2'],
fontSize=12,
spaceAfter=8,
spaceBefore=16,
textColor=colors.HexColor('#2c5282')
))
# Section header
styles.add(ParagraphStyle(
name='SectionHeader',
parent=styles['Heading3'],
fontSize=11,
spaceAfter=6,
spaceBefore=12,
textColor=colors.HexColor('#2d3748'),
borderColor=colors.HexColor('#e2e8f0'),
borderWidth=0,
borderPadding=0
))
# Body text
styles.add(ParagraphStyle(
name='GutachtenBody',
parent=styles['Normal'],
fontSize=10,
leading=14,
alignment=TA_JUSTIFY,
spaceAfter=6
))
# Small text for footer/meta
styles.add(ParagraphStyle(
name='MetaText',
parent=styles['Normal'],
fontSize=8,
textColor=colors.grey,
alignment=TA_LEFT
))
# List item
styles.add(ParagraphStyle(
name='ListItem',
parent=styles['Normal'],
fontSize=10,
leftIndent=20,
bulletIndent=10,
spaceAfter=4
))
return styles
# =============================================
# PDF GENERATION FUNCTIONS
# =============================================
def generate_gutachten_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]] = None,
workflow_data: Dict[str, Any] = None
) -> bytes:
"""
Generate a PDF Gutachten for a single student.
Args:
student_data: Student work data including criteria_scores, gutachten, grade_points
klausur_data: Klausur metadata (title, subject, year, etc.)
annotations: List of annotations for annotation summary
workflow_data: Examiner workflow data (EK, ZK, DK info)
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information table
meta_data = [
["Pruefling:", student_data.get('student_name', 'Anonym')],
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Datum:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Gutachten content
gutachten = student_data.get('gutachten', {})
if gutachten:
# Einleitung
if gutachten.get('einleitung'):
story.append(Paragraph("Einleitung", styles['SectionHeader']))
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
# Hauptteil
if gutachten.get('hauptteil'):
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
# Fazit
if gutachten.get('fazit'):
story.append(Paragraph("Fazit", styles['SectionHeader']))
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
# Staerken und Schwaechen
if gutachten.get('staerken') or gutachten.get('schwaechen'):
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken'):
story.append(Paragraph("Staerken:", styles['SectionHeader']))
for s in gutachten['staerken']:
story.append(Paragraph(f"{s}", styles['ListItem']))
if gutachten.get('schwaechen'):
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
for s in gutachten['schwaechen']:
story.append(Paragraph(f"{s}", styles['ListItem']))
else:
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Bewertungstabelle
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
criteria_scores = student_data.get('criteria_scores', {})
# Build criteria table data
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
total_weighted = 0
total_weight = 0
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
weight = CRITERIA_WEIGHTS.get(key, 0)
score_data = criteria_scores.get(key, {})
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
# Calculate weighted contribution
weighted_score = (score / 100) * weight if score else 0
total_weighted += weighted_score
total_weight += weight
table_data.append([
display_name,
f"{weight}%",
f"{score}%",
f"{weighted_score:.1f}"
])
# Add total row
table_data.append([
"Gesamt",
f"{total_weight}%",
"",
f"{total_weighted:.1f}"
])
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
criteria_table.setStyle(TableStyle([
# Header row
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 10),
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
# Body rows
('FONTSIZE', (0, 1), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
# Grid
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
# Total row
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
# Alternating row colors
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(criteria_table)
story.append(Spacer(1, 0.5*cm))
# Final grade box
grade_points = student_data.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
raw_points = student_data.get('raw_points', 0)
grade_data = [
["Rohpunkte:", f"{raw_points} / 100"],
["Notenpunkte:", f"{grade_points} Punkte"],
["Note:", grade_note]
]
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
grade_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 11),
('FONTSIZE', (1, -1), (1, -1), 14),
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
('TOPPADDING', (0, 0), (-1, -1), 8),
('LEFTPADDING', (0, 0), (-1, -1), 12),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
]))
story.append(KeepTogether([
Paragraph("Endergebnis", styles['SectionHeader']),
Spacer(1, 0.2*cm),
grade_table
]))
# Examiner workflow information
if workflow_data:
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
workflow_rows = []
if workflow_data.get('erst_korrektor'):
ek = workflow_data['erst_korrektor']
workflow_rows.append([
"Erstkorrektor:",
ek.get('name', 'Unbekannt'),
f"{ek.get('grade_points', '-')} Punkte"
])
if workflow_data.get('zweit_korrektor'):
zk = workflow_data['zweit_korrektor']
workflow_rows.append([
"Zweitkorrektor:",
zk.get('name', 'Unbekannt'),
f"{zk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('dritt_korrektor'):
dk = workflow_data['dritt_korrektor']
workflow_rows.append([
"Drittkorrektor:",
dk.get('name', 'Unbekannt'),
f"{dk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('final_grade_source'):
workflow_rows.append([
"Endnote durch:",
workflow_data['final_grade_source'],
""
])
if workflow_rows:
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
workflow_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(workflow_table)
# Annotation summary (if any)
if annotations:
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
# Group annotations by type
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
# Footer with generation info
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def generate_klausur_overview_pdf(
klausur_data: Dict[str, Any],
students: List[Dict[str, Any]],
fairness_data: Optional[Dict[str, Any]] = None
) -> bytes:
"""
Generate an overview PDF for an entire Klausur with all student grades.
Args:
klausur_data: Klausur metadata
students: List of all student work data
fairness_data: Optional fairness analysis data
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=1.5*cm,
leftMargin=1.5*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information
meta_data = [
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Anzahl Arbeiten:", str(len(students))],
["Stand:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
# Statistics (if fairness data available)
if fairness_data and fairness_data.get('statistics'):
stats = fairness_data['statistics']
story.append(Paragraph("Statistik", styles['SectionHeader']))
stats_data = [
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
]
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
stats_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(stats_table)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Student grades table
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
# Sort students by grade (descending)
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
# Build table header
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
for idx, student in enumerate(sorted_students, 1):
grade_points = student.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
raw_points = student.get('raw_points', 0)
status = student.get('status', 'unknown')
# Format status
status_display = {
'completed': 'Abgeschlossen',
'first_examiner': 'In Korrektur',
'second_examiner': 'Zweitkorrektur',
'uploaded': 'Hochgeladen',
'ocr_complete': 'OCR fertig',
'analyzing': 'Wird analysiert'
}.get(status, status)
table_data.append([
str(idx),
student.get('student_name', 'Anonym'),
f"{raw_points}/100",
str(grade_points),
grade_note,
status_display
])
# Create table
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
student_table.setStyle(TableStyle([
# Header
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 9),
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
# Body
('FONTSIZE', (0, 1), (-1, -1), 9),
('ALIGN', (0, 1), (0, -1), 'CENTER'),
('ALIGN', (2, 1), (4, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
# Grid
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
# Alternating rows
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(student_table)
# Grade distribution
story.append(Spacer(1, 0.5*cm))
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
# Count grades
grade_counts = {}
for student in sorted_students:
gp = student.get('grade_points', 0)
grade_counts[gp] = grade_counts.get(gp, 0) + 1
# Build grade distribution table
dist_data = [["Punkte", "Note", "Anzahl"]]
for points in range(15, -1, -1):
if points in grade_counts:
note = GRADE_POINTS_TO_NOTE.get(points, "-")
count = grade_counts[points]
dist_data.append([str(points), note, str(count)])
if len(dist_data) > 1:
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
dist_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(dist_table)
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def generate_annotations_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]]
) -> bytes:
"""
Generate a PDF with all annotations for a student work.
Args:
student_data: Student work data
klausur_data: Klausur metadata
annotations: List of all annotations
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
if not annotations:
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
else:
# Group by type
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
# Sort by page then position
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
for idx, ann in enumerate(sorted_anns, 1):
page = ann.get('page', 1)
text = ann.get('text', '')
suggestion = ann.get('suggestion', '')
severity = ann.get('severity', 'minor')
# Build annotation text
ann_text = f"<b>[S.{page}]</b> {text}"
if suggestion:
ann_text += f" → <i>{suggestion}</i>"
# Color code by severity
if severity == 'critical':
ann_text = f"<font color='red'>{ann_text}</font>"
elif severity == 'major':
ann_text = f"<font color='orange'>{ann_text}</font>"
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
story.append(Spacer(1, 0.3*cm))
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
@@ -0,0 +1,315 @@
"""
PDF Export - Individual Gutachten PDF generation.
Generates a single student's Gutachten with criteria table,
workflow info, and annotation summary.
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable, KeepTogether
)
from pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
def generate_gutachten_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]] = None,
workflow_data: Dict[str, Any] = None
) -> bytes:
"""
Generate a PDF Gutachten for a single student.
Args:
student_data: Student work data including criteria_scores, gutachten, grade_points
klausur_data: Klausur metadata (title, subject, year, etc.)
annotations: List of annotations for annotation summary
workflow_data: Examiner workflow data (EK, ZK, DK info)
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information table
meta_data = [
["Pruefling:", student_data.get('student_name', 'Anonym')],
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Datum:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Gutachten content
_add_gutachten_content(story, styles, student_data)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Bewertungstabelle
_add_criteria_table(story, styles, student_data)
# Final grade box
_add_grade_box(story, styles, student_data)
# Examiner workflow information
if workflow_data:
_add_workflow_info(story, styles, workflow_data)
# Annotation summary
if annotations:
_add_annotation_summary(story, styles, annotations)
# Footer
_add_footer(story, styles)
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_gutachten_content(story, styles, student_data):
"""Add gutachten text sections to the story."""
gutachten = student_data.get('gutachten', {})
if gutachten:
if gutachten.get('einleitung'):
story.append(Paragraph("Einleitung", styles['SectionHeader']))
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('hauptteil'):
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('fazit'):
story.append(Paragraph("Fazit", styles['SectionHeader']))
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken') or gutachten.get('schwaechen'):
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken'):
story.append(Paragraph("Staerken:", styles['SectionHeader']))
for s in gutachten['staerken']:
story.append(Paragraph(f"{s}", styles['ListItem']))
if gutachten.get('schwaechen'):
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
for s in gutachten['schwaechen']:
story.append(Paragraph(f"{s}", styles['ListItem']))
else:
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
def _add_criteria_table(story, styles, student_data):
"""Add criteria scoring table to the story."""
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
criteria_scores = student_data.get('criteria_scores', {})
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
total_weighted = 0
total_weight = 0
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
weight = CRITERIA_WEIGHTS.get(key, 0)
score_data = criteria_scores.get(key, {})
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
weighted_score = (score / 100) * weight if score else 0
total_weighted += weighted_score
total_weight += weight
table_data.append([
display_name,
f"{weight}%",
f"{score}%",
f"{weighted_score:.1f}"
])
table_data.append([
"Gesamt",
f"{total_weight}%",
"",
f"{total_weighted:.1f}"
])
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
criteria_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 10),
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(criteria_table)
story.append(Spacer(1, 0.5*cm))
def _add_grade_box(story, styles, student_data):
"""Add final grade box to the story."""
grade_points = student_data.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
raw_points = student_data.get('raw_points', 0)
grade_data = [
["Rohpunkte:", f"{raw_points} / 100"],
["Notenpunkte:", f"{grade_points} Punkte"],
["Note:", grade_note]
]
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
grade_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 11),
('FONTSIZE', (1, -1), (1, -1), 14),
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
('TOPPADDING', (0, 0), (-1, -1), 8),
('LEFTPADDING', (0, 0), (-1, -1), 12),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
]))
story.append(KeepTogether([
Paragraph("Endergebnis", styles['SectionHeader']),
Spacer(1, 0.2*cm),
grade_table
]))
def _add_workflow_info(story, styles, workflow_data):
"""Add examiner workflow information to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
workflow_rows = []
if workflow_data.get('erst_korrektor'):
ek = workflow_data['erst_korrektor']
workflow_rows.append([
"Erstkorrektor:",
ek.get('name', 'Unbekannt'),
f"{ek.get('grade_points', '-')} Punkte"
])
if workflow_data.get('zweit_korrektor'):
zk = workflow_data['zweit_korrektor']
workflow_rows.append([
"Zweitkorrektor:",
zk.get('name', 'Unbekannt'),
f"{zk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('dritt_korrektor'):
dk = workflow_data['dritt_korrektor']
workflow_rows.append([
"Drittkorrektor:",
dk.get('name', 'Unbekannt'),
f"{dk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('final_grade_source'):
workflow_rows.append([
"Endnote durch:",
workflow_data['final_grade_source'],
""
])
if workflow_rows:
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
workflow_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(workflow_table)
def _add_annotation_summary(story, styles, annotations):
"""Add annotation summary to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
def _add_footer(story, styles):
"""Add generation footer to the story."""
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
@@ -0,0 +1,297 @@
"""
PDF Export - Klausur overview and annotations PDF generation.
Generates:
- Klausur overview with grade distribution for all students
- Annotations PDF for a single student
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable
)
from pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
get_custom_styles,
)
def generate_klausur_overview_pdf(
klausur_data: Dict[str, Any],
students: List[Dict[str, Any]],
fairness_data: Optional[Dict[str, Any]] = None
) -> bytes:
"""
Generate an overview PDF for an entire Klausur with all student grades.
Args:
klausur_data: Klausur metadata
students: List of all student work data
fairness_data: Optional fairness analysis data
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=1.5*cm,
leftMargin=1.5*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information
meta_data = [
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Anzahl Arbeiten:", str(len(students))],
["Stand:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
# Statistics (if fairness data available)
if fairness_data and fairness_data.get('statistics'):
_add_statistics(story, styles, fairness_data['statistics'])
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Student grades table
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
_add_student_table(story, styles, sorted_students)
# Grade distribution
_add_grade_distribution(story, styles, sorted_students)
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_statistics(story, styles, stats):
"""Add statistics section."""
story.append(Paragraph("Statistik", styles['SectionHeader']))
stats_data = [
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
]
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
stats_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(stats_table)
story.append(Spacer(1, 0.5*cm))
def _add_student_table(story, styles, sorted_students):
"""Add student grades table."""
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
for idx, student in enumerate(sorted_students, 1):
grade_points = student.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
raw_points = student.get('raw_points', 0)
status = student.get('status', 'unknown')
status_display = {
'completed': 'Abgeschlossen',
'first_examiner': 'In Korrektur',
'second_examiner': 'Zweitkorrektur',
'uploaded': 'Hochgeladen',
'ocr_complete': 'OCR fertig',
'analyzing': 'Wird analysiert'
}.get(status, status)
table_data.append([
str(idx),
student.get('student_name', 'Anonym'),
f"{raw_points}/100",
str(grade_points),
grade_note,
status_display
])
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
student_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 9),
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('ALIGN', (0, 1), (0, -1), 'CENTER'),
('ALIGN', (2, 1), (4, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(student_table)
def _add_grade_distribution(story, styles, sorted_students):
"""Add grade distribution table."""
story.append(Spacer(1, 0.5*cm))
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
grade_counts = {}
for student in sorted_students:
gp = student.get('grade_points', 0)
grade_counts[gp] = grade_counts.get(gp, 0) + 1
dist_data = [["Punkte", "Note", "Anzahl"]]
for points in range(15, -1, -1):
if points in grade_counts:
note = GRADE_POINTS_TO_NOTE.get(points, "-")
count = grade_counts[points]
dist_data.append([str(points), note, str(count)])
if len(dist_data) > 1:
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
dist_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(dist_table)
def generate_annotations_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]]
) -> bytes:
"""
Generate a PDF with all annotations for a student work.
Args:
student_data: Student work data
klausur_data: Klausur metadata
annotations: List of all annotations
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
if not annotations:
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
else:
# Group by type
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
for idx, ann in enumerate(sorted_anns, 1):
page = ann.get('page', 1)
text = ann.get('text', '')
suggestion = ann.get('suggestion', '')
severity = ann.get('severity', 'minor')
ann_text = f"<b>[S.{page}]</b> {text}"
if suggestion:
ann_text += f" -> <i>{suggestion}</i>"
if severity == 'critical':
ann_text = f"<font color='red'>{ann_text}</font>"
elif severity == 'major':
ann_text = f"<font color='orange'>{ann_text}</font>"
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
story.append(Spacer(1, 0.3*cm))
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
@@ -0,0 +1,110 @@
"""
PDF Export - Constants and ReportLab styles for Abiturkorrektur PDFs.
"""
from reportlab.lib import colors
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
# =============================================
# CONSTANTS
# =============================================
GRADE_POINTS_TO_NOTE = {
15: "1+", 14: "1", 13: "1-",
12: "2+", 11: "2", 10: "2-",
9: "3+", 8: "3", 7: "3-",
6: "4+", 5: "4", 4: "4-",
3: "5+", 2: "5", 1: "5-",
0: "6"
}
CRITERIA_DISPLAY_NAMES = {
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
"inhalt": "Inhaltliche Leistung",
"struktur": "Aufbau und Struktur",
"stil": "Ausdruck und Stil"
}
CRITERIA_WEIGHTS = {
"rechtschreibung": 15,
"grammatik": 15,
"inhalt": 40,
"struktur": 15,
"stil": 15
}
# =============================================
# STYLES
# =============================================
def get_custom_styles():
"""Create custom paragraph styles for Gutachten."""
styles = getSampleStyleSheet()
# Title style
styles.add(ParagraphStyle(
name='GutachtenTitle',
parent=styles['Heading1'],
fontSize=16,
spaceAfter=12,
alignment=TA_CENTER,
textColor=colors.HexColor('#1e3a5f')
))
# Subtitle style
styles.add(ParagraphStyle(
name='GutachtenSubtitle',
parent=styles['Heading2'],
fontSize=12,
spaceAfter=8,
spaceBefore=16,
textColor=colors.HexColor('#2c5282')
))
# Section header
styles.add(ParagraphStyle(
name='SectionHeader',
parent=styles['Heading3'],
fontSize=11,
spaceAfter=6,
spaceBefore=12,
textColor=colors.HexColor('#2d3748'),
borderColor=colors.HexColor('#e2e8f0'),
borderWidth=0,
borderPadding=0
))
# Body text
styles.add(ParagraphStyle(
name='GutachtenBody',
parent=styles['Normal'],
fontSize=10,
leading=14,
alignment=TA_JUSTIFY,
spaceAfter=6
))
# Small text for footer/meta
styles.add(ParagraphStyle(
name='MetaText',
parent=styles['Normal'],
fontSize=8,
textColor=colors.grey,
alignment=TA_LEFT
))
# List item
styles.add(ParagraphStyle(
name='ListItem',
parent=styles['Normal'],
fontSize=10,
leftIndent=20,
bulletIndent=10,
spaceAfter=4
))
return styles
+146
View File
@@ -0,0 +1,146 @@
"""
Qdrant Vector Database Service — QdrantService class for NiBiS Ingestion Pipeline.
"""
from typing import List, Dict, Optional
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
from qdrant_core import QDRANT_URL, VECTOR_SIZE
class QdrantService:
"""
Class-based Qdrant service for flexible collection management.
Used by nibis_ingestion.py for bulk indexing.
"""
def __init__(self, url: str = None):
self.url = url or QDRANT_URL
self._client = None
@property
def client(self) -> QdrantClient:
if self._client is None:
self._client = QdrantClient(url=self.url)
return self._client
async def ensure_collection(self, collection_name: str, vector_size: int = VECTOR_SIZE) -> bool:
"""
Ensure collection exists, create if needed.
Args:
collection_name: Name of the collection
vector_size: Dimension of vectors
Returns:
True if collection exists/created
"""
try:
collections = self.client.get_collections().collections
collection_names = [c.name for c in collections]
if collection_name not in collection_names:
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE
)
)
print(f"Created collection: {collection_name}")
return True
except Exception as e:
print(f"Error ensuring collection: {e}")
return False
async def upsert_points(self, collection_name: str, points: List[Dict]) -> int:
"""
Upsert points into collection.
Args:
collection_name: Target collection
points: List of {id, vector, payload}
Returns:
Number of upserted points
"""
import uuid
if not points:
return 0
qdrant_points = []
for p in points:
# Convert string ID to UUID for Qdrant compatibility
point_id = p["id"]
if isinstance(point_id, str):
# Use uuid5 with DNS namespace for deterministic UUID from string
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, point_id))
qdrant_points.append(
PointStruct(
id=point_id,
vector=p["vector"],
payload={**p.get("payload", {}), "original_id": p["id"]} # Keep original ID in payload
)
)
self.client.upsert(collection_name=collection_name, points=qdrant_points)
return len(qdrant_points)
async def search(
self,
collection_name: str,
query_vector: List[float],
filter_conditions: Optional[Dict] = None,
limit: int = 10
) -> List[Dict]:
"""
Semantic search in collection.
Args:
collection_name: Collection to search
query_vector: Query embedding
filter_conditions: Optional filters (key: value pairs)
limit: Max results
Returns:
List of matching points with scores
"""
query_filter = None
if filter_conditions:
must_conditions = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in filter_conditions.items()
]
query_filter = Filter(must=must_conditions)
results = self.client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"payload": r.payload
}
for r in results
]
async def get_stats(self, collection_name: str) -> Dict:
"""Get collection statistics."""
try:
info = self.client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value
}
except Exception as e:
return {"error": str(e), "name": collection_name}
+193
View File
@@ -0,0 +1,193 @@
"""
Qdrant Vector Database Service — core client and BYOEH functions.
"""
import os
from typing import List, Dict, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
COLLECTION_NAME = "bp_eh"
VECTOR_SIZE = 1536 # OpenAI text-embedding-3-small
_client: Optional[QdrantClient] = None
def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client singleton."""
global _client
if _client is None:
_client = QdrantClient(url=QDRANT_URL)
return _client
async def init_qdrant_collection() -> bool:
"""Initialize Qdrant collection for BYOEH if not exists."""
try:
client = get_qdrant_client()
# Check if collection exists
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
if COLLECTION_NAME not in collection_names:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE
)
)
print(f"Created Qdrant collection: {COLLECTION_NAME}")
else:
print(f"Qdrant collection {COLLECTION_NAME} already exists")
return True
except Exception as e:
print(f"Failed to initialize Qdrant: {e}")
return False
async def index_eh_chunks(
eh_id: str,
tenant_id: str,
subject: str,
chunks: List[Dict]
) -> int:
"""
Index EH chunks in Qdrant.
Args:
eh_id: Erwartungshorizont ID
tenant_id: Tenant/School ID for isolation
subject: Subject (deutsch, englisch, etc.)
chunks: List of {text, embedding, encrypted_content}
Returns:
Number of indexed chunks
"""
client = get_qdrant_client()
points = []
for i, chunk in enumerate(chunks):
point_id = f"{eh_id}_{i}"
points.append(
PointStruct(
id=point_id,
vector=chunk["embedding"],
payload={
"tenant_id": tenant_id,
"eh_id": eh_id,
"chunk_index": i,
"subject": subject,
"encrypted_content": chunk.get("encrypted_content", ""),
"training_allowed": False # ALWAYS FALSE - critical for compliance
}
)
)
if points:
client.upsert(collection_name=COLLECTION_NAME, points=points)
return len(points)
async def search_eh(
query_embedding: List[float],
tenant_id: str,
subject: Optional[str] = None,
limit: int = 5
) -> List[Dict]:
"""
Semantic search in tenant's Erwartungshorizonte.
Args:
query_embedding: Query vector (1536 dimensions)
tenant_id: Tenant ID for isolation
subject: Optional subject filter
limit: Max results
Returns:
List of matching chunks with scores
"""
client = get_qdrant_client()
# Build filter conditions
must_conditions = [
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id))
]
if subject:
must_conditions.append(
FieldCondition(key="subject", match=MatchValue(value=subject))
)
query_filter = Filter(must=must_conditions)
results = client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"eh_id": r.payload.get("eh_id"),
"chunk_index": r.payload.get("chunk_index"),
"encrypted_content": r.payload.get("encrypted_content"),
"subject": r.payload.get("subject")
}
for r in results
]
async def delete_eh_vectors(eh_id: str) -> int:
"""
Delete all vectors for a specific Erwartungshorizont.
Args:
eh_id: Erwartungshorizont ID
Returns:
Number of deleted points
"""
client = get_qdrant_client()
# Get all points for this EH first
scroll_result = client.scroll(
collection_name=COLLECTION_NAME,
scroll_filter=Filter(
must=[FieldCondition(key="eh_id", match=MatchValue(value=eh_id))]
),
limit=1000
)
point_ids = [str(p.id) for p in scroll_result[0]]
if point_ids:
client.delete(
collection_name=COLLECTION_NAME,
points_selector=models.PointIdsList(points=point_ids)
)
return len(point_ids)
async def get_collection_info() -> Dict:
"""Get collection statistics."""
try:
client = get_qdrant_client()
info = client.get_collection(COLLECTION_NAME)
return {
"name": COLLECTION_NAME,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value
}
except Exception as e:
return {"error": str(e)}
+231
View File
@@ -0,0 +1,231 @@
"""
Qdrant Vector Database Service — Legal Templates RAG Search.
"""
from typing import List, Dict, Optional
from qdrant_client.models import VectorParams, Distance, Filter, FieldCondition, MatchValue
from qdrant_core import get_qdrant_client
LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates"
LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3
async def init_legal_templates_collection() -> bool:
"""Initialize Qdrant collection for legal templates if not exists."""
try:
client = get_qdrant_client()
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
if LEGAL_TEMPLATES_COLLECTION not in collection_names:
client.create_collection(
collection_name=LEGAL_TEMPLATES_COLLECTION,
vectors_config=VectorParams(
size=LEGAL_TEMPLATES_VECTOR_SIZE,
distance=Distance.COSINE
)
)
print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}")
else:
print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists")
return True
except Exception as e:
print(f"Failed to initialize legal templates collection: {e}")
return False
async def search_legal_templates(
query_embedding: List[float],
template_type: Optional[str] = None,
license_types: Optional[List[str]] = None,
language: Optional[str] = None,
jurisdiction: Optional[str] = None,
attribution_required: Optional[bool] = None,
limit: int = 10
) -> List[Dict]:
"""
Search in legal templates collection for document generation.
Args:
query_embedding: Query vector (1024 dimensions, BGE-M3)
template_type: Filter by template type (privacy_policy, terms_of_service, etc.)
license_types: Filter by license types (cc0, mit, cc_by_4, etc.)
language: Filter by language (de, en)
jurisdiction: Filter by jurisdiction (DE, EU, US, etc.)
attribution_required: Filter by attribution requirement
limit: Max results
Returns:
List of matching template chunks with full metadata
"""
client = get_qdrant_client()
# Build filter conditions
must_conditions = []
if template_type:
must_conditions.append(
FieldCondition(key="template_type", match=MatchValue(value=template_type))
)
if language:
must_conditions.append(
FieldCondition(key="language", match=MatchValue(value=language))
)
if jurisdiction:
must_conditions.append(
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if attribution_required is not None:
must_conditions.append(
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
)
# License type filter (OR condition)
should_conditions = []
if license_types:
for license_type in license_types:
should_conditions.append(
FieldCondition(key="license_id", match=MatchValue(value=license_type))
)
# Construct filter
query_filter = None
if must_conditions or should_conditions:
filter_args = {}
if must_conditions:
filter_args["must"] = must_conditions
if should_conditions:
filter_args["should"] = should_conditions
query_filter = Filter(**filter_args)
try:
results = client.search(
collection_name=LEGAL_TEMPLATES_COLLECTION,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", ""),
"document_title": r.payload.get("document_title"),
"template_type": r.payload.get("template_type"),
"clause_category": r.payload.get("clause_category"),
"language": r.payload.get("language"),
"jurisdiction": r.payload.get("jurisdiction"),
"license_id": r.payload.get("license_id"),
"license_name": r.payload.get("license_name"),
"license_url": r.payload.get("license_url"),
"attribution_required": r.payload.get("attribution_required"),
"attribution_text": r.payload.get("attribution_text"),
"source_name": r.payload.get("source_name"),
"source_url": r.payload.get("source_url"),
"source_repo": r.payload.get("source_repo"),
"placeholders": r.payload.get("placeholders", []),
"is_complete_document": r.payload.get("is_complete_document"),
"is_modular": r.payload.get("is_modular"),
"requires_customization": r.payload.get("requires_customization"),
"output_allowed": r.payload.get("output_allowed"),
"modification_allowed": r.payload.get("modification_allowed"),
"distortion_prohibited": r.payload.get("distortion_prohibited"),
}
for r in results
]
except Exception as e:
print(f"Legal templates search error: {e}")
return []
async def get_legal_templates_stats() -> Dict:
"""Get statistics for the legal templates collection."""
try:
client = get_qdrant_client()
info = client.get_collection(LEGAL_TEMPLATES_COLLECTION)
# Count by template type
template_types = ["privacy_policy", "terms_of_service", "cookie_banner",
"impressum", "widerruf", "dpa", "sla", "agb"]
type_counts = {}
for ttype in template_types:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))]
)
)
if result.count > 0:
type_counts[ttype] = result.count
# Count by language
lang_counts = {}
for lang in ["de", "en"]:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="language", match=MatchValue(value=lang))]
)
)
lang_counts[lang] = result.count
# Count by license
license_counts = {}
for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))]
)
)
if result.count > 0:
license_counts[license_id] = result.count
return {
"collection": LEGAL_TEMPLATES_COLLECTION,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
"template_types": type_counts,
"languages": lang_counts,
"licenses": license_counts,
}
except Exception as e:
return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION}
async def delete_legal_templates_by_source(source_name: str) -> int:
"""
Delete all legal template chunks from a specific source.
Args:
source_name: Name of the source to delete
Returns:
Number of deleted points
"""
client = get_qdrant_client()
# Count first
count_result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
)
)
# Delete by filter
client.delete(
collection_name=LEGAL_TEMPLATES_COLLECTION,
points_selector=Filter(
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
)
)
return count_result.count
+79
View File
@@ -0,0 +1,79 @@
"""
Qdrant Vector Database Service — NiBiS RAG Search for Klausurkorrektur.
"""
from typing import List, Dict, Optional
from qdrant_client.models import Filter, FieldCondition, MatchValue
from qdrant_core import get_qdrant_client
async def search_nibis_eh(
query_embedding: List[float],
year: Optional[int] = None,
subject: Optional[str] = None,
niveau: Optional[str] = None,
limit: int = 5
) -> List[Dict]:
"""
Search in NiBiS Erwartungshorizonte (public, pre-indexed data).
Unlike search_eh(), this searches in the public NiBiS collection
and returns plaintext (not encrypted).
Args:
query_embedding: Query vector
year: Optional year filter (2016, 2017, 2024, 2025)
subject: Optional subject filter
niveau: Optional niveau filter (eA, gA)
limit: Max results
Returns:
List of matching chunks with metadata
"""
client = get_qdrant_client()
collection = "bp_nibis_eh"
# Build filter
must_conditions = []
if year:
must_conditions.append(
FieldCondition(key="year", match=MatchValue(value=year))
)
if subject:
must_conditions.append(
FieldCondition(key="subject", match=MatchValue(value=subject))
)
if niveau:
must_conditions.append(
FieldCondition(key="niveau", match=MatchValue(value=niveau))
)
query_filter = Filter(must=must_conditions) if must_conditions else None
try:
results = client.search(
collection_name=collection,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", ""),
"year": r.payload.get("year"),
"subject": r.payload.get("subject"),
"niveau": r.payload.get("niveau"),
"task_number": r.payload.get("task_number"),
"doc_type": r.payload.get("doc_type"),
"variant": r.payload.get("variant"),
}
for r in results
]
except Exception as e:
print(f"NiBiS search error: {e}")
return []
+35 -635
View File
@@ -1,638 +1,38 @@
"""
Qdrant Vector Database Service for BYOEH
Manages vector storage and semantic search for Erwartungshorizonte.
Qdrant Vector Database Service for BYOEH — barrel re-export.
The actual code lives in:
- qdrant_core.py (client singleton, BYOEH index/search/delete)
- qdrant_class.py (QdrantService class for NiBiS pipeline)
- qdrant_nibis.py (NiBiS RAG search)
- qdrant_legal.py (Legal Templates RAG search)
"""
import os
from typing import List, Dict, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
COLLECTION_NAME = "bp_eh"
VECTOR_SIZE = 1536 # OpenAI text-embedding-3-small
_client: Optional[QdrantClient] = None
def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client singleton."""
global _client
if _client is None:
_client = QdrantClient(url=QDRANT_URL)
return _client
async def init_qdrant_collection() -> bool:
"""Initialize Qdrant collection for BYOEH if not exists."""
try:
client = get_qdrant_client()
# Check if collection exists
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
if COLLECTION_NAME not in collection_names:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE
)
)
print(f"Created Qdrant collection: {COLLECTION_NAME}")
else:
print(f"Qdrant collection {COLLECTION_NAME} already exists")
return True
except Exception as e:
print(f"Failed to initialize Qdrant: {e}")
return False
async def index_eh_chunks(
eh_id: str,
tenant_id: str,
subject: str,
chunks: List[Dict]
) -> int:
"""
Index EH chunks in Qdrant.
Args:
eh_id: Erwartungshorizont ID
tenant_id: Tenant/School ID for isolation
subject: Subject (deutsch, englisch, etc.)
chunks: List of {text, embedding, encrypted_content}
Returns:
Number of indexed chunks
"""
client = get_qdrant_client()
points = []
for i, chunk in enumerate(chunks):
point_id = f"{eh_id}_{i}"
points.append(
PointStruct(
id=point_id,
vector=chunk["embedding"],
payload={
"tenant_id": tenant_id,
"eh_id": eh_id,
"chunk_index": i,
"subject": subject,
"encrypted_content": chunk.get("encrypted_content", ""),
"training_allowed": False # ALWAYS FALSE - critical for compliance
}
)
)
if points:
client.upsert(collection_name=COLLECTION_NAME, points=points)
return len(points)
async def search_eh(
query_embedding: List[float],
tenant_id: str,
subject: Optional[str] = None,
limit: int = 5
) -> List[Dict]:
"""
Semantic search in tenant's Erwartungshorizonte.
Args:
query_embedding: Query vector (1536 dimensions)
tenant_id: Tenant ID for isolation
subject: Optional subject filter
limit: Max results
Returns:
List of matching chunks with scores
"""
client = get_qdrant_client()
# Build filter conditions
must_conditions = [
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id))
]
if subject:
must_conditions.append(
FieldCondition(key="subject", match=MatchValue(value=subject))
)
query_filter = Filter(must=must_conditions)
results = client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"eh_id": r.payload.get("eh_id"),
"chunk_index": r.payload.get("chunk_index"),
"encrypted_content": r.payload.get("encrypted_content"),
"subject": r.payload.get("subject")
}
for r in results
]
async def delete_eh_vectors(eh_id: str) -> int:
"""
Delete all vectors for a specific Erwartungshorizont.
Args:
eh_id: Erwartungshorizont ID
Returns:
Number of deleted points
"""
client = get_qdrant_client()
# Get all points for this EH first
scroll_result = client.scroll(
collection_name=COLLECTION_NAME,
scroll_filter=Filter(
must=[FieldCondition(key="eh_id", match=MatchValue(value=eh_id))]
),
limit=1000
)
point_ids = [str(p.id) for p in scroll_result[0]]
if point_ids:
client.delete(
collection_name=COLLECTION_NAME,
points_selector=models.PointIdsList(points=point_ids)
)
return len(point_ids)
async def get_collection_info() -> Dict:
"""Get collection statistics."""
try:
client = get_qdrant_client()
info = client.get_collection(COLLECTION_NAME)
return {
"name": COLLECTION_NAME,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value
}
except Exception as e:
return {"error": str(e)}
# =============================================================================
# QdrantService Class (for NiBiS Ingestion Pipeline)
# =============================================================================
class QdrantService:
"""
Class-based Qdrant service for flexible collection management.
Used by nibis_ingestion.py for bulk indexing.
"""
def __init__(self, url: str = None):
self.url = url or QDRANT_URL
self._client = None
@property
def client(self) -> QdrantClient:
if self._client is None:
self._client = QdrantClient(url=self.url)
return self._client
async def ensure_collection(self, collection_name: str, vector_size: int = VECTOR_SIZE) -> bool:
"""
Ensure collection exists, create if needed.
Args:
collection_name: Name of the collection
vector_size: Dimension of vectors
Returns:
True if collection exists/created
"""
try:
collections = self.client.get_collections().collections
collection_names = [c.name for c in collections]
if collection_name not in collection_names:
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE
)
)
print(f"Created collection: {collection_name}")
return True
except Exception as e:
print(f"Error ensuring collection: {e}")
return False
async def upsert_points(self, collection_name: str, points: List[Dict]) -> int:
"""
Upsert points into collection.
Args:
collection_name: Target collection
points: List of {id, vector, payload}
Returns:
Number of upserted points
"""
import uuid
if not points:
return 0
qdrant_points = []
for p in points:
# Convert string ID to UUID for Qdrant compatibility
point_id = p["id"]
if isinstance(point_id, str):
# Use uuid5 with DNS namespace for deterministic UUID from string
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, point_id))
qdrant_points.append(
PointStruct(
id=point_id,
vector=p["vector"],
payload={**p.get("payload", {}), "original_id": p["id"]} # Keep original ID in payload
)
)
self.client.upsert(collection_name=collection_name, points=qdrant_points)
return len(qdrant_points)
async def search(
self,
collection_name: str,
query_vector: List[float],
filter_conditions: Optional[Dict] = None,
limit: int = 10
) -> List[Dict]:
"""
Semantic search in collection.
Args:
collection_name: Collection to search
query_vector: Query embedding
filter_conditions: Optional filters (key: value pairs)
limit: Max results
Returns:
List of matching points with scores
"""
query_filter = None
if filter_conditions:
must_conditions = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in filter_conditions.items()
]
query_filter = Filter(must=must_conditions)
results = self.client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"payload": r.payload
}
for r in results
]
async def get_stats(self, collection_name: str) -> Dict:
"""Get collection statistics."""
try:
info = self.client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value
}
except Exception as e:
return {"error": str(e), "name": collection_name}
# =============================================================================
# NiBiS RAG Search (for Klausurkorrektur Module)
# =============================================================================
async def search_nibis_eh(
query_embedding: List[float],
year: Optional[int] = None,
subject: Optional[str] = None,
niveau: Optional[str] = None,
limit: int = 5
) -> List[Dict]:
"""
Search in NiBiS Erwartungshorizonte (public, pre-indexed data).
Unlike search_eh(), this searches in the public NiBiS collection
and returns plaintext (not encrypted).
Args:
query_embedding: Query vector
year: Optional year filter (2016, 2017, 2024, 2025)
subject: Optional subject filter
niveau: Optional niveau filter (eA, gA)
limit: Max results
Returns:
List of matching chunks with metadata
"""
client = get_qdrant_client()
collection = "bp_nibis_eh"
# Build filter
must_conditions = []
if year:
must_conditions.append(
FieldCondition(key="year", match=MatchValue(value=year))
)
if subject:
must_conditions.append(
FieldCondition(key="subject", match=MatchValue(value=subject))
)
if niveau:
must_conditions.append(
FieldCondition(key="niveau", match=MatchValue(value=niveau))
)
query_filter = Filter(must=must_conditions) if must_conditions else None
try:
results = client.search(
collection_name=collection,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", ""),
"year": r.payload.get("year"),
"subject": r.payload.get("subject"),
"niveau": r.payload.get("niveau"),
"task_number": r.payload.get("task_number"),
"doc_type": r.payload.get("doc_type"),
"variant": r.payload.get("variant"),
}
for r in results
]
except Exception as e:
print(f"NiBiS search error: {e}")
return []
# =============================================================================
# Legal Templates RAG Search (for Document Generator)
# =============================================================================
LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates"
LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3
async def init_legal_templates_collection() -> bool:
"""Initialize Qdrant collection for legal templates if not exists."""
try:
client = get_qdrant_client()
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
if LEGAL_TEMPLATES_COLLECTION not in collection_names:
client.create_collection(
collection_name=LEGAL_TEMPLATES_COLLECTION,
vectors_config=VectorParams(
size=LEGAL_TEMPLATES_VECTOR_SIZE,
distance=Distance.COSINE
)
)
print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}")
else:
print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists")
return True
except Exception as e:
print(f"Failed to initialize legal templates collection: {e}")
return False
async def search_legal_templates(
query_embedding: List[float],
template_type: Optional[str] = None,
license_types: Optional[List[str]] = None,
language: Optional[str] = None,
jurisdiction: Optional[str] = None,
attribution_required: Optional[bool] = None,
limit: int = 10
) -> List[Dict]:
"""
Search in legal templates collection for document generation.
Args:
query_embedding: Query vector (1024 dimensions, BGE-M3)
template_type: Filter by template type (privacy_policy, terms_of_service, etc.)
license_types: Filter by license types (cc0, mit, cc_by_4, etc.)
language: Filter by language (de, en)
jurisdiction: Filter by jurisdiction (DE, EU, US, etc.)
attribution_required: Filter by attribution requirement
limit: Max results
Returns:
List of matching template chunks with full metadata
"""
client = get_qdrant_client()
# Build filter conditions
must_conditions = []
if template_type:
must_conditions.append(
FieldCondition(key="template_type", match=MatchValue(value=template_type))
)
if language:
must_conditions.append(
FieldCondition(key="language", match=MatchValue(value=language))
)
if jurisdiction:
must_conditions.append(
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if attribution_required is not None:
must_conditions.append(
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
)
# License type filter (OR condition)
should_conditions = []
if license_types:
for license_type in license_types:
should_conditions.append(
FieldCondition(key="license_id", match=MatchValue(value=license_type))
)
# Construct filter
query_filter = None
if must_conditions or should_conditions:
filter_args = {}
if must_conditions:
filter_args["must"] = must_conditions
if should_conditions:
filter_args["should"] = should_conditions
query_filter = Filter(**filter_args)
try:
results = client.search(
collection_name=LEGAL_TEMPLATES_COLLECTION,
query_vector=query_embedding,
query_filter=query_filter,
limit=limit
)
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", ""),
"document_title": r.payload.get("document_title"),
"template_type": r.payload.get("template_type"),
"clause_category": r.payload.get("clause_category"),
"language": r.payload.get("language"),
"jurisdiction": r.payload.get("jurisdiction"),
"license_id": r.payload.get("license_id"),
"license_name": r.payload.get("license_name"),
"license_url": r.payload.get("license_url"),
"attribution_required": r.payload.get("attribution_required"),
"attribution_text": r.payload.get("attribution_text"),
"source_name": r.payload.get("source_name"),
"source_url": r.payload.get("source_url"),
"source_repo": r.payload.get("source_repo"),
"placeholders": r.payload.get("placeholders", []),
"is_complete_document": r.payload.get("is_complete_document"),
"is_modular": r.payload.get("is_modular"),
"requires_customization": r.payload.get("requires_customization"),
"output_allowed": r.payload.get("output_allowed"),
"modification_allowed": r.payload.get("modification_allowed"),
"distortion_prohibited": r.payload.get("distortion_prohibited"),
}
for r in results
]
except Exception as e:
print(f"Legal templates search error: {e}")
return []
async def get_legal_templates_stats() -> Dict:
"""Get statistics for the legal templates collection."""
try:
client = get_qdrant_client()
info = client.get_collection(LEGAL_TEMPLATES_COLLECTION)
# Count by template type
template_types = ["privacy_policy", "terms_of_service", "cookie_banner",
"impressum", "widerruf", "dpa", "sla", "agb"]
type_counts = {}
for ttype in template_types:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))]
)
)
if result.count > 0:
type_counts[ttype] = result.count
# Count by language
lang_counts = {}
for lang in ["de", "en"]:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="language", match=MatchValue(value=lang))]
)
)
lang_counts[lang] = result.count
# Count by license
license_counts = {}
for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]:
result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))]
)
)
if result.count > 0:
license_counts[license_id] = result.count
return {
"collection": LEGAL_TEMPLATES_COLLECTION,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
"template_types": type_counts,
"languages": lang_counts,
"licenses": license_counts,
}
except Exception as e:
return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION}
async def delete_legal_templates_by_source(source_name: str) -> int:
"""
Delete all legal template chunks from a specific source.
Args:
source_name: Name of the source to delete
Returns:
Number of deleted points
"""
client = get_qdrant_client()
# Count first
count_result = client.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
)
)
# Delete by filter
client.delete(
collection_name=LEGAL_TEMPLATES_COLLECTION,
points_selector=Filter(
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
)
)
return count_result.count
# Core client & BYOEH functions
from qdrant_core import ( # noqa: F401
QDRANT_URL,
COLLECTION_NAME,
VECTOR_SIZE,
get_qdrant_client,
init_qdrant_collection,
index_eh_chunks,
search_eh,
delete_eh_vectors,
get_collection_info,
)
# Class-based service
from qdrant_class import QdrantService # noqa: F401
# NiBiS search
from qdrant_nibis import search_nibis_eh # noqa: F401
# Legal templates
from qdrant_legal import ( # noqa: F401
LEGAL_TEMPLATES_COLLECTION,
LEGAL_TEMPLATES_VECTOR_SIZE,
init_legal_templates_collection,
search_legal_templates,
get_legal_templates_stats,
delete_legal_templates_by_source,
)
+27 -621
View File
@@ -1,625 +1,31 @@
"""
Training API - Endpoints for managing AI training jobs
Training API — barrel re-export.
Provides endpoints for:
- Starting/stopping training jobs
- Monitoring training progress
- Managing model versions
- Configuring training parameters
- SSE streaming for real-time metrics
Phase 2.2: Server-Sent Events for live progress
The actual code lives in:
- training_models.py (enums, Pydantic models, in-memory state)
- training_simulation.py (simulate_training_progress, SSE generators)
- training_routes.py (FastAPI router + all endpoints)
"""
import os
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field, asdict
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS & MODELS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# Request/Response Models
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
global _state
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
# ============================================================================
# ROUTER
# ============================================================================
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
global _state
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
# Get stats from zeugnis sources
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12, # Estimate ~12 chunks per doc
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
# ============================================================================
# TRAINING STATUS
# ============================================================================
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINTS
# ============================================================================
async def training_metrics_generator(job_id: str, request: Request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
Usage:
const eventSource = new EventSource('/api/v1/admin/training/metrics/stream?job_id=xxx')
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data)
console.log(data.progress, data.metrics.loss)
}
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
async def batch_ocr_progress_generator(images_count: int, request: Request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
In production, integrate with actual TrOCR batch processing.
Args:
images_count: Number of images to process
Usage:
const eventSource = new EventSource('/api/v1/admin/training/ocr/stream?images_count=10')
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data)
if (data.type === 'progress') {
console.log(`${data.current}/${data.total}`)
}
}
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# Models & enums
from training_models import ( # noqa: F401
TrainingStatus,
ModelType,
TrainingConfig,
TrainingMetrics,
TrainingJob,
ModelVersion,
DatasetStats,
TrainingState,
_state,
)
# Simulation helpers
from training_simulation import ( # noqa: F401
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
# Router
from training_routes import router # noqa: F401
+118
View File
@@ -0,0 +1,118 @@
"""
Training API — enums, request/response models, and in-memory state.
"""
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()
+303
View File
@@ -0,0 +1,303 @@
"""
Training API — FastAPI route handlers.
"""
import uuid
from datetime import datetime
from typing import List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from training_models import (
TrainingStatus,
TrainingConfig,
_state,
)
from training_simulation import (
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
# ============================================================================
# TRAINING JOBS
# ============================================================================
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS & STATUS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SSE ENDPOINTS
# ============================================================================
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@@ -0,0 +1,190 @@
"""
Training API — simulation helper and SSE generators.
"""
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from training_models import TrainingStatus, _state
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
async def training_metrics_generator(job_id: str, request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
async def batch_ocr_progress_generator(images_count: int, request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
+105
View File
@@ -0,0 +1,105 @@
"""
Zeugnis Crawler - Start/stop/status control functions.
"""
import asyncio
from typing import Optional, Dict, Any
from zeugnis_worker import ZeugnisCrawler, get_crawler_state
_crawler_instance: Optional[ZeugnisCrawler] = None
_crawler_task: Optional[asyncio.Task] = None
async def start_crawler(bundesland: Optional[str] = None, source_id: Optional[str] = None) -> bool:
"""Start the crawler."""
global _crawler_instance, _crawler_task
state = get_crawler_state()
if state.is_running:
return False
state.is_running = True
state.documents_crawled_today = 0
state.documents_indexed_today = 0
state.errors_today = 0
_crawler_instance = ZeugnisCrawler()
await _crawler_instance.init()
async def run_crawler():
try:
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
# Get sources to crawl
if source_id:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE id = $1",
source_id
)
elif bundesland:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
else:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources ORDER BY bundesland"
)
for source in sources:
if not state.is_running:
break
await _crawler_instance.crawl_source(source["id"])
except Exception as e:
print(f"Crawler error: {e}")
finally:
state.is_running = False
if _crawler_instance:
await _crawler_instance.close()
_crawler_task = asyncio.create_task(run_crawler())
return True
async def stop_crawler() -> bool:
"""Stop the crawler."""
global _crawler_task
state = get_crawler_state()
if not state.is_running:
return False
state.is_running = False
if _crawler_task:
_crawler_task.cancel()
try:
await _crawler_task
except asyncio.CancelledError:
pass
return True
def get_crawler_status() -> Dict[str, Any]:
"""Get current crawler status."""
state = get_crawler_state()
return {
"is_running": state.is_running,
"current_source": state.current_source_id,
"current_bundesland": state.current_bundesland,
"queue_length": len(state.queue),
"documents_crawled_today": state.documents_crawled_today,
"documents_indexed_today": state.documents_indexed_today,
"errors_today": state.errors_today,
"last_activity": state.last_activity.isoformat() if state.last_activity else None,
}
+20 -670
View File
@@ -1,676 +1,26 @@
"""
Zeugnis Rights-Aware Crawler
Crawls official government documents about school certificates (Zeugnisse)
from all 16 German federal states. Only indexes documents where AI training
is legally permitted.
Barrel re-export: all public symbols for backward compatibility.
"""
import asyncio
import hashlib
import os
import re
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass, field
import httpx
# Local imports
from zeugnis_models import (
CrawlStatus, LicenseType, DocType, EventType,
BUNDESLAENDER, TRAINING_PERMISSIONS,
generate_id, get_training_allowed, get_bundesland_name,
from zeugnis_text import ( # noqa: F401
extract_text_from_pdf,
extract_text_from_html,
chunk_text,
compute_hash,
)
from zeugnis_storage import ( # noqa: F401
generate_embeddings,
upload_to_minio,
index_in_qdrant,
)
from zeugnis_worker import ( # noqa: F401
CrawlerState,
ZeugnisCrawler,
)
from zeugnis_control import ( # noqa: F401
start_crawler,
stop_crawler,
get_crawler_status,
)
# =============================================================================
# Configuration
# =============================================================================
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "test-access-key")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "test-secret-key")
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-rag")
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
ZEUGNIS_COLLECTION = "bp_zeugnis"
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
MAX_RETRIES = 3
RETRY_DELAY = 5 # seconds
REQUEST_TIMEOUT = 30 # seconds
USER_AGENT = "BreakPilot-Zeugnis-Crawler/1.0 (Educational Research)"
# =============================================================================
# Crawler State
# =============================================================================
@dataclass
class CrawlerState:
"""Global crawler state."""
is_running: bool = False
current_source_id: Optional[str] = None
current_bundesland: Optional[str] = None
queue: List[Dict] = field(default_factory=list)
documents_crawled_today: int = 0
documents_indexed_today: int = 0
errors_today: int = 0
last_activity: Optional[datetime] = None
_crawler_state = CrawlerState()
# =============================================================================
# Text Extraction
# =============================================================================
def extract_text_from_pdf(content: bytes) -> str:
"""Extract text from PDF bytes."""
try:
from PyPDF2 import PdfReader
import io
reader = PdfReader(io.BytesIO(content))
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return "\n\n".join(text_parts)
except Exception as e:
print(f"PDF extraction failed: {e}")
return ""
def extract_text_from_html(content: bytes, encoding: str = "utf-8") -> str:
"""Extract text from HTML bytes."""
try:
from bs4 import BeautifulSoup
html = content.decode(encoding, errors="replace")
soup = BeautifulSoup(html, "html.parser")
# Remove script and style elements
for element in soup(["script", "style", "nav", "header", "footer"]):
element.decompose()
# Get text
text = soup.get_text(separator="\n", strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
except Exception as e:
print(f"HTML extraction failed: {e}")
return ""
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
separators = ["\n\n", "\n", ". ", " "]
def split_recursive(text: str, sep_index: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text] if text.strip() else []
if sep_index >= len(separators):
# Force split at chunk_size
result = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk.strip():
result.append(chunk)
return result
sep = separators[sep_index]
parts = text.split(sep)
result = []
current = ""
for part in parts:
if len(current) + len(sep) + len(part) <= chunk_size:
current = current + sep + part if current else part
else:
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
current = part
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
return result
chunks = split_recursive(text)
# Add overlap
if overlap > 0 and len(chunks) > 1:
overlapped = []
for i, chunk in enumerate(chunks):
if i > 0:
# Add end of previous chunk
prev_end = chunks[i - 1][-overlap:]
chunk = prev_end + chunk
overlapped.append(chunk)
chunks = overlapped
return chunks
def compute_hash(content: bytes) -> str:
"""Compute SHA-256 hash of content."""
return hashlib.sha256(content).hexdigest()
# =============================================================================
# Embedding Generation
# =============================================================================
_embedding_model = None
def get_embedding_model():
"""Get or initialize embedding model."""
global _embedding_model
if _embedding_model is None and EMBEDDING_BACKEND == "local":
try:
from sentence_transformers import SentenceTransformer
_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Loaded local embedding model: all-MiniLM-L6-v2")
except ImportError:
print("Warning: sentence-transformers not installed")
return _embedding_model
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of texts."""
if not texts:
return []
if EMBEDDING_BACKEND == "local":
model = get_embedding_model()
if model:
embeddings = model.encode(texts, show_progress_bar=False)
return [emb.tolist() for emb in embeddings]
return []
elif EMBEDDING_BACKEND == "openai":
import openai
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Warning: OPENAI_API_KEY not set")
return []
client = openai.AsyncOpenAI(api_key=api_key)
response = await client.embeddings.create(
input=texts,
model="text-embedding-3-small"
)
return [item.embedding for item in response.data]
return []
# =============================================================================
# MinIO Storage
# =============================================================================
async def upload_to_minio(
content: bytes,
bundesland: str,
filename: str,
content_type: str = "application/pdf",
year: Optional[int] = None,
) -> Optional[str]:
"""Upload document to MinIO."""
try:
from minio import Minio
client = Minio(
MINIO_ENDPOINT,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
)
# Ensure bucket exists
if not client.bucket_exists(MINIO_BUCKET):
client.make_bucket(MINIO_BUCKET)
# Build path
year_str = str(year) if year else str(datetime.now().year)
object_name = f"landes-daten/{bundesland}/zeugnis/{year_str}/{filename}"
# Upload
import io
client.put_object(
MINIO_BUCKET,
object_name,
io.BytesIO(content),
len(content),
content_type=content_type,
)
return object_name
except Exception as e:
print(f"MinIO upload failed: {e}")
return None
# =============================================================================
# Qdrant Indexing
# =============================================================================
async def index_in_qdrant(
doc_id: str,
chunks: List[str],
embeddings: List[List[float]],
metadata: Dict[str, Any],
) -> int:
"""Index document chunks in Qdrant."""
try:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
client = QdrantClient(url=QDRANT_URL)
# Ensure collection exists
collections = client.get_collections().collections
if not any(c.name == ZEUGNIS_COLLECTION for c in collections):
vector_size = len(embeddings[0]) if embeddings else 384
client.create_collection(
collection_name=ZEUGNIS_COLLECTION,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE,
),
)
print(f"Created Qdrant collection: {ZEUGNIS_COLLECTION}")
# Create points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embedding,
payload={
"document_id": doc_id,
"chunk_index": i,
"chunk_text": chunk[:500], # Store first 500 chars for preview
"bundesland": metadata.get("bundesland"),
"doc_type": metadata.get("doc_type"),
"title": metadata.get("title"),
"source_url": metadata.get("url"),
"training_allowed": metadata.get("training_allowed", False),
"indexed_at": datetime.now().isoformat(),
}
))
# Upsert
if points:
client.upsert(
collection_name=ZEUGNIS_COLLECTION,
points=points,
)
return len(points)
except Exception as e:
print(f"Qdrant indexing failed: {e}")
return 0
# =============================================================================
# Crawler Worker
# =============================================================================
class ZeugnisCrawler:
"""Rights-aware crawler for zeugnis documents."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
self.db_pool = None
async def init(self):
"""Initialize crawler resources."""
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
follow_redirects=True,
headers={"User-Agent": USER_AGENT},
)
# Initialize database connection
try:
from metrics_db import get_pool
self.db_pool = await get_pool()
except Exception as e:
print(f"Failed to get database pool: {e}")
async def close(self):
"""Close crawler resources."""
if self.http_client:
await self.http_client.aclose()
async def fetch_url(self, url: str) -> Tuple[Optional[bytes], Optional[str]]:
"""Fetch URL with retry logic."""
for attempt in range(MAX_RETRIES):
try:
response = await self.http_client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
return response.content, content_type
except httpx.HTTPStatusError as e:
print(f"HTTP error {e.response.status_code} for {url}")
if e.response.status_code == 404:
return None, None
except Exception as e:
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed for {url}: {e}")
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
return None, None
async def crawl_seed_url(
self,
seed_url_id: str,
url: str,
bundesland: str,
doc_type: str,
training_allowed: bool,
) -> Dict[str, Any]:
"""Crawl a single seed URL."""
global _crawler_state
result = {
"seed_url_id": seed_url_id,
"url": url,
"success": False,
"document_id": None,
"indexed": False,
"error": None,
}
try:
# Fetch content
content, content_type = await self.fetch_url(url)
if not content:
result["error"] = "Failed to fetch URL"
return result
# Determine file type
is_pdf = "pdf" in content_type.lower() or url.lower().endswith(".pdf")
# Extract text
if is_pdf:
text = extract_text_from_pdf(content)
filename = url.split("/")[-1] or f"document_{seed_url_id}.pdf"
else:
text = extract_text_from_html(content)
filename = f"document_{seed_url_id}.html"
if not text:
result["error"] = "No text extracted"
return result
# Compute hash for versioning
content_hash = compute_hash(content)
# Upload to MinIO
minio_path = await upload_to_minio(
content,
bundesland,
filename,
content_type=content_type or "application/octet-stream",
)
# Generate document ID
doc_id = generate_id()
# Store document in database
if self.db_pool:
async with self.db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_documents
(id, seed_url_id, title, url, content_hash, minio_path,
training_allowed, file_size, content_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING
""",
doc_id, seed_url_id, filename, url, content_hash,
minio_path, training_allowed, len(content), content_type
)
result["document_id"] = doc_id
result["success"] = True
_crawler_state.documents_crawled_today += 1
# Only index if training is allowed
if training_allowed:
chunks = chunk_text(text)
if chunks:
embeddings = await generate_embeddings(chunks)
if embeddings:
indexed_count = await index_in_qdrant(
doc_id,
chunks,
embeddings,
{
"bundesland": bundesland,
"doc_type": doc_type,
"title": filename,
"url": url,
"training_allowed": True,
}
)
if indexed_count > 0:
result["indexed"] = True
_crawler_state.documents_indexed_today += 1
# Update database
if self.db_pool:
async with self.db_pool.acquire() as conn:
await conn.execute(
"UPDATE zeugnis_documents SET indexed_in_qdrant = true WHERE id = $1",
doc_id
)
else:
result["indexed"] = False
result["error"] = "Training not allowed for this source"
_crawler_state.last_activity = datetime.now()
except Exception as e:
result["error"] = str(e)
_crawler_state.errors_today += 1
return result
async def crawl_source(self, source_id: str) -> Dict[str, Any]:
"""Crawl all seed URLs for a source."""
global _crawler_state
result = {
"source_id": source_id,
"documents_found": 0,
"documents_indexed": 0,
"errors": [],
"started_at": datetime.now(),
"completed_at": None,
}
if not self.db_pool:
result["errors"].append("Database not available")
return result
try:
async with self.db_pool.acquire() as conn:
# Get source info
source = await conn.fetchrow(
"SELECT * FROM zeugnis_sources WHERE id = $1",
source_id
)
if not source:
result["errors"].append(f"Source not found: {source_id}")
return result
bundesland = source["bundesland"]
training_allowed = source["training_allowed"]
_crawler_state.current_source_id = source_id
_crawler_state.current_bundesland = bundesland
# Get seed URLs
seed_urls = await conn.fetch(
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 AND status != 'completed'",
source_id
)
for seed_url in seed_urls:
# Update status to running
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'running' WHERE id = $1",
seed_url["id"]
)
# Crawl
crawl_result = await self.crawl_seed_url(
seed_url["id"],
seed_url["url"],
bundesland,
seed_url["doc_type"],
training_allowed,
)
# Update status
if crawl_result["success"]:
result["documents_found"] += 1
if crawl_result["indexed"]:
result["documents_indexed"] += 1
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'completed', last_crawled = NOW() WHERE id = $1",
seed_url["id"]
)
else:
result["errors"].append(f"{seed_url['url']}: {crawl_result['error']}")
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'failed', error_message = $2 WHERE id = $1",
seed_url["id"], crawl_result["error"]
)
# Small delay between requests
await asyncio.sleep(1)
except Exception as e:
result["errors"].append(str(e))
finally:
result["completed_at"] = datetime.now()
_crawler_state.current_source_id = None
_crawler_state.current_bundesland = None
return result
# =============================================================================
# Crawler Control Functions
# =============================================================================
_crawler_instance: Optional[ZeugnisCrawler] = None
_crawler_task: Optional[asyncio.Task] = None
async def start_crawler(bundesland: Optional[str] = None, source_id: Optional[str] = None) -> bool:
"""Start the crawler."""
global _crawler_state, _crawler_instance, _crawler_task
if _crawler_state.is_running:
return False
_crawler_state.is_running = True
_crawler_state.documents_crawled_today = 0
_crawler_state.documents_indexed_today = 0
_crawler_state.errors_today = 0
_crawler_instance = ZeugnisCrawler()
await _crawler_instance.init()
async def run_crawler():
try:
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
# Get sources to crawl
if source_id:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE id = $1",
source_id
)
elif bundesland:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
else:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources ORDER BY bundesland"
)
for source in sources:
if not _crawler_state.is_running:
break
await _crawler_instance.crawl_source(source["id"])
except Exception as e:
print(f"Crawler error: {e}")
finally:
_crawler_state.is_running = False
if _crawler_instance:
await _crawler_instance.close()
_crawler_task = asyncio.create_task(run_crawler())
return True
async def stop_crawler() -> bool:
"""Stop the crawler."""
global _crawler_state, _crawler_task
if not _crawler_state.is_running:
return False
_crawler_state.is_running = False
if _crawler_task:
_crawler_task.cancel()
try:
await _crawler_task
except asyncio.CancelledError:
pass
return True
def get_crawler_status() -> Dict[str, Any]:
"""Get current crawler status."""
global _crawler_state
return {
"is_running": _crawler_state.is_running,
"current_source": _crawler_state.current_source_id,
"current_bundesland": _crawler_state.current_bundesland,
"queue_length": len(_crawler_state.queue),
"documents_crawled_today": _crawler_state.documents_crawled_today,
"documents_indexed_today": _crawler_state.documents_indexed_today,
"errors_today": _crawler_state.errors_today,
"last_activity": _crawler_state.last_activity.isoformat() if _crawler_state.last_activity else None,
}
+180
View File
@@ -0,0 +1,180 @@
"""
Zeugnis Crawler - Embedding generation, MinIO upload, and Qdrant indexing.
"""
import io
import os
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
# =============================================================================
# Configuration
# =============================================================================
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "test-access-key")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "test-secret-key")
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-rag")
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
ZEUGNIS_COLLECTION = "bp_zeugnis"
# =============================================================================
# Embedding Generation
# =============================================================================
_embedding_model = None
def get_embedding_model():
"""Get or initialize embedding model."""
global _embedding_model
if _embedding_model is None and EMBEDDING_BACKEND == "local":
try:
from sentence_transformers import SentenceTransformer
_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Loaded local embedding model: all-MiniLM-L6-v2")
except ImportError:
print("Warning: sentence-transformers not installed")
return _embedding_model
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of texts."""
if not texts:
return []
if EMBEDDING_BACKEND == "local":
model = get_embedding_model()
if model:
embeddings = model.encode(texts, show_progress_bar=False)
return [emb.tolist() for emb in embeddings]
return []
elif EMBEDDING_BACKEND == "openai":
import openai
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Warning: OPENAI_API_KEY not set")
return []
client = openai.AsyncOpenAI(api_key=api_key)
response = await client.embeddings.create(
input=texts,
model="text-embedding-3-small"
)
return [item.embedding for item in response.data]
return []
# =============================================================================
# MinIO Storage
# =============================================================================
async def upload_to_minio(
content: bytes,
bundesland: str,
filename: str,
content_type: str = "application/pdf",
year: Optional[int] = None,
) -> Optional[str]:
"""Upload document to MinIO."""
try:
from minio import Minio
client = Minio(
MINIO_ENDPOINT,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
)
# Ensure bucket exists
if not client.bucket_exists(MINIO_BUCKET):
client.make_bucket(MINIO_BUCKET)
# Build path
year_str = str(year) if year else str(datetime.now().year)
object_name = f"landes-daten/{bundesland}/zeugnis/{year_str}/{filename}"
# Upload
client.put_object(
MINIO_BUCKET,
object_name,
io.BytesIO(content),
len(content),
content_type=content_type,
)
return object_name
except Exception as e:
print(f"MinIO upload failed: {e}")
return None
# =============================================================================
# Qdrant Indexing
# =============================================================================
async def index_in_qdrant(
doc_id: str,
chunks: List[str],
embeddings: List[List[float]],
metadata: Dict[str, Any],
) -> int:
"""Index document chunks in Qdrant."""
try:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
client = QdrantClient(url=QDRANT_URL)
# Ensure collection exists
collections = client.get_collections().collections
if not any(c.name == ZEUGNIS_COLLECTION for c in collections):
vector_size = len(embeddings[0]) if embeddings else 384
client.create_collection(
collection_name=ZEUGNIS_COLLECTION,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE,
),
)
print(f"Created Qdrant collection: {ZEUGNIS_COLLECTION}")
# Create points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embedding,
payload={
"document_id": doc_id,
"chunk_index": i,
"chunk_text": chunk[:500], # Store first 500 chars for preview
"bundesland": metadata.get("bundesland"),
"doc_type": metadata.get("doc_type"),
"title": metadata.get("title"),
"source_url": metadata.get("url"),
"training_allowed": metadata.get("training_allowed", False),
"indexed_at": datetime.now().isoformat(),
}
))
# Upsert
if points:
client.upsert(
collection_name=ZEUGNIS_COLLECTION,
points=points,
)
return len(points)
except Exception as e:
print(f"Qdrant indexing failed: {e}")
return 0
+110
View File
@@ -0,0 +1,110 @@
"""
Zeugnis Crawler - Text extraction, chunking, and hashing utilities.
"""
import hashlib
from typing import List
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
def extract_text_from_pdf(content: bytes) -> str:
"""Extract text from PDF bytes."""
try:
from PyPDF2 import PdfReader
import io
reader = PdfReader(io.BytesIO(content))
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return "\n\n".join(text_parts)
except Exception as e:
print(f"PDF extraction failed: {e}")
return ""
def extract_text_from_html(content: bytes, encoding: str = "utf-8") -> str:
"""Extract text from HTML bytes."""
try:
from bs4 import BeautifulSoup
html = content.decode(encoding, errors="replace")
soup = BeautifulSoup(html, "html.parser")
# Remove script and style elements
for element in soup(["script", "style", "nav", "header", "footer"]):
element.decompose()
# Get text
text = soup.get_text(separator="\n", strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
except Exception as e:
print(f"HTML extraction failed: {e}")
return ""
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
separators = ["\n\n", "\n", ". ", " "]
def split_recursive(text: str, sep_index: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text] if text.strip() else []
if sep_index >= len(separators):
# Force split at chunk_size
result = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk.strip():
result.append(chunk)
return result
sep = separators[sep_index]
parts = text.split(sep)
result = []
current = ""
for part in parts:
if len(current) + len(sep) + len(part) <= chunk_size:
current = current + sep + part if current else part
else:
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
current = part
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
return result
chunks = split_recursive(text)
# Add overlap
if overlap > 0 and len(chunks) > 1:
overlapped = []
for i, chunk in enumerate(chunks):
if i > 0:
# Add end of previous chunk
prev_end = chunks[i - 1][-overlap:]
chunk = prev_end + chunk
overlapped.append(chunk)
chunks = overlapped
return chunks
def compute_hash(content: bytes) -> str:
"""Compute SHA-256 hash of content."""
return hashlib.sha256(content).hexdigest()
+313
View File
@@ -0,0 +1,313 @@
"""
Zeugnis Crawler - ZeugnisCrawler worker class and CrawlerState.
Crawls official government documents about school certificates from
all 16 German federal states. Only indexes documents where AI training
is legally permitted.
"""
import asyncio
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from dataclasses import dataclass, field
import httpx
from zeugnis_models import generate_id
from zeugnis_text import (
extract_text_from_pdf,
extract_text_from_html,
chunk_text,
compute_hash,
)
from zeugnis_storage import (
upload_to_minio,
generate_embeddings,
index_in_qdrant,
)
# =============================================================================
# Configuration
# =============================================================================
MAX_RETRIES = 3
RETRY_DELAY = 5 # seconds
REQUEST_TIMEOUT = 30 # seconds
USER_AGENT = "BreakPilot-Zeugnis-Crawler/1.0 (Educational Research)"
# =============================================================================
# Crawler State
# =============================================================================
@dataclass
class CrawlerState:
"""Global crawler state."""
is_running: bool = False
current_source_id: Optional[str] = None
current_bundesland: Optional[str] = None
queue: List[Dict] = field(default_factory=list)
documents_crawled_today: int = 0
documents_indexed_today: int = 0
errors_today: int = 0
last_activity: Optional[datetime] = None
_crawler_state = CrawlerState()
def get_crawler_state() -> CrawlerState:
"""Get the global crawler state."""
return _crawler_state
# =============================================================================
# Crawler Worker
# =============================================================================
class ZeugnisCrawler:
"""Rights-aware crawler for zeugnis documents."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
self.db_pool = None
async def init(self):
"""Initialize crawler resources."""
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
follow_redirects=True,
headers={"User-Agent": USER_AGENT},
)
# Initialize database connection
try:
from metrics_db import get_pool
self.db_pool = await get_pool()
except Exception as e:
print(f"Failed to get database pool: {e}")
async def close(self):
"""Close crawler resources."""
if self.http_client:
await self.http_client.aclose()
async def fetch_url(self, url: str) -> Tuple[Optional[bytes], Optional[str]]:
"""Fetch URL with retry logic."""
for attempt in range(MAX_RETRIES):
try:
response = await self.http_client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
return response.content, content_type
except httpx.HTTPStatusError as e:
print(f"HTTP error {e.response.status_code} for {url}")
if e.response.status_code == 404:
return None, None
except Exception as e:
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed for {url}: {e}")
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
return None, None
async def crawl_seed_url(
self,
seed_url_id: str,
url: str,
bundesland: str,
doc_type: str,
training_allowed: bool,
) -> Dict[str, Any]:
"""Crawl a single seed URL."""
global _crawler_state
result = {
"seed_url_id": seed_url_id,
"url": url,
"success": False,
"document_id": None,
"indexed": False,
"error": None,
}
try:
# Fetch content
content, content_type = await self.fetch_url(url)
if not content:
result["error"] = "Failed to fetch URL"
return result
# Determine file type
is_pdf = "pdf" in content_type.lower() or url.lower().endswith(".pdf")
# Extract text
if is_pdf:
text = extract_text_from_pdf(content)
filename = url.split("/")[-1] or f"document_{seed_url_id}.pdf"
else:
text = extract_text_from_html(content)
filename = f"document_{seed_url_id}.html"
if not text:
result["error"] = "No text extracted"
return result
# Compute hash for versioning
content_hash = compute_hash(content)
# Upload to MinIO
minio_path = await upload_to_minio(
content,
bundesland,
filename,
content_type=content_type or "application/octet-stream",
)
# Generate document ID
doc_id = generate_id()
# Store document in database
if self.db_pool:
async with self.db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_documents
(id, seed_url_id, title, url, content_hash, minio_path,
training_allowed, file_size, content_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING
""",
doc_id, seed_url_id, filename, url, content_hash,
minio_path, training_allowed, len(content), content_type
)
result["document_id"] = doc_id
result["success"] = True
_crawler_state.documents_crawled_today += 1
# Only index if training is allowed
if training_allowed:
chunks = chunk_text(text)
if chunks:
embeddings = await generate_embeddings(chunks)
if embeddings:
indexed_count = await index_in_qdrant(
doc_id,
chunks,
embeddings,
{
"bundesland": bundesland,
"doc_type": doc_type,
"title": filename,
"url": url,
"training_allowed": True,
}
)
if indexed_count > 0:
result["indexed"] = True
_crawler_state.documents_indexed_today += 1
# Update database
if self.db_pool:
async with self.db_pool.acquire() as conn:
await conn.execute(
"UPDATE zeugnis_documents SET indexed_in_qdrant = true WHERE id = $1",
doc_id
)
else:
result["indexed"] = False
result["error"] = "Training not allowed for this source"
_crawler_state.last_activity = datetime.now()
except Exception as e:
result["error"] = str(e)
_crawler_state.errors_today += 1
return result
async def crawl_source(self, source_id: str) -> Dict[str, Any]:
"""Crawl all seed URLs for a source."""
global _crawler_state
result = {
"source_id": source_id,
"documents_found": 0,
"documents_indexed": 0,
"errors": [],
"started_at": datetime.now(),
"completed_at": None,
}
if not self.db_pool:
result["errors"].append("Database not available")
return result
try:
async with self.db_pool.acquire() as conn:
# Get source info
source = await conn.fetchrow(
"SELECT * FROM zeugnis_sources WHERE id = $1",
source_id
)
if not source:
result["errors"].append(f"Source not found: {source_id}")
return result
bundesland = source["bundesland"]
training_allowed = source["training_allowed"]
_crawler_state.current_source_id = source_id
_crawler_state.current_bundesland = bundesland
# Get seed URLs
seed_urls = await conn.fetch(
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 AND status != 'completed'",
source_id
)
for seed_url in seed_urls:
# Update status to running
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'running' WHERE id = $1",
seed_url["id"]
)
# Crawl
crawl_result = await self.crawl_seed_url(
seed_url["id"],
seed_url["url"],
bundesland,
seed_url["doc_type"],
training_allowed,
)
# Update status
if crawl_result["success"]:
result["documents_found"] += 1
if crawl_result["indexed"]:
result["documents_indexed"] += 1
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'completed', last_crawled = NOW() WHERE id = $1",
seed_url["id"]
)
else:
result["errors"].append(f"{seed_url['url']}: {crawl_result['error']}")
await conn.execute(
"UPDATE zeugnis_seed_urls SET status = 'failed', error_message = $2 WHERE id = $1",
seed_url["id"], crawl_result["error"]
)
# Small delay between requests
await asyncio.sleep(1)
except Exception as e:
result["errors"].append(str(e))
finally:
result["completed_at"] = datetime.now()
_crawler_state.current_source_id = None
_crawler_state.current_bundesland = None
return result