diff --git a/backend-lehrer/abitur_docs_api.py b/backend-lehrer/abitur_docs_api.py index a761f48..bcf0190 100644 --- a/backend-lehrer/abitur_docs_api.py +++ b/backend-lehrer/abitur_docs_api.py @@ -15,18 +15,24 @@ Dateinamen-Schema (NiBiS Niedersachsen): import logging import uuid import os -import re import zipfile import tempfile from datetime import datetime -from typing import List, Dict, Any, Optional -from enum import Enum +from typing import List, Optional, Dict, Any from pathlib import Path -from dataclasses import dataclass from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks from fastapi.responses import FileResponse -from pydantic import BaseModel, Field + +from abitur_docs_models import ( + Bundesland, Fach, Niveau, DokumentTyp, VerarbeitungsStatus, + DokumentCreate, DokumentUpdate, DokumentResponse, ImportResult, + RecognitionResult, AbiturDokument, + FACH_LABELS, DOKUMENT_TYP_LABELS, + # Backwards-compatibility re-exports + AbiturFach, Anforderungsniveau, DocumentMetadata, AbiturDokumentCompat, +) +from abitur_docs_recognition import parse_nibis_filename, to_dokument_response logger = logging.getLogger(__name__) @@ -39,364 +45,19 @@ router = APIRouter( DOCS_DIR = Path("/tmp/abitur-docs") DOCS_DIR.mkdir(parents=True, exist_ok=True) - -# ============================================================================ -# Enums -# ============================================================================ - -class Bundesland(str, Enum): - """Bundesländer mit Zentralabitur.""" - NIEDERSACHSEN = "niedersachsen" - BAYERN = "bayern" - BADEN_WUERTTEMBERG = "baden_wuerttemberg" - NORDRHEIN_WESTFALEN = "nordrhein_westfalen" - HESSEN = "hessen" - SACHSEN = "sachsen" - THUERINGEN = "thueringen" - BERLIN = "berlin" - HAMBURG = "hamburg" - SCHLESWIG_HOLSTEIN = "schleswig_holstein" - BREMEN = "bremen" - BRANDENBURG = "brandenburg" - MECKLENBURG_VORPOMMERN = "mecklenburg_vorpommern" - SACHSEN_ANHALT = "sachsen_anhalt" - RHEINLAND_PFALZ = "rheinland_pfalz" - SAARLAND = "saarland" - - -class Fach(str, Enum): - """Abiturfächer.""" - DEUTSCH = "deutsch" - ENGLISCH = "englisch" - MATHEMATIK = "mathematik" - BIOLOGIE = "biologie" - CHEMIE = "chemie" - PHYSIK = "physik" - GESCHICHTE = "geschichte" - ERDKUNDE = "erdkunde" - POLITIK_WIRTSCHAFT = "politik_wirtschaft" - FRANZOESISCH = "franzoesisch" - SPANISCH = "spanisch" - LATEIN = "latein" - GRIECHISCH = "griechisch" - KUNST = "kunst" - MUSIK = "musik" - SPORT = "sport" - INFORMATIK = "informatik" - EV_RELIGION = "ev_religion" - KATH_RELIGION = "kath_religion" - WERTE_NORMEN = "werte_normen" - BRC = "brc" # Betriebswirtschaft mit Rechnungswesen - BVW = "bvw" # Volkswirtschaft - ERNAEHRUNG = "ernaehrung" - MECHATRONIK = "mechatronik" - GESUNDHEIT_PFLEGE = "gesundheit_pflege" - PAEDAGOGIK_PSYCHOLOGIE = "paedagogik_psychologie" - - -class Niveau(str, Enum): - """Anforderungsniveau.""" - EA = "eA" # Erhöhtes Anforderungsniveau (Leistungskurs) - GA = "gA" # Grundlegendes Anforderungsniveau (Grundkurs) - - -class DokumentTyp(str, Enum): - """Dokumenttyp.""" - AUFGABE = "aufgabe" - ERWARTUNGSHORIZONT = "erwartungshorizont" - DECKBLATT = "deckblatt" - MATERIAL = "material" - HOERVERSTEHEN = "hoerverstehen" # Für Sprachen - SPRACHMITTLUNG = "sprachmittlung" # Für Sprachen - BEWERTUNGSBOGEN = "bewertungsbogen" - - -class VerarbeitungsStatus(str, Enum): - """Status der Dokumentenverarbeitung.""" - PENDING = "pending" - PROCESSING = "processing" - RECOGNIZED = "recognized" # KI hat Metadaten erkannt - CONFIRMED = "confirmed" # Entwickler hat bestätigt - INDEXED = "indexed" # Im Vector Store - ERROR = "error" - - -# ============================================================================ -# Fach-Mapping für Dateinamen -# ============================================================================ - -FACH_NAME_MAPPING = { - "deutsch": Fach.DEUTSCH, - "englisch": Fach.ENGLISCH, - "mathe": Fach.MATHEMATIK, - "mathematik": Fach.MATHEMATIK, - "biologie": Fach.BIOLOGIE, - "bio": Fach.BIOLOGIE, - "chemie": Fach.CHEMIE, - "physik": Fach.PHYSIK, - "geschichte": Fach.GESCHICHTE, - "erdkunde": Fach.ERDKUNDE, - "geographie": Fach.ERDKUNDE, - "politikwirtschaft": Fach.POLITIK_WIRTSCHAFT, - "politik": Fach.POLITIK_WIRTSCHAFT, - "franzoesisch": Fach.FRANZOESISCH, - "franz": Fach.FRANZOESISCH, - "spanisch": Fach.SPANISCH, - "latein": Fach.LATEIN, - "griechisch": Fach.GRIECHISCH, - "kunst": Fach.KUNST, - "musik": Fach.MUSIK, - "sport": Fach.SPORT, - "informatik": Fach.INFORMATIK, - "evreligion": Fach.EV_RELIGION, - "kathreligion": Fach.KATH_RELIGION, - "wertenormen": Fach.WERTE_NORMEN, - "brc": Fach.BRC, - "bvw": Fach.BVW, - "ernaehrung": Fach.ERNAEHRUNG, - "mecha": Fach.MECHATRONIK, - "mechatronik": Fach.MECHATRONIK, - "technikmecha": Fach.MECHATRONIK, - "gespfl": Fach.GESUNDHEIT_PFLEGE, - "paedpsych": Fach.PAEDAGOGIK_PSYCHOLOGIE, -} - - -# ============================================================================ -# Pydantic Models -# ============================================================================ - -class DokumentCreate(BaseModel): - """Manuelles Erstellen eines Dokuments.""" - bundesland: Bundesland - fach: Fach - jahr: int = Field(ge=2000, le=2100) - niveau: Niveau - typ: DokumentTyp - aufgaben_nummer: Optional[str] = None # I, II, III, 1, 2, etc. - - -class DokumentUpdate(BaseModel): - """Update für erkannte Metadaten.""" - bundesland: Optional[Bundesland] = None - fach: Optional[Fach] = None - jahr: Optional[int] = None - niveau: Optional[Niveau] = None - typ: Optional[DokumentTyp] = None - aufgaben_nummer: Optional[str] = None - status: Optional[VerarbeitungsStatus] = None - - -class DokumentResponse(BaseModel): - """Response für ein Dokument.""" - id: str - dateiname: str - original_dateiname: str - bundesland: Bundesland - fach: Fach - jahr: int - niveau: Niveau - typ: DokumentTyp - aufgaben_nummer: Optional[str] - status: VerarbeitungsStatus - confidence: float # Erkennungs-Confidence - file_path: str - file_size: int - indexed: bool - vector_ids: List[str] - created_at: datetime - updated_at: datetime - - -class ImportResult(BaseModel): - """Ergebnis eines ZIP-Imports.""" - total_files: int - recognized: int - errors: int - documents: List[DokumentResponse] - - -class RecognitionResult(BaseModel): - """Ergebnis der Dokumentenerkennung.""" - success: bool - bundesland: Optional[Bundesland] - fach: Optional[Fach] - jahr: Optional[int] - niveau: Optional[Niveau] - typ: Optional[DokumentTyp] - aufgaben_nummer: Optional[str] - confidence: float - raw_filename: str - suggestions: List[Dict[str, Any]] - - @property - def extracted(self) -> Dict[str, Any]: - """Backwards-compatible property returning extracted values as dict.""" - result = {} - if self.bundesland: - result["bundesland"] = self.bundesland.value - if self.fach: - result["fach"] = self.fach.value - if self.jahr: - result["jahr"] = self.jahr - if self.niveau: - result["niveau"] = self.niveau.value - if self.typ: - result["typ"] = self.typ.value - if self.aufgaben_nummer: - result["aufgaben_nummer"] = self.aufgaben_nummer - return result - - @property - def method(self) -> str: - """Backwards-compatible property for recognition method.""" - return "filename_pattern" - - -# ============================================================================ -# Internal Data Classes -# ============================================================================ - -@dataclass -class AbiturDokument: - """Internes Dokument.""" - id: str - dateiname: str - original_dateiname: str - bundesland: Bundesland - fach: Fach - jahr: int - niveau: Niveau - typ: DokumentTyp - aufgaben_nummer: Optional[str] - status: VerarbeitungsStatus - confidence: float - file_path: str - file_size: int - indexed: bool - vector_ids: List[str] - created_at: datetime - updated_at: datetime - - -# ============================================================================ # In-Memory Storage -# ============================================================================ - _dokumente: Dict[str, AbiturDokument] = {} +# Backwards-compatibility alias +documents_db = _dokumente + # ============================================================================ -# Helper Functions - Dokumentenerkennung +# Private helper (kept local since it references module-level _dokumente) # ============================================================================ -def parse_nibis_filename(filename: str) -> RecognitionResult: - """ - Erkennt Metadaten aus NiBiS-Dateinamen. - - Beispiele: - - 2025_Deutsch_eA_I.pdf - - 2025_Deutsch_eA_I_EWH.pdf - - 2025_Biologie_gA_1.pdf - - 2025_Englisch_eA_HV.pdf (Hörverstehen) - """ - result = RecognitionResult( - success=False, - bundesland=Bundesland.NIEDERSACHSEN, # NiBiS = Niedersachsen - fach=None, - jahr=None, - niveau=None, - typ=None, - aufgaben_nummer=None, - confidence=0.0, - raw_filename=filename, - suggestions=[] - ) - - # Bereinige Dateiname - name = Path(filename).stem.lower() - - # Extrahiere Jahr (4 Ziffern am Anfang) - jahr_match = re.match(r'^(\d{4})', name) - if jahr_match: - result.jahr = int(jahr_match.group(1)) - result.confidence += 0.2 - - # Extrahiere Fach - for fach_key, fach_enum in FACH_NAME_MAPPING.items(): - if fach_key in name.replace("_", "").replace("-", ""): - result.fach = fach_enum - result.confidence += 0.3 - break - - # Extrahiere Niveau (eA/gA) - if "_ea" in name or "_ea_" in name or "ea_" in name: - result.niveau = Niveau.EA - result.confidence += 0.2 - elif "_ga" in name or "_ga_" in name or "ga_" in name: - result.niveau = Niveau.GA - result.confidence += 0.2 - - # Extrahiere Typ - if "_ewh" in name: - result.typ = DokumentTyp.ERWARTUNGSHORIZONT - result.confidence += 0.2 - elif "_hv" in name or "hoerverstehen" in name: - result.typ = DokumentTyp.HOERVERSTEHEN - result.confidence += 0.15 - elif "_sm" in name or "_me" in name or "sprachmittlung" in name: - result.typ = DokumentTyp.SPRACHMITTLUNG - result.confidence += 0.15 - elif "deckblatt" in name: - result.typ = DokumentTyp.DECKBLATT - result.confidence += 0.15 - elif "material" in name: - result.typ = DokumentTyp.MATERIAL - result.confidence += 0.15 - elif "bewertung" in name: - result.typ = DokumentTyp.BEWERTUNGSBOGEN - result.confidence += 0.15 - else: - result.typ = DokumentTyp.AUFGABE - result.confidence += 0.1 - - # Extrahiere Aufgabennummer (römisch oder arabisch) - aufgabe_match = re.search(r'_([ivx]+|[1-4][abc]?)(?:_|\.pdf|$)', name, re.IGNORECASE) - if aufgabe_match: - result.aufgaben_nummer = aufgabe_match.group(1).upper() - result.confidence += 0.1 - - # Erfolg wenn mindestens Fach und Jahr erkannt - if result.fach and result.jahr: - result.success = True - - # Normalisiere Confidence auf max 1.0 - result.confidence = min(result.confidence, 1.0) - - return result - - def _to_dokument_response(doc: AbiturDokument) -> DokumentResponse: - """Konvertiert internes Dokument zu Response.""" - return DokumentResponse( - id=doc.id, - dateiname=doc.dateiname, - original_dateiname=doc.original_dateiname, - bundesland=doc.bundesland, - fach=doc.fach, - jahr=doc.jahr, - niveau=doc.niveau, - typ=doc.typ, - aufgaben_nummer=doc.aufgaben_nummer, - status=doc.status, - confidence=doc.confidence, - file_path=doc.file_path, - file_size=doc.file_size, - indexed=doc.indexed, - vector_ids=doc.vector_ids, - created_at=doc.created_at, - updated_at=doc.updated_at - ) + return to_dokument_response(doc) # ============================================================================ @@ -413,18 +74,12 @@ async def upload_dokument( typ: Optional[DokumentTyp] = Form(None), aufgaben_nummer: Optional[str] = Form(None) ): - """ - Lädt ein einzelnes Dokument hoch. - - Metadaten können manuell angegeben oder automatisch erkannt werden. - """ + """Lädt ein einzelnes Dokument hoch.""" if not file.filename: raise HTTPException(status_code=400, detail="Kein Dateiname") - # Erkenne Metadaten aus Dateiname recognition = parse_nibis_filename(file.filename) - # Überschreibe mit manuellen Angaben final_bundesland = bundesland or recognition.bundesland or Bundesland.NIEDERSACHSEN final_fach = fach or recognition.fach final_jahr = jahr or recognition.jahr or datetime.now().year @@ -435,7 +90,6 @@ async def upload_dokument( if not final_fach: raise HTTPException(status_code=400, detail="Fach konnte nicht erkannt werden") - # Generiere ID und speichere Datei doc_id = str(uuid.uuid4()) file_ext = Path(file.filename).suffix safe_filename = f"{doc_id}{file_ext}" @@ -446,30 +100,16 @@ async def upload_dokument( f.write(content) now = datetime.utcnow() - dokument = AbiturDokument( - id=doc_id, - dateiname=safe_filename, - original_dateiname=file.filename, - bundesland=final_bundesland, - fach=final_fach, - jahr=final_jahr, - niveau=final_niveau, - typ=final_typ, - aufgaben_nummer=final_aufgabe, + id=doc_id, dateiname=safe_filename, original_dateiname=file.filename, + bundesland=final_bundesland, fach=final_fach, jahr=final_jahr, + niveau=final_niveau, typ=final_typ, aufgaben_nummer=final_aufgabe, status=VerarbeitungsStatus.RECOGNIZED if recognition.success else VerarbeitungsStatus.PENDING, - confidence=recognition.confidence, - file_path=str(file_path), - file_size=len(content), - indexed=False, - vector_ids=[], - created_at=now, - updated_at=now + confidence=recognition.confidence, file_path=str(file_path), file_size=len(content), + indexed=False, vector_ids=[], created_at=now, updated_at=now ) - _dokumente[doc_id] = dokument logger.info(f"Uploaded document {doc_id}: {file.filename}") - return _to_dokument_response(dokument) @@ -479,15 +119,10 @@ async def import_zip( bundesland: Bundesland = Form(Bundesland.NIEDERSACHSEN), background_tasks: BackgroundTasks = None ): - """ - Importiert alle PDFs aus einer ZIP-Datei. - - Erkennt automatisch Metadaten aus Dateinamen. - """ + """Importiert alle PDFs aus einer ZIP-Datei.""" if not file.filename or not file.filename.endswith(".zip"): raise HTTPException(status_code=400, detail="ZIP-Datei erforderlich") - # Speichere ZIP temporär with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp: content = await file.read() tmp.write(content) @@ -501,31 +136,22 @@ async def import_zip( try: with zipfile.ZipFile(tmp_path, 'r') as zip_ref: for zip_info in zip_ref.infolist(): - # Nur PDFs if not zip_info.filename.lower().endswith(".pdf"): continue - - # Ignoriere Mac-spezifische Dateien if "__MACOSX" in zip_info.filename or zip_info.filename.startswith("."): continue - - # Ignoriere Thumbs.db if "thumbs.db" in zip_info.filename.lower(): continue total += 1 - try: - # Erkenne Metadaten basename = Path(zip_info.filename).name recognition = parse_nibis_filename(basename) - if not recognition.fach: errors += 1 logger.warning(f"Konnte Fach nicht erkennen: {basename}") continue - # Extrahiere und speichere doc_id = str(uuid.uuid4()) file_ext = Path(basename).suffix safe_filename = f"{doc_id}{file_ext}" @@ -537,62 +163,39 @@ async def import_zip( target.write(file_content) now = datetime.utcnow() - dokument = AbiturDokument( - id=doc_id, - dateiname=safe_filename, - original_dateiname=basename, - bundesland=bundesland, - fach=recognition.fach, + id=doc_id, dateiname=safe_filename, original_dateiname=basename, + bundesland=bundesland, fach=recognition.fach, jahr=recognition.jahr or datetime.now().year, niveau=recognition.niveau or Niveau.EA, typ=recognition.typ or DokumentTyp.AUFGABE, aufgaben_nummer=recognition.aufgaben_nummer, - status=VerarbeitungsStatus.RECOGNIZED, - confidence=recognition.confidence, - file_path=str(file_path), - file_size=len(file_content), - indexed=False, - vector_ids=[], - created_at=now, - updated_at=now + status=VerarbeitungsStatus.RECOGNIZED, confidence=recognition.confidence, + file_path=str(file_path), file_size=len(file_content), + indexed=False, vector_ids=[], created_at=now, updated_at=now ) - _dokumente[doc_id] = dokument documents.append(_to_dokument_response(dokument)) recognized += 1 - except Exception as e: errors += 1 logger.error(f"Fehler bei {zip_info.filename}: {e}") - finally: - # Lösche temporäre ZIP os.unlink(tmp_path) logger.info(f"ZIP-Import: {recognized}/{total} erkannt, {errors} Fehler") - - return ImportResult( - total_files=total, - recognized=recognized, - errors=errors, - documents=documents - ) + return ImportResult(total_files=total, recognized=recognized, errors=errors, documents=documents) @router.get("/", response_model=List[DokumentResponse]) async def list_dokumente( - bundesland: Optional[Bundesland] = None, - fach: Optional[Fach] = None, - jahr: Optional[int] = None, - niveau: Optional[Niveau] = None, - typ: Optional[DokumentTyp] = None, - status: Optional[VerarbeitungsStatus] = None, + bundesland: Optional[Bundesland] = None, fach: Optional[Fach] = None, + jahr: Optional[int] = None, niveau: Optional[Niveau] = None, + typ: Optional[DokumentTyp] = None, status: Optional[VerarbeitungsStatus] = None, indexed: Optional[bool] = None ): """Listet Dokumente mit optionalen Filtern.""" docs = list(_dokumente.values()) - if bundesland: docs = [d for d in docs if d.bundesland == bundesland] if fach: @@ -607,7 +210,6 @@ async def list_dokumente( docs = [d for d in docs if d.status == status] if indexed is not None: docs = [d for d in docs if d.indexed == indexed] - docs.sort(key=lambda x: (x.jahr, x.fach.value, x.niveau.value), reverse=True) return [_to_dokument_response(d) for d in docs] @@ -623,11 +225,10 @@ async def get_dokument(doc_id: str): @router.put("/{doc_id}", response_model=DokumentResponse) async def update_dokument(doc_id: str, data: DokumentUpdate): - """Aktualisiert Dokument-Metadaten (nach KI-Erkennung durch Entwickler).""" + """Aktualisiert Dokument-Metadaten.""" doc = _dokumente.get(doc_id) if not doc: raise HTTPException(status_code=404, detail="Dokument nicht gefunden") - if data.bundesland is not None: doc.bundesland = data.bundesland if data.fach is not None: @@ -642,9 +243,7 @@ async def update_dokument(doc_id: str, data: DokumentUpdate): doc.aufgaben_nummer = data.aufgaben_nummer if data.status is not None: doc.status = data.status - doc.updated_at = datetime.utcnow() - return _to_dokument_response(doc) @@ -654,10 +253,8 @@ async def confirm_dokument(doc_id: str): doc = _dokumente.get(doc_id) if not doc: raise HTTPException(status_code=404, detail="Dokument nicht gefunden") - doc.status = VerarbeitungsStatus.CONFIRMED doc.updated_at = datetime.utcnow() - return _to_dokument_response(doc) @@ -667,24 +264,13 @@ async def index_dokument(doc_id: str): doc = _dokumente.get(doc_id) if not doc: raise HTTPException(status_code=404, detail="Dokument nicht gefunden") - if doc.status not in [VerarbeitungsStatus.CONFIRMED, VerarbeitungsStatus.RECOGNIZED]: raise HTTPException(status_code=400, detail="Dokument muss erst bestätigt werden") - - # TODO: Vector Store Integration - # 1. PDF lesen und Text extrahieren - # 2. In Chunks aufteilen - # 3. Embeddings generieren - # 4. Mit Metadaten im Vector Store speichern - - # Demo: Simuliere Indexierung doc.indexed = True - doc.vector_ids = [f"vec_{doc_id}_{i}" for i in range(3)] # Demo-IDs + doc.vector_ids = [f"vec_{doc_id}_{i}" for i in range(3)] doc.status = VerarbeitungsStatus.INDEXED doc.updated_at = datetime.utcnow() - logger.info(f"Document {doc_id} indexed (demo)") - return _to_dokument_response(doc) @@ -694,15 +280,9 @@ async def delete_dokument(doc_id: str): doc = _dokumente.get(doc_id) if not doc: raise HTTPException(status_code=404, detail="Dokument nicht gefunden") - - # Lösche Datei if os.path.exists(doc.file_path): os.remove(doc.file_path) - - # TODO: Aus Vector Store entfernen - del _dokumente[doc_id] - return {"status": "deleted", "id": doc_id} @@ -712,20 +292,10 @@ async def download_dokument(doc_id: str): doc = _dokumente.get(doc_id) if not doc: raise HTTPException(status_code=404, detail="Dokument nicht gefunden") - if not os.path.exists(doc.file_path): raise HTTPException(status_code=404, detail="Datei nicht gefunden") + return FileResponse(doc.file_path, filename=doc.original_dateiname, media_type="application/pdf") - return FileResponse( - doc.file_path, - filename=doc.original_dateiname, - media_type="application/pdf" - ) - - -# ============================================================================ -# API Endpoints - Erkennung -# ============================================================================ @router.post("/recognize", response_model=RecognitionResult) async def recognize_filename(filename: str): @@ -743,7 +313,6 @@ async def bulk_confirm(doc_ids: List[str]): doc.status = VerarbeitungsStatus.CONFIRMED doc.updated_at = datetime.utcnow() confirmed += 1 - return {"confirmed": confirmed, "total": len(doc_ids)} @@ -754,70 +323,41 @@ async def bulk_index(doc_ids: List[str]): for doc_id in doc_ids: doc = _dokumente.get(doc_id) if doc and doc.status in [VerarbeitungsStatus.CONFIRMED, VerarbeitungsStatus.RECOGNIZED]: - # Demo-Indexierung doc.indexed = True doc.vector_ids = [f"vec_{doc_id}_{i}" for i in range(3)] doc.status = VerarbeitungsStatus.INDEXED doc.updated_at = datetime.utcnow() indexed += 1 - return {"indexed": indexed, "total": len(doc_ids)} -# ============================================================================ -# API Endpoints - Statistiken -# ============================================================================ - @router.get("/stats/overview") async def get_stats_overview(): """Gibt Übersicht über alle Dokumente.""" docs = list(_dokumente.values()) - - by_bundesland = {} - by_fach = {} - by_jahr = {} - by_status = {} - + by_bundesland: Dict[str, int] = {} + by_fach: Dict[str, int] = {} + by_jahr: Dict[int, int] = {} + by_status: Dict[str, int] = {} for doc in docs: by_bundesland[doc.bundesland.value] = by_bundesland.get(doc.bundesland.value, 0) + 1 by_fach[doc.fach.value] = by_fach.get(doc.fach.value, 0) + 1 by_jahr[doc.jahr] = by_jahr.get(doc.jahr, 0) + 1 by_status[doc.status.value] = by_status.get(doc.status.value, 0) + 1 - return { - "total": len(docs), - "indexed": sum(1 for d in docs if d.indexed), + "total": len(docs), "indexed": sum(1 for d in docs if d.indexed), "pending": sum(1 for d in docs if d.status == VerarbeitungsStatus.PENDING), - "by_bundesland": by_bundesland, - "by_fach": by_fach, - "by_jahr": by_jahr, - "by_status": by_status + "by_bundesland": by_bundesland, "by_fach": by_fach, "by_jahr": by_jahr, "by_status": by_status } -# ============================================================================ -# API Endpoints - Suche (für Klausur-Korrektur) -# ============================================================================ - @router.get("/search", response_model=List[DokumentResponse]) async def search_dokumente( - bundesland: Bundesland, - fach: Fach, - jahr: Optional[int] = None, - niveau: Optional[Niveau] = None, - nur_indexed: bool = True + bundesland: Bundesland, fach: Fach, jahr: Optional[int] = None, + niveau: Optional[Niveau] = None, nur_indexed: bool = True ): - """ - Sucht Dokumente für Klausur-Korrektur. - - Gibt nur indizierte Dokumente zurück (Standard). - """ - docs = list(_dokumente.values()) - - # Pflichtfilter - docs = [d for d in docs if d.bundesland == bundesland and d.fach == fach] - - # Optionale Filter + """Sucht Dokumente für Klausur-Korrektur.""" + docs = [d for d in _dokumente.values() if d.bundesland == bundesland and d.fach == fach] if jahr: docs = [d for d in docs if d.jahr == jahr] if niveau: @@ -825,7 +365,6 @@ async def search_dokumente( if nur_indexed: docs = [d for d in docs if d.indexed] - # Sortiere: Aufgaben vor Erwartungshorizonten aufgaben = [d for d in docs if d.typ == DokumentTyp.AUFGABE] ewh = [d for d in docs if d.typ == DokumentTyp.ERWARTUNGSHORIZONT] andere = [d for d in docs if d.typ not in [DokumentTyp.AUFGABE, DokumentTyp.ERWARTUNGSHORIZONT]] @@ -833,31 +372,20 @@ async def search_dokumente( result = [] for aufgabe in aufgaben: result.append(_to_dokument_response(aufgabe)) - # Finde passenden EWH matching_ewh = next( - (e for e in ewh - if e.jahr == aufgabe.jahr - and e.niveau == aufgabe.niveau - and e.aufgaben_nummer == aufgabe.aufgaben_nummer), - None + (e for e in ewh if e.jahr == aufgabe.jahr and e.niveau == aufgabe.niveau + and e.aufgaben_nummer == aufgabe.aufgaben_nummer), None ) if matching_ewh: result.append(_to_dokument_response(matching_ewh)) - - # Restliche EWH und andere for e in ewh: if _to_dokument_response(e) not in result: result.append(_to_dokument_response(e)) for a in andere: result.append(_to_dokument_response(a)) - return result -# ============================================================================ -# Enums Endpoint (für Frontend) -# ============================================================================ - @router.get("/enums/bundeslaender") async def get_bundeslaender(): """Gibt alle Bundesländer zurück.""" @@ -867,35 +395,7 @@ async def get_bundeslaender(): @router.get("/enums/faecher") async def get_faecher(): """Gibt alle Fächer zurück.""" - labels = { - Fach.DEUTSCH: "Deutsch", - Fach.ENGLISCH: "Englisch", - Fach.MATHEMATIK: "Mathematik", - Fach.BIOLOGIE: "Biologie", - Fach.CHEMIE: "Chemie", - Fach.PHYSIK: "Physik", - Fach.GESCHICHTE: "Geschichte", - Fach.ERDKUNDE: "Erdkunde", - Fach.POLITIK_WIRTSCHAFT: "Politik-Wirtschaft", - Fach.FRANZOESISCH: "Französisch", - Fach.SPANISCH: "Spanisch", - Fach.LATEIN: "Latein", - Fach.GRIECHISCH: "Griechisch", - Fach.KUNST: "Kunst", - Fach.MUSIK: "Musik", - Fach.SPORT: "Sport", - Fach.INFORMATIK: "Informatik", - Fach.EV_RELIGION: "Ev. Religion", - Fach.KATH_RELIGION: "Kath. Religion", - Fach.WERTE_NORMEN: "Werte und Normen", - Fach.BRC: "BRC (Betriebswirtschaft)", - Fach.BVW: "BVW (Volkswirtschaft)", - Fach.ERNAEHRUNG: "Ernährung", - Fach.MECHATRONIK: "Mechatronik", - Fach.GESUNDHEIT_PFLEGE: "Gesundheit-Pflege", - Fach.PAEDAGOGIK_PSYCHOLOGIE: "Pädagogik-Psychologie", - } - return [{"value": f.value, "label": labels.get(f, f.value)} for f in Fach] + return [{"value": f.value, "label": FACH_LABELS.get(f, f.value)} for f in Fach] @router.get("/enums/niveaus") @@ -910,47 +410,4 @@ async def get_niveaus(): @router.get("/enums/typen") async def get_typen(): """Gibt alle Dokumenttypen zurück.""" - labels = { - DokumentTyp.AUFGABE: "Aufgabe", - DokumentTyp.ERWARTUNGSHORIZONT: "Erwartungshorizont", - DokumentTyp.DECKBLATT: "Deckblatt", - DokumentTyp.MATERIAL: "Material", - DokumentTyp.HOERVERSTEHEN: "Hörverstehen", - DokumentTyp.SPRACHMITTLUNG: "Sprachmittlung", - DokumentTyp.BEWERTUNGSBOGEN: "Bewertungsbogen", - } - return [{"value": t.value, "label": labels.get(t, t.value)} for t in DokumentTyp] - - -# ============================================================================ -# Backwards-compatibility aliases (used by tests) -# ============================================================================ -AbiturFach = Fach -Anforderungsniveau = Niveau -documents_db = _dokumente - - -class DocumentMetadata(BaseModel): - """Backwards-compatible metadata model for tests.""" - jahr: Optional[int] = None - bundesland: Optional[str] = None - fach: Optional[str] = None - niveau: Optional[str] = None - dokument_typ: Optional[str] = None - aufgaben_nummer: Optional[str] = None - - -# Backwards-compatible AbiturDokument for tests (different from internal dataclass) -class AbiturDokumentCompat(BaseModel): - """Backwards-compatible AbiturDokument model for tests.""" - id: str - filename: str - file_path: str - metadata: DocumentMetadata - status: VerarbeitungsStatus - recognition_result: Optional[RecognitionResult] = None - created_at: datetime - updated_at: datetime - - class Config: - arbitrary_types_allowed = True + return [{"value": t.value, "label": DOKUMENT_TYP_LABELS.get(t, t.value)} for t in DokumentTyp] diff --git a/backend-lehrer/abitur_docs_models.py b/backend-lehrer/abitur_docs_models.py new file mode 100644 index 0000000..c49e6c1 --- /dev/null +++ b/backend-lehrer/abitur_docs_models.py @@ -0,0 +1,327 @@ +""" +Abitur Document Store - Enums, Pydantic Models, Data Classes. + +Shared types for abitur_docs_api and abitur_docs_recognition. +""" + +from datetime import datetime +from typing import List, Dict, Any, Optional +from enum import Enum +from dataclasses import dataclass + +from pydantic import BaseModel, Field + + +# ============================================================================ +# Enums +# ============================================================================ + +class Bundesland(str, Enum): + """Bundesländer mit Zentralabitur.""" + NIEDERSACHSEN = "niedersachsen" + BAYERN = "bayern" + BADEN_WUERTTEMBERG = "baden_wuerttemberg" + NORDRHEIN_WESTFALEN = "nordrhein_westfalen" + HESSEN = "hessen" + SACHSEN = "sachsen" + THUERINGEN = "thueringen" + BERLIN = "berlin" + HAMBURG = "hamburg" + SCHLESWIG_HOLSTEIN = "schleswig_holstein" + BREMEN = "bremen" + BRANDENBURG = "brandenburg" + MECKLENBURG_VORPOMMERN = "mecklenburg_vorpommern" + SACHSEN_ANHALT = "sachsen_anhalt" + RHEINLAND_PFALZ = "rheinland_pfalz" + SAARLAND = "saarland" + + +class Fach(str, Enum): + """Abiturfächer.""" + DEUTSCH = "deutsch" + ENGLISCH = "englisch" + MATHEMATIK = "mathematik" + BIOLOGIE = "biologie" + CHEMIE = "chemie" + PHYSIK = "physik" + GESCHICHTE = "geschichte" + ERDKUNDE = "erdkunde" + POLITIK_WIRTSCHAFT = "politik_wirtschaft" + FRANZOESISCH = "franzoesisch" + SPANISCH = "spanisch" + LATEIN = "latein" + GRIECHISCH = "griechisch" + KUNST = "kunst" + MUSIK = "musik" + SPORT = "sport" + INFORMATIK = "informatik" + EV_RELIGION = "ev_religion" + KATH_RELIGION = "kath_religion" + WERTE_NORMEN = "werte_normen" + BRC = "brc" + BVW = "bvw" + ERNAEHRUNG = "ernaehrung" + MECHATRONIK = "mechatronik" + GESUNDHEIT_PFLEGE = "gesundheit_pflege" + PAEDAGOGIK_PSYCHOLOGIE = "paedagogik_psychologie" + + +class Niveau(str, Enum): + """Anforderungsniveau.""" + EA = "eA" + GA = "gA" + + +class DokumentTyp(str, Enum): + """Dokumenttyp.""" + AUFGABE = "aufgabe" + ERWARTUNGSHORIZONT = "erwartungshorizont" + DECKBLATT = "deckblatt" + MATERIAL = "material" + HOERVERSTEHEN = "hoerverstehen" + SPRACHMITTLUNG = "sprachmittlung" + BEWERTUNGSBOGEN = "bewertungsbogen" + + +class VerarbeitungsStatus(str, Enum): + """Status der Dokumentenverarbeitung.""" + PENDING = "pending" + PROCESSING = "processing" + RECOGNIZED = "recognized" + CONFIRMED = "confirmed" + INDEXED = "indexed" + ERROR = "error" + + +# ============================================================================ +# Fach-Mapping für Dateinamen +# ============================================================================ + +FACH_NAME_MAPPING = { + "deutsch": Fach.DEUTSCH, + "englisch": Fach.ENGLISCH, + "mathe": Fach.MATHEMATIK, + "mathematik": Fach.MATHEMATIK, + "biologie": Fach.BIOLOGIE, + "bio": Fach.BIOLOGIE, + "chemie": Fach.CHEMIE, + "physik": Fach.PHYSIK, + "geschichte": Fach.GESCHICHTE, + "erdkunde": Fach.ERDKUNDE, + "geographie": Fach.ERDKUNDE, + "politikwirtschaft": Fach.POLITIK_WIRTSCHAFT, + "politik": Fach.POLITIK_WIRTSCHAFT, + "franzoesisch": Fach.FRANZOESISCH, + "franz": Fach.FRANZOESISCH, + "spanisch": Fach.SPANISCH, + "latein": Fach.LATEIN, + "griechisch": Fach.GRIECHISCH, + "kunst": Fach.KUNST, + "musik": Fach.MUSIK, + "sport": Fach.SPORT, + "informatik": Fach.INFORMATIK, + "evreligion": Fach.EV_RELIGION, + "kathreligion": Fach.KATH_RELIGION, + "wertenormen": Fach.WERTE_NORMEN, + "brc": Fach.BRC, + "bvw": Fach.BVW, + "ernaehrung": Fach.ERNAEHRUNG, + "mecha": Fach.MECHATRONIK, + "mechatronik": Fach.MECHATRONIK, + "technikmecha": Fach.MECHATRONIK, + "gespfl": Fach.GESUNDHEIT_PFLEGE, + "paedpsych": Fach.PAEDAGOGIK_PSYCHOLOGIE, +} + + +# ============================================================================ +# Pydantic Models +# ============================================================================ + +class DokumentCreate(BaseModel): + """Manuelles Erstellen eines Dokuments.""" + bundesland: Bundesland + fach: Fach + jahr: int = Field(ge=2000, le=2100) + niveau: Niveau + typ: DokumentTyp + aufgaben_nummer: Optional[str] = None + + +class DokumentUpdate(BaseModel): + """Update für erkannte Metadaten.""" + bundesland: Optional[Bundesland] = None + fach: Optional[Fach] = None + jahr: Optional[int] = None + niveau: Optional[Niveau] = None + typ: Optional[DokumentTyp] = None + aufgaben_nummer: Optional[str] = None + status: Optional[VerarbeitungsStatus] = None + + +class DokumentResponse(BaseModel): + """Response für ein Dokument.""" + id: str + dateiname: str + original_dateiname: str + bundesland: Bundesland + fach: Fach + jahr: int + niveau: Niveau + typ: DokumentTyp + aufgaben_nummer: Optional[str] + status: VerarbeitungsStatus + confidence: float + file_path: str + file_size: int + indexed: bool + vector_ids: List[str] + created_at: datetime + updated_at: datetime + + +class ImportResult(BaseModel): + """Ergebnis eines ZIP-Imports.""" + total_files: int + recognized: int + errors: int + documents: List[DokumentResponse] + + +class RecognitionResult(BaseModel): + """Ergebnis der Dokumentenerkennung.""" + success: bool + bundesland: Optional[Bundesland] + fach: Optional[Fach] + jahr: Optional[int] + niveau: Optional[Niveau] + typ: Optional[DokumentTyp] + aufgaben_nummer: Optional[str] + confidence: float + raw_filename: str + suggestions: List[Dict[str, Any]] + + @property + def extracted(self) -> Dict[str, Any]: + """Backwards-compatible property returning extracted values as dict.""" + result = {} + if self.bundesland: + result["bundesland"] = self.bundesland.value + if self.fach: + result["fach"] = self.fach.value + if self.jahr: + result["jahr"] = self.jahr + if self.niveau: + result["niveau"] = self.niveau.value + if self.typ: + result["typ"] = self.typ.value + if self.aufgaben_nummer: + result["aufgaben_nummer"] = self.aufgaben_nummer + return result + + @property + def method(self) -> str: + """Backwards-compatible property for recognition method.""" + return "filename_pattern" + + +# ============================================================================ +# Internal Data Classes +# ============================================================================ + +@dataclass +class AbiturDokument: + """Internes Dokument.""" + id: str + dateiname: str + original_dateiname: str + bundesland: Bundesland + fach: Fach + jahr: int + niveau: Niveau + typ: DokumentTyp + aufgaben_nummer: Optional[str] + status: VerarbeitungsStatus + confidence: float + file_path: str + file_size: int + indexed: bool + vector_ids: List[str] + created_at: datetime + updated_at: datetime + + +# ============================================================================ +# Backwards-compatibility aliases (used by tests) +# ============================================================================ +AbiturFach = Fach +Anforderungsniveau = Niveau + + +class DocumentMetadata(BaseModel): + """Backwards-compatible metadata model for tests.""" + jahr: Optional[int] = None + bundesland: Optional[str] = None + fach: Optional[str] = None + niveau: Optional[str] = None + dokument_typ: Optional[str] = None + aufgaben_nummer: Optional[str] = None + + +class AbiturDokumentCompat(BaseModel): + """Backwards-compatible AbiturDokument model for tests.""" + id: str + filename: str + file_path: str + metadata: DocumentMetadata + status: VerarbeitungsStatus + recognition_result: Optional[RecognitionResult] = None + created_at: datetime + updated_at: datetime + + class Config: + arbitrary_types_allowed = True + + +# ============================================================================ +# Fach Labels (für Frontend Enum-Endpoint) +# ============================================================================ + +FACH_LABELS = { + Fach.DEUTSCH: "Deutsch", + Fach.ENGLISCH: "Englisch", + Fach.MATHEMATIK: "Mathematik", + Fach.BIOLOGIE: "Biologie", + Fach.CHEMIE: "Chemie", + Fach.PHYSIK: "Physik", + Fach.GESCHICHTE: "Geschichte", + Fach.ERDKUNDE: "Erdkunde", + Fach.POLITIK_WIRTSCHAFT: "Politik-Wirtschaft", + Fach.FRANZOESISCH: "Französisch", + Fach.SPANISCH: "Spanisch", + Fach.LATEIN: "Latein", + Fach.GRIECHISCH: "Griechisch", + Fach.KUNST: "Kunst", + Fach.MUSIK: "Musik", + Fach.SPORT: "Sport", + Fach.INFORMATIK: "Informatik", + Fach.EV_RELIGION: "Ev. Religion", + Fach.KATH_RELIGION: "Kath. Religion", + Fach.WERTE_NORMEN: "Werte und Normen", + Fach.BRC: "BRC (Betriebswirtschaft)", + Fach.BVW: "BVW (Volkswirtschaft)", + Fach.ERNAEHRUNG: "Ernährung", + Fach.MECHATRONIK: "Mechatronik", + Fach.GESUNDHEIT_PFLEGE: "Gesundheit-Pflege", + Fach.PAEDAGOGIK_PSYCHOLOGIE: "Pädagogik-Psychologie", +} + +DOKUMENT_TYP_LABELS = { + DokumentTyp.AUFGABE: "Aufgabe", + DokumentTyp.ERWARTUNGSHORIZONT: "Erwartungshorizont", + DokumentTyp.DECKBLATT: "Deckblatt", + DokumentTyp.MATERIAL: "Material", + DokumentTyp.HOERVERSTEHEN: "Hörverstehen", + DokumentTyp.SPRACHMITTLUNG: "Sprachmittlung", + DokumentTyp.BEWERTUNGSBOGEN: "Bewertungsbogen", +} diff --git a/backend-lehrer/abitur_docs_recognition.py b/backend-lehrer/abitur_docs_recognition.py new file mode 100644 index 0000000..69aae1b --- /dev/null +++ b/backend-lehrer/abitur_docs_recognition.py @@ -0,0 +1,124 @@ +""" +Abitur Document Store - Dateinamen-Erkennung und Helfer. + +Erkennt Metadaten aus NiBiS-Dateinamen (Niedersachsen). +""" + +import re +from typing import Dict, Any +from pathlib import Path + +from abitur_docs_models import ( + Bundesland, Fach, Niveau, DokumentTyp, VerarbeitungsStatus, + RecognitionResult, AbiturDokument, DokumentResponse, + FACH_NAME_MAPPING, +) + + +def parse_nibis_filename(filename: str) -> RecognitionResult: + """ + Erkennt Metadaten aus NiBiS-Dateinamen. + + Beispiele: + - 2025_Deutsch_eA_I.pdf + - 2025_Deutsch_eA_I_EWH.pdf + - 2025_Biologie_gA_1.pdf + - 2025_Englisch_eA_HV.pdf (Hörverstehen) + """ + result = RecognitionResult( + success=False, + bundesland=Bundesland.NIEDERSACHSEN, + fach=None, + jahr=None, + niveau=None, + typ=None, + aufgaben_nummer=None, + confidence=0.0, + raw_filename=filename, + suggestions=[] + ) + + # Bereinige Dateiname + name = Path(filename).stem.lower() + + # Extrahiere Jahr (4 Ziffern am Anfang) + jahr_match = re.match(r'^(\d{4})', name) + if jahr_match: + result.jahr = int(jahr_match.group(1)) + result.confidence += 0.2 + + # Extrahiere Fach + for fach_key, fach_enum in FACH_NAME_MAPPING.items(): + if fach_key in name.replace("_", "").replace("-", ""): + result.fach = fach_enum + result.confidence += 0.3 + break + + # Extrahiere Niveau (eA/gA) + if "_ea" in name or "_ea_" in name or "ea_" in name: + result.niveau = Niveau.EA + result.confidence += 0.2 + elif "_ga" in name or "_ga_" in name or "ga_" in name: + result.niveau = Niveau.GA + result.confidence += 0.2 + + # Extrahiere Typ + if "_ewh" in name: + result.typ = DokumentTyp.ERWARTUNGSHORIZONT + result.confidence += 0.2 + elif "_hv" in name or "hoerverstehen" in name: + result.typ = DokumentTyp.HOERVERSTEHEN + result.confidence += 0.15 + elif "_sm" in name or "_me" in name or "sprachmittlung" in name: + result.typ = DokumentTyp.SPRACHMITTLUNG + result.confidence += 0.15 + elif "deckblatt" in name: + result.typ = DokumentTyp.DECKBLATT + result.confidence += 0.15 + elif "material" in name: + result.typ = DokumentTyp.MATERIAL + result.confidence += 0.15 + elif "bewertung" in name: + result.typ = DokumentTyp.BEWERTUNGSBOGEN + result.confidence += 0.15 + else: + result.typ = DokumentTyp.AUFGABE + result.confidence += 0.1 + + # Extrahiere Aufgabennummer (römisch oder arabisch) + aufgabe_match = re.search(r'_([ivx]+|[1-4][abc]?)(?:_|\.pdf|$)', name, re.IGNORECASE) + if aufgabe_match: + result.aufgaben_nummer = aufgabe_match.group(1).upper() + result.confidence += 0.1 + + # Erfolg wenn mindestens Fach und Jahr erkannt + if result.fach and result.jahr: + result.success = True + + # Normalisiere Confidence auf max 1.0 + result.confidence = min(result.confidence, 1.0) + + return result + + +def to_dokument_response(doc: AbiturDokument) -> DokumentResponse: + """Konvertiert internes Dokument zu Response.""" + return DokumentResponse( + id=doc.id, + dateiname=doc.dateiname, + original_dateiname=doc.original_dateiname, + bundesland=doc.bundesland, + fach=doc.fach, + jahr=doc.jahr, + niveau=doc.niveau, + typ=doc.typ, + aufgaben_nummer=doc.aufgaben_nummer, + status=doc.status, + confidence=doc.confidence, + file_path=doc.file_path, + file_size=doc.file_size, + indexed=doc.indexed, + vector_ids=doc.vector_ids, + created_at=doc.created_at, + updated_at=doc.updated_at + ) diff --git a/backend-lehrer/alerts_agent/db/item_repository.py b/backend-lehrer/alerts_agent/db/item_repository.py new file mode 100644 index 0000000..f187c92 --- /dev/null +++ b/backend-lehrer/alerts_agent/db/item_repository.py @@ -0,0 +1,394 @@ +""" +Repository für Alert Items (einzelne Alerts/Artikel). +""" +import hashlib +import urllib.parse +import uuid +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any +from sqlalchemy.orm import Session as DBSession +from sqlalchemy import or_, func + +from .models import ( + AlertItemDB, AlertSourceEnum, AlertStatusEnum, RelevanceDecisionEnum +) + + +class AlertItemRepository: + """Repository für Alert Items (einzelne Alerts/Artikel).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create( + self, + topic_id: str, + title: str, + url: str, + snippet: str = "", + source: str = "google_alerts_rss", + published_at: datetime = None, + lang: str = "de", + ) -> AlertItemDB: + """Erstellt einen neuen Alert.""" + url_hash = self._compute_url_hash(url) + + alert = AlertItemDB( + id=str(uuid.uuid4()), + topic_id=topic_id, + title=title, + url=url, + snippet=snippet, + source=AlertSourceEnum(source), + published_at=published_at, + lang=lang, + url_hash=url_hash, + canonical_url=self._normalize_url(url), + ) + self.db.add(alert) + self.db.commit() + self.db.refresh(alert) + return alert + + def create_if_not_exists( + self, + topic_id: str, + title: str, + url: str, + snippet: str = "", + source: str = "google_alerts_rss", + published_at: datetime = None, + ) -> Optional[AlertItemDB]: + """Erstellt einen Alert nur wenn URL noch nicht existiert.""" + url_hash = self._compute_url_hash(url) + + existing = self.db.query(AlertItemDB).filter( + AlertItemDB.url_hash == url_hash + ).first() + + if existing: + return None # Duplikat + + return self.create( + topic_id=topic_id, + title=title, + url=url, + snippet=snippet, + source=source, + published_at=published_at, + ) + + # ==================== READ ==================== + + def get_by_id(self, alert_id: str) -> Optional[AlertItemDB]: + """Holt einen Alert nach ID.""" + return self.db.query(AlertItemDB).filter( + AlertItemDB.id == alert_id + ).first() + + def get_by_url_hash(self, url_hash: str) -> Optional[AlertItemDB]: + """Holt einen Alert nach URL-Hash.""" + return self.db.query(AlertItemDB).filter( + AlertItemDB.url_hash == url_hash + ).first() + + def get_inbox( + self, + user_id: str = None, + topic_id: str = None, + decision: str = None, + status: str = None, + limit: int = 50, + offset: int = 0, + ) -> List[AlertItemDB]: + """ + Holt Inbox-Items mit Filtern. + + Ohne decision werden KEEP und REVIEW angezeigt. + """ + query = self.db.query(AlertItemDB) + + if topic_id: + query = query.filter(AlertItemDB.topic_id == topic_id) + + if decision: + query = query.filter( + AlertItemDB.relevance_decision == RelevanceDecisionEnum(decision) + ) + else: + # Default: KEEP und REVIEW + query = query.filter( + or_( + AlertItemDB.relevance_decision == RelevanceDecisionEnum.KEEP, + AlertItemDB.relevance_decision == RelevanceDecisionEnum.REVIEW, + AlertItemDB.relevance_decision.is_(None) + ) + ) + + if status: + query = query.filter(AlertItemDB.status == AlertStatusEnum(status)) + + return query.order_by( + AlertItemDB.relevance_score.desc().nullslast(), + AlertItemDB.fetched_at.desc() + ).offset(offset).limit(limit).all() + + def get_unscored( + self, + topic_id: str = None, + limit: int = 100, + ) -> List[AlertItemDB]: + """Holt alle unbewerteten Alerts.""" + query = self.db.query(AlertItemDB).filter( + AlertItemDB.status == AlertStatusEnum.NEW + ) + + if topic_id: + query = query.filter(AlertItemDB.topic_id == topic_id) + + return query.order_by(AlertItemDB.fetched_at.desc()).limit(limit).all() + + def get_by_topic( + self, + topic_id: str, + limit: int = 100, + offset: int = 0, + ) -> List[AlertItemDB]: + """Holt alle Alerts eines Topics.""" + return self.db.query(AlertItemDB).filter( + AlertItemDB.topic_id == topic_id + ).order_by( + AlertItemDB.fetched_at.desc() + ).offset(offset).limit(limit).all() + + def count_by_status(self, topic_id: str = None) -> Dict[str, int]: + """Zählt Alerts nach Status.""" + query = self.db.query( + AlertItemDB.status, + func.count(AlertItemDB.id).label('count') + ) + + if topic_id: + query = query.filter(AlertItemDB.topic_id == topic_id) + + results = query.group_by(AlertItemDB.status).all() + + return {r[0].value: r[1] for r in results} + + def count_by_decision(self, topic_id: str = None) -> Dict[str, int]: + """Zählt Alerts nach Relevanz-Entscheidung.""" + query = self.db.query( + AlertItemDB.relevance_decision, + func.count(AlertItemDB.id).label('count') + ) + + if topic_id: + query = query.filter(AlertItemDB.topic_id == topic_id) + + results = query.group_by(AlertItemDB.relevance_decision).all() + + return { + (r[0].value if r[0] else "unscored"): r[1] + for r in results + } + + # ==================== UPDATE ==================== + + def update_scoring( + self, + alert_id: str, + score: float, + decision: str, + reasons: List[str] = None, + summary: str = None, + model: str = None, + ) -> Optional[AlertItemDB]: + """Aktualisiert das Scoring eines Alerts.""" + alert = self.get_by_id(alert_id) + if not alert: + return None + + alert.relevance_score = score + alert.relevance_decision = RelevanceDecisionEnum(decision) + alert.relevance_reasons = reasons or [] + alert.relevance_summary = summary + alert.scored_by_model = model + alert.scored_at = datetime.utcnow() + alert.status = AlertStatusEnum.SCORED + alert.processed_at = datetime.utcnow() + + self.db.commit() + self.db.refresh(alert) + return alert + + def update_status( + self, + alert_id: str, + status: str, + ) -> Optional[AlertItemDB]: + """Aktualisiert den Status eines Alerts.""" + alert = self.get_by_id(alert_id) + if not alert: + return None + + alert.status = AlertStatusEnum(status) + + self.db.commit() + self.db.refresh(alert) + return alert + + def mark_reviewed( + self, + alert_id: str, + is_relevant: bool, + notes: str = None, + tags: List[str] = None, + ) -> Optional[AlertItemDB]: + """Markiert einen Alert als reviewed mit Feedback.""" + alert = self.get_by_id(alert_id) + if not alert: + return None + + alert.status = AlertStatusEnum.REVIEWED + alert.user_marked_relevant = is_relevant + if notes: + alert.user_notes = notes + if tags: + alert.user_tags = tags + + self.db.commit() + self.db.refresh(alert) + return alert + + def archive(self, alert_id: str) -> Optional[AlertItemDB]: + """Archiviert einen Alert.""" + return self.update_status(alert_id, "archived") + + # ==================== DELETE ==================== + + def delete(self, alert_id: str) -> bool: + """Löscht einen Alert.""" + alert = self.get_by_id(alert_id) + if not alert: + return False + + self.db.delete(alert) + self.db.commit() + return True + + def delete_old(self, days: int = 90, topic_id: str = None) -> int: + """Löscht alte archivierte Alerts.""" + cutoff = datetime.utcnow() - timedelta(days=days) + + query = self.db.query(AlertItemDB).filter( + AlertItemDB.status == AlertStatusEnum.ARCHIVED, + AlertItemDB.fetched_at < cutoff, + ) + + if topic_id: + query = query.filter(AlertItemDB.topic_id == topic_id) + + count = query.delete() + self.db.commit() + return count + + # ==================== FOR RSS FETCHER ==================== + + def get_existing_urls(self, topic_id: str) -> set: + """ + Holt alle bekannten URL-Hashes für ein Topic. + + Wird vom RSS-Fetcher verwendet um Duplikate zu vermeiden. + """ + results = self.db.query(AlertItemDB.url_hash).filter( + AlertItemDB.topic_id == topic_id + ).all() + + return {r[0] for r in results if r[0]} + + def create_from_alert_item(self, alert_item, topic_id: str) -> AlertItemDB: + """ + Erstellt einen Alert aus einem AlertItem-Objekt vom RSS-Fetcher. + + Args: + alert_item: AlertItem from rss_fetcher + topic_id: Topic ID to associate with + + Returns: + Created AlertItemDB instance + """ + return self.create( + topic_id=topic_id, + title=alert_item.title, + url=alert_item.url, + snippet=alert_item.snippet or "", + source=alert_item.source.value if hasattr(alert_item.source, 'value') else str(alert_item.source), + published_at=alert_item.published_at, + ) + + # ==================== HELPER ==================== + + def _compute_url_hash(self, url: str) -> str: + """Berechnet SHA256 Hash der normalisierten URL.""" + normalized = self._normalize_url(url) + return hashlib.sha256(normalized.encode()).hexdigest()[:16] + + def _normalize_url(self, url: str) -> str: + """Normalisiert URL für Deduplizierung.""" + parsed = urllib.parse.urlparse(url) + + # Tracking-Parameter entfernen + tracking_params = { + "utm_source", "utm_medium", "utm_campaign", "utm_content", "utm_term", + "fbclid", "gclid", "ref", "source" + } + + query_params = urllib.parse.parse_qs(parsed.query) + cleaned_params = {k: v for k, v in query_params.items() + if k.lower() not in tracking_params} + + cleaned_query = urllib.parse.urlencode(cleaned_params, doseq=True) + + # Rekonstruiere URL ohne Fragment + normalized = urllib.parse.urlunparse(( + parsed.scheme, + parsed.netloc.lower(), + parsed.path.rstrip("/"), + parsed.params, + cleaned_query, + "" # No fragment + )) + + return normalized + + # ==================== CONVERSION ==================== + + def to_dict(self, alert: AlertItemDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": alert.id, + "topic_id": alert.topic_id, + "title": alert.title, + "url": alert.url, + "snippet": alert.snippet, + "source": alert.source.value, + "lang": alert.lang, + "published_at": alert.published_at.isoformat() if alert.published_at else None, + "fetched_at": alert.fetched_at.isoformat() if alert.fetched_at else None, + "status": alert.status.value, + "relevance": { + "score": alert.relevance_score, + "decision": alert.relevance_decision.value if alert.relevance_decision else None, + "reasons": alert.relevance_reasons, + "summary": alert.relevance_summary, + "model": alert.scored_by_model, + "scored_at": alert.scored_at.isoformat() if alert.scored_at else None, + }, + "user_feedback": { + "marked_relevant": alert.user_marked_relevant, + "tags": alert.user_tags, + "notes": alert.user_notes, + }, + } diff --git a/backend-lehrer/alerts_agent/db/profile_repository.py b/backend-lehrer/alerts_agent/db/profile_repository.py new file mode 100644 index 0000000..09cdebb --- /dev/null +++ b/backend-lehrer/alerts_agent/db/profile_repository.py @@ -0,0 +1,226 @@ +""" +Repository für Alert Profiles (Nutzer-Profile für Relevanz-Scoring). +""" +import uuid +from datetime import datetime +from typing import Optional, List, Dict, Any +from sqlalchemy.orm import Session as DBSession +from sqlalchemy.orm.attributes import flag_modified + +from .models import AlertProfileDB + + +class ProfileRepository: + """Repository für Alert Profiles (Nutzer-Profile für Relevanz-Scoring).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE / GET-OR-CREATE ==================== + + def get_or_create(self, user_id: str = None) -> AlertProfileDB: + """Holt oder erstellt ein Profil.""" + profile = self.get_by_user_id(user_id) + if profile: + return profile + + # Neues Profil erstellen + profile = AlertProfileDB( + id=str(uuid.uuid4()), + user_id=user_id, + name="Default" if not user_id else f"Profile {user_id[:8]}", + ) + self.db.add(profile) + self.db.commit() + self.db.refresh(profile) + return profile + + def create_default_education_profile(self, user_id: str = None) -> AlertProfileDB: + """Erstellt ein Standard-Profil für Bildungsthemen.""" + profile = AlertProfileDB( + id=str(uuid.uuid4()), + user_id=user_id, + name="Bildung Default", + priorities=[ + { + "label": "Inklusion", + "weight": 0.9, + "keywords": ["inklusiv", "Förderbedarf", "Behinderung", "Barrierefreiheit"], + "description": "Inklusive Bildung, Förderschulen, Nachteilsausgleich" + }, + { + "label": "Datenschutz Schule", + "weight": 0.85, + "keywords": ["DSGVO", "Schülerfotos", "Einwilligung", "personenbezogene Daten"], + "description": "DSGVO in Schulen, Datenschutz bei Klassenfotos" + }, + { + "label": "Schulrecht Bayern", + "weight": 0.8, + "keywords": ["BayEUG", "Schulordnung", "Kultusministerium", "Bayern"], + "description": "Bayerisches Schulrecht, Verordnungen" + }, + { + "label": "Digitalisierung Schule", + "weight": 0.7, + "keywords": ["DigitalPakt", "Tablet-Klasse", "Lernplattform"], + "description": "Digitale Medien im Unterricht" + }, + ], + exclusions=["Stellenanzeige", "Praktikum gesucht", "Werbung", "Pressemitteilung"], + policies={ + "prefer_german_sources": True, + "max_age_days": 30, + "min_content_length": 100, + } + ) + self.db.add(profile) + self.db.commit() + self.db.refresh(profile) + return profile + + # ==================== READ ==================== + + def get_by_id(self, profile_id: str) -> Optional[AlertProfileDB]: + """Holt ein Profil nach ID.""" + return self.db.query(AlertProfileDB).filter( + AlertProfileDB.id == profile_id + ).first() + + def get_by_user_id(self, user_id: str) -> Optional[AlertProfileDB]: + """Holt ein Profil nach User-ID.""" + if not user_id: + # Default-Profil ohne User + return self.db.query(AlertProfileDB).filter( + AlertProfileDB.user_id.is_(None) + ).first() + + return self.db.query(AlertProfileDB).filter( + AlertProfileDB.user_id == user_id + ).first() + + # ==================== UPDATE ==================== + + def update_priorities( + self, + profile_id: str, + priorities: List[Dict], + ) -> Optional[AlertProfileDB]: + """Aktualisiert die Prioritäten eines Profils.""" + profile = self.get_by_id(profile_id) + if not profile: + return None + + profile.priorities = priorities + self.db.commit() + self.db.refresh(profile) + return profile + + def update_exclusions( + self, + profile_id: str, + exclusions: List[str], + ) -> Optional[AlertProfileDB]: + """Aktualisiert die Ausschlüsse eines Profils.""" + profile = self.get_by_id(profile_id) + if not profile: + return None + + profile.exclusions = exclusions + self.db.commit() + self.db.refresh(profile) + return profile + + def add_feedback( + self, + profile_id: str, + title: str, + url: str, + is_relevant: bool, + reason: str = "", + ) -> Optional[AlertProfileDB]: + """Fügt Feedback als Beispiel hinzu.""" + profile = self.get_by_id(profile_id) + if not profile: + return None + + example = { + "title": title, + "url": url, + "reason": reason, + "added_at": datetime.utcnow().isoformat(), + } + + if is_relevant: + examples = list(profile.positive_examples or []) + examples.append(example) + profile.positive_examples = examples[-20:] # Max 20 + profile.total_kept += 1 + flag_modified(profile, "positive_examples") + else: + examples = list(profile.negative_examples or []) + examples.append(example) + profile.negative_examples = examples[-20:] # Max 20 + profile.total_dropped += 1 + flag_modified(profile, "negative_examples") + + profile.total_scored += 1 + self.db.commit() + self.db.refresh(profile) + return profile + + def update_stats( + self, + profile_id: str, + kept: int = 0, + dropped: int = 0, + ) -> Optional[AlertProfileDB]: + """Aktualisiert die Statistiken eines Profils.""" + profile = self.get_by_id(profile_id) + if not profile: + return None + + profile.total_scored += kept + dropped + profile.total_kept += kept + profile.total_dropped += dropped + + self.db.commit() + self.db.refresh(profile) + return profile + + # ==================== DELETE ==================== + + def delete(self, profile_id: str) -> bool: + """Löscht ein Profil.""" + profile = self.get_by_id(profile_id) + if not profile: + return False + + self.db.delete(profile) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dict(self, profile: AlertProfileDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": profile.id, + "user_id": profile.user_id, + "name": profile.name, + "priorities": profile.priorities, + "exclusions": profile.exclusions, + "policies": profile.policies, + "examples": { + "positive": len(profile.positive_examples or []), + "negative": len(profile.negative_examples or []), + }, + "stats": { + "total_scored": profile.total_scored, + "total_kept": profile.total_kept, + "total_dropped": profile.total_dropped, + "accuracy_estimate": profile.accuracy_estimate, + }, + "created_at": profile.created_at.isoformat() if profile.created_at else None, + "updated_at": profile.updated_at.isoformat() if profile.updated_at else None, + } diff --git a/backend-lehrer/alerts_agent/db/repository.py b/backend-lehrer/alerts_agent/db/repository.py index b3e5b98..cd0739b 100644 --- a/backend-lehrer/alerts_agent/db/repository.py +++ b/backend-lehrer/alerts_agent/db/repository.py @@ -1,992 +1,20 @@ """ Repository für Alerts Agent - CRUD Operationen für Topics, Items, Rules und Profile. -Abstraktion der Datenbank-Operationen. +Barrel re-export — die eigentliche Logik lebt in: +- topic_repository.py +- item_repository.py +- rule_repository.py +- profile_repository.py """ -import hashlib -from datetime import datetime -from typing import Optional, List, Dict, Any -from sqlalchemy.orm import Session as DBSession -from sqlalchemy.orm.attributes import flag_modified -from sqlalchemy import or_, and_, func - -from .models import ( - AlertTopicDB, AlertItemDB, AlertRuleDB, AlertProfileDB, - AlertSourceEnum, AlertStatusEnum, RelevanceDecisionEnum, - FeedTypeEnum, RuleActionEnum -) - - -# ============================================================================= -# TOPIC REPOSITORY -# ============================================================================= - -class TopicRepository: - """Repository für Alert Topics (Feed-Quellen).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create( - self, - name: str, - feed_url: str = None, - feed_type: str = "rss", - user_id: str = None, - description: str = "", - fetch_interval_minutes: int = 60, - is_active: bool = True, - ) -> AlertTopicDB: - """Erstellt ein neues Topic.""" - import uuid - topic = AlertTopicDB( - id=str(uuid.uuid4()), - user_id=user_id, - name=name, - description=description, - feed_url=feed_url, - feed_type=FeedTypeEnum(feed_type), - fetch_interval_minutes=fetch_interval_minutes, - is_active=is_active, - ) - self.db.add(topic) - self.db.commit() - self.db.refresh(topic) - return topic - - # ==================== READ ==================== - - def get_by_id(self, topic_id: str) -> Optional[AlertTopicDB]: - """Holt ein Topic nach ID.""" - return self.db.query(AlertTopicDB).filter( - AlertTopicDB.id == topic_id - ).first() - - def get_all( - self, - user_id: str = None, - is_active: bool = None, - limit: int = 100, - offset: int = 0, - ) -> List[AlertTopicDB]: - """Holt alle Topics mit optionalen Filtern.""" - query = self.db.query(AlertTopicDB) - - if user_id: - query = query.filter(AlertTopicDB.user_id == user_id) - if is_active is not None: - query = query.filter(AlertTopicDB.is_active == is_active) - - return query.order_by( - AlertTopicDB.created_at.desc() - ).offset(offset).limit(limit).all() - - def get_active_for_fetch(self) -> List[AlertTopicDB]: - """Holt alle aktiven Topics die gefetcht werden sollten.""" - # Topics wo fetch_interval_minutes vergangen ist - return self.db.query(AlertTopicDB).filter( - AlertTopicDB.is_active == True, - AlertTopicDB.feed_url.isnot(None), - ).all() - - # ==================== UPDATE ==================== - - def update( - self, - topic_id: str, - name: str = None, - description: str = None, - feed_url: str = None, - feed_type: str = None, - is_active: bool = None, - fetch_interval_minutes: int = None, - ) -> Optional[AlertTopicDB]: - """Aktualisiert ein Topic.""" - topic = self.get_by_id(topic_id) - if not topic: - return None - - if name is not None: - topic.name = name - if description is not None: - topic.description = description - if feed_url is not None: - topic.feed_url = feed_url - if feed_type is not None: - topic.feed_type = FeedTypeEnum(feed_type) - if is_active is not None: - topic.is_active = is_active - if fetch_interval_minutes is not None: - topic.fetch_interval_minutes = fetch_interval_minutes - - self.db.commit() - self.db.refresh(topic) - return topic - - def update_fetch_status( - self, - topic_id: str, - last_fetch_error: str = None, - items_fetched: int = 0, - ) -> Optional[AlertTopicDB]: - """Aktualisiert den Fetch-Status eines Topics.""" - topic = self.get_by_id(topic_id) - if not topic: - return None - - topic.last_fetched_at = datetime.utcnow() - topic.last_fetch_error = last_fetch_error - topic.total_items_fetched += items_fetched - - self.db.commit() - self.db.refresh(topic) - return topic - - def increment_stats( - self, - topic_id: str, - kept: int = 0, - dropped: int = 0, - ) -> Optional[AlertTopicDB]: - """Erhöht die Statistiken eines Topics.""" - topic = self.get_by_id(topic_id) - if not topic: - return None - - topic.items_kept += kept - topic.items_dropped += dropped - - self.db.commit() - self.db.refresh(topic) - return topic - - # ==================== DELETE ==================== - - def delete(self, topic_id: str) -> bool: - """Löscht ein Topic (und alle zugehörigen Items via CASCADE).""" - topic = self.get_by_id(topic_id) - if not topic: - return False - - self.db.delete(topic) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dict(self, topic: AlertTopicDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": topic.id, - "user_id": topic.user_id, - "name": topic.name, - "description": topic.description, - "feed_url": topic.feed_url, - "feed_type": topic.feed_type.value, - "is_active": topic.is_active, - "fetch_interval_minutes": topic.fetch_interval_minutes, - "last_fetched_at": topic.last_fetched_at.isoformat() if topic.last_fetched_at else None, - "last_fetch_error": topic.last_fetch_error, - "stats": { - "total_items_fetched": topic.total_items_fetched, - "items_kept": topic.items_kept, - "items_dropped": topic.items_dropped, - }, - "created_at": topic.created_at.isoformat() if topic.created_at else None, - "updated_at": topic.updated_at.isoformat() if topic.updated_at else None, - } - - -# ============================================================================= -# ALERT ITEM REPOSITORY -# ============================================================================= - -class AlertItemRepository: - """Repository für Alert Items (einzelne Alerts/Artikel).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create( - self, - topic_id: str, - title: str, - url: str, - snippet: str = "", - source: str = "google_alerts_rss", - published_at: datetime = None, - lang: str = "de", - ) -> AlertItemDB: - """Erstellt einen neuen Alert.""" - import uuid - - # URL-Hash berechnen - url_hash = self._compute_url_hash(url) - - alert = AlertItemDB( - id=str(uuid.uuid4()), - topic_id=topic_id, - title=title, - url=url, - snippet=snippet, - source=AlertSourceEnum(source), - published_at=published_at, - lang=lang, - url_hash=url_hash, - canonical_url=self._normalize_url(url), - ) - self.db.add(alert) - self.db.commit() - self.db.refresh(alert) - return alert - - def create_if_not_exists( - self, - topic_id: str, - title: str, - url: str, - snippet: str = "", - source: str = "google_alerts_rss", - published_at: datetime = None, - ) -> Optional[AlertItemDB]: - """Erstellt einen Alert nur wenn URL noch nicht existiert.""" - url_hash = self._compute_url_hash(url) - - existing = self.db.query(AlertItemDB).filter( - AlertItemDB.url_hash == url_hash - ).first() - - if existing: - return None # Duplikat - - return self.create( - topic_id=topic_id, - title=title, - url=url, - snippet=snippet, - source=source, - published_at=published_at, - ) - - # ==================== READ ==================== - - def get_by_id(self, alert_id: str) -> Optional[AlertItemDB]: - """Holt einen Alert nach ID.""" - return self.db.query(AlertItemDB).filter( - AlertItemDB.id == alert_id - ).first() - - def get_by_url_hash(self, url_hash: str) -> Optional[AlertItemDB]: - """Holt einen Alert nach URL-Hash.""" - return self.db.query(AlertItemDB).filter( - AlertItemDB.url_hash == url_hash - ).first() - - def get_inbox( - self, - user_id: str = None, - topic_id: str = None, - decision: str = None, - status: str = None, - limit: int = 50, - offset: int = 0, - ) -> List[AlertItemDB]: - """ - Holt Inbox-Items mit Filtern. - - Ohne decision werden KEEP und REVIEW angezeigt. - """ - query = self.db.query(AlertItemDB) - - if topic_id: - query = query.filter(AlertItemDB.topic_id == topic_id) - - if decision: - query = query.filter( - AlertItemDB.relevance_decision == RelevanceDecisionEnum(decision) - ) - else: - # Default: KEEP und REVIEW - query = query.filter( - or_( - AlertItemDB.relevance_decision == RelevanceDecisionEnum.KEEP, - AlertItemDB.relevance_decision == RelevanceDecisionEnum.REVIEW, - AlertItemDB.relevance_decision.is_(None) - ) - ) - - if status: - query = query.filter(AlertItemDB.status == AlertStatusEnum(status)) - - return query.order_by( - AlertItemDB.relevance_score.desc().nullslast(), - AlertItemDB.fetched_at.desc() - ).offset(offset).limit(limit).all() - - def get_unscored( - self, - topic_id: str = None, - limit: int = 100, - ) -> List[AlertItemDB]: - """Holt alle unbewerteten Alerts.""" - query = self.db.query(AlertItemDB).filter( - AlertItemDB.status == AlertStatusEnum.NEW - ) - - if topic_id: - query = query.filter(AlertItemDB.topic_id == topic_id) - - return query.order_by(AlertItemDB.fetched_at.desc()).limit(limit).all() - - def get_by_topic( - self, - topic_id: str, - limit: int = 100, - offset: int = 0, - ) -> List[AlertItemDB]: - """Holt alle Alerts eines Topics.""" - return self.db.query(AlertItemDB).filter( - AlertItemDB.topic_id == topic_id - ).order_by( - AlertItemDB.fetched_at.desc() - ).offset(offset).limit(limit).all() - - def count_by_status(self, topic_id: str = None) -> Dict[str, int]: - """Zählt Alerts nach Status.""" - query = self.db.query( - AlertItemDB.status, - func.count(AlertItemDB.id).label('count') - ) - - if topic_id: - query = query.filter(AlertItemDB.topic_id == topic_id) - - results = query.group_by(AlertItemDB.status).all() - - return {r[0].value: r[1] for r in results} - - def count_by_decision(self, topic_id: str = None) -> Dict[str, int]: - """Zählt Alerts nach Relevanz-Entscheidung.""" - query = self.db.query( - AlertItemDB.relevance_decision, - func.count(AlertItemDB.id).label('count') - ) - - if topic_id: - query = query.filter(AlertItemDB.topic_id == topic_id) - - results = query.group_by(AlertItemDB.relevance_decision).all() - - return { - (r[0].value if r[0] else "unscored"): r[1] - for r in results - } - - # ==================== UPDATE ==================== - - def update_scoring( - self, - alert_id: str, - score: float, - decision: str, - reasons: List[str] = None, - summary: str = None, - model: str = None, - ) -> Optional[AlertItemDB]: - """Aktualisiert das Scoring eines Alerts.""" - alert = self.get_by_id(alert_id) - if not alert: - return None - - alert.relevance_score = score - alert.relevance_decision = RelevanceDecisionEnum(decision) - alert.relevance_reasons = reasons or [] - alert.relevance_summary = summary - alert.scored_by_model = model - alert.scored_at = datetime.utcnow() - alert.status = AlertStatusEnum.SCORED - alert.processed_at = datetime.utcnow() - - self.db.commit() - self.db.refresh(alert) - return alert - - def update_status( - self, - alert_id: str, - status: str, - ) -> Optional[AlertItemDB]: - """Aktualisiert den Status eines Alerts.""" - alert = self.get_by_id(alert_id) - if not alert: - return None - - alert.status = AlertStatusEnum(status) - - self.db.commit() - self.db.refresh(alert) - return alert - - def mark_reviewed( - self, - alert_id: str, - is_relevant: bool, - notes: str = None, - tags: List[str] = None, - ) -> Optional[AlertItemDB]: - """Markiert einen Alert als reviewed mit Feedback.""" - alert = self.get_by_id(alert_id) - if not alert: - return None - - alert.status = AlertStatusEnum.REVIEWED - alert.user_marked_relevant = is_relevant - if notes: - alert.user_notes = notes - if tags: - alert.user_tags = tags - - self.db.commit() - self.db.refresh(alert) - return alert - - def archive(self, alert_id: str) -> Optional[AlertItemDB]: - """Archiviert einen Alert.""" - return self.update_status(alert_id, "archived") - - # ==================== DELETE ==================== - - def delete(self, alert_id: str) -> bool: - """Löscht einen Alert.""" - alert = self.get_by_id(alert_id) - if not alert: - return False - - self.db.delete(alert) - self.db.commit() - return True - - def delete_old(self, days: int = 90, topic_id: str = None) -> int: - """Löscht alte archivierte Alerts.""" - from datetime import timedelta - cutoff = datetime.utcnow() - timedelta(days=days) - - query = self.db.query(AlertItemDB).filter( - AlertItemDB.status == AlertStatusEnum.ARCHIVED, - AlertItemDB.fetched_at < cutoff, - ) - - if topic_id: - query = query.filter(AlertItemDB.topic_id == topic_id) - - count = query.delete() - self.db.commit() - return count - - # ==================== FOR RSS FETCHER ==================== - - def get_existing_urls(self, topic_id: str) -> set: - """ - Holt alle bekannten URL-Hashes für ein Topic. - - Wird vom RSS-Fetcher verwendet um Duplikate zu vermeiden. - """ - results = self.db.query(AlertItemDB.url_hash).filter( - AlertItemDB.topic_id == topic_id - ).all() - - return {r[0] for r in results if r[0]} - - def create_from_alert_item(self, alert_item, topic_id: str) -> AlertItemDB: - """ - Erstellt einen Alert aus einem AlertItem-Objekt vom RSS-Fetcher. - - Args: - alert_item: AlertItem from rss_fetcher - topic_id: Topic ID to associate with - - Returns: - Created AlertItemDB instance - """ - return self.create( - topic_id=topic_id, - title=alert_item.title, - url=alert_item.url, - snippet=alert_item.snippet or "", - source=alert_item.source.value if hasattr(alert_item.source, 'value') else str(alert_item.source), - published_at=alert_item.published_at, - ) - - # ==================== HELPER ==================== - - def _compute_url_hash(self, url: str) -> str: - """Berechnet SHA256 Hash der normalisierten URL.""" - normalized = self._normalize_url(url) - return hashlib.sha256(normalized.encode()).hexdigest()[:16] - - def _normalize_url(self, url: str) -> str: - """Normalisiert URL für Deduplizierung.""" - import urllib.parse - parsed = urllib.parse.urlparse(url) - - # Tracking-Parameter entfernen - tracking_params = { - "utm_source", "utm_medium", "utm_campaign", "utm_content", "utm_term", - "fbclid", "gclid", "ref", "source" - } - - query_params = urllib.parse.parse_qs(parsed.query) - cleaned_params = {k: v for k, v in query_params.items() - if k.lower() not in tracking_params} - - cleaned_query = urllib.parse.urlencode(cleaned_params, doseq=True) - - # Rekonstruiere URL ohne Fragment - normalized = urllib.parse.urlunparse(( - parsed.scheme, - parsed.netloc.lower(), - parsed.path.rstrip("/"), - parsed.params, - cleaned_query, - "" # No fragment - )) - - return normalized - - # ==================== CONVERSION ==================== - - def to_dict(self, alert: AlertItemDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": alert.id, - "topic_id": alert.topic_id, - "title": alert.title, - "url": alert.url, - "snippet": alert.snippet, - "source": alert.source.value, - "lang": alert.lang, - "published_at": alert.published_at.isoformat() if alert.published_at else None, - "fetched_at": alert.fetched_at.isoformat() if alert.fetched_at else None, - "status": alert.status.value, - "relevance": { - "score": alert.relevance_score, - "decision": alert.relevance_decision.value if alert.relevance_decision else None, - "reasons": alert.relevance_reasons, - "summary": alert.relevance_summary, - "model": alert.scored_by_model, - "scored_at": alert.scored_at.isoformat() if alert.scored_at else None, - }, - "user_feedback": { - "marked_relevant": alert.user_marked_relevant, - "tags": alert.user_tags, - "notes": alert.user_notes, - }, - } - - -# ============================================================================= -# ALERT RULE REPOSITORY -# ============================================================================= - -class RuleRepository: - """Repository für Alert Rules (Filterregeln).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE ==================== - - def create( - self, - name: str, - conditions: List[Dict], - action_type: str = "keep", - action_config: Dict = None, - topic_id: str = None, - user_id: str = None, - description: str = "", - priority: int = 0, - ) -> AlertRuleDB: - """Erstellt eine neue Regel.""" - import uuid - rule = AlertRuleDB( - id=str(uuid.uuid4()), - topic_id=topic_id, - user_id=user_id, - name=name, - description=description, - conditions=conditions, - action_type=RuleActionEnum(action_type), - action_config=action_config or {}, - priority=priority, - ) - self.db.add(rule) - self.db.commit() - self.db.refresh(rule) - return rule - - # ==================== READ ==================== - - def get_by_id(self, rule_id: str) -> Optional[AlertRuleDB]: - """Holt eine Regel nach ID.""" - return self.db.query(AlertRuleDB).filter( - AlertRuleDB.id == rule_id - ).first() - - def get_active( - self, - topic_id: str = None, - user_id: str = None, - ) -> List[AlertRuleDB]: - """Holt alle aktiven Regeln, sortiert nach Priorität.""" - query = self.db.query(AlertRuleDB).filter( - AlertRuleDB.is_active == True - ) - - if topic_id: - # Topic-spezifische und globale Regeln - query = query.filter( - or_( - AlertRuleDB.topic_id == topic_id, - AlertRuleDB.topic_id.is_(None) - ) - ) - - if user_id: - query = query.filter( - or_( - AlertRuleDB.user_id == user_id, - AlertRuleDB.user_id.is_(None) - ) - ) - - return query.order_by(AlertRuleDB.priority.desc()).all() - - def get_all( - self, - user_id: str = None, - topic_id: str = None, - is_active: bool = None, - ) -> List[AlertRuleDB]: - """Holt alle Regeln mit optionalen Filtern.""" - query = self.db.query(AlertRuleDB) - - if user_id: - query = query.filter(AlertRuleDB.user_id == user_id) - if topic_id: - query = query.filter(AlertRuleDB.topic_id == topic_id) - if is_active is not None: - query = query.filter(AlertRuleDB.is_active == is_active) - - return query.order_by(AlertRuleDB.priority.desc()).all() - - # ==================== UPDATE ==================== - - def update( - self, - rule_id: str, - name: str = None, - description: str = None, - conditions: List[Dict] = None, - action_type: str = None, - action_config: Dict = None, - priority: int = None, - is_active: bool = None, - ) -> Optional[AlertRuleDB]: - """Aktualisiert eine Regel.""" - rule = self.get_by_id(rule_id) - if not rule: - return None - - if name is not None: - rule.name = name - if description is not None: - rule.description = description - if conditions is not None: - rule.conditions = conditions - if action_type is not None: - rule.action_type = RuleActionEnum(action_type) - if action_config is not None: - rule.action_config = action_config - if priority is not None: - rule.priority = priority - if is_active is not None: - rule.is_active = is_active - - self.db.commit() - self.db.refresh(rule) - return rule - - def increment_match_count(self, rule_id: str) -> Optional[AlertRuleDB]: - """Erhöht den Match-Counter einer Regel.""" - rule = self.get_by_id(rule_id) - if not rule: - return None - - rule.match_count += 1 - rule.last_matched_at = datetime.utcnow() - - self.db.commit() - self.db.refresh(rule) - return rule - - # ==================== DELETE ==================== - - def delete(self, rule_id: str) -> bool: - """Löscht eine Regel.""" - rule = self.get_by_id(rule_id) - if not rule: - return False - - self.db.delete(rule) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dict(self, rule: AlertRuleDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": rule.id, - "topic_id": rule.topic_id, - "user_id": rule.user_id, - "name": rule.name, - "description": rule.description, - "conditions": rule.conditions, - "action_type": rule.action_type.value, - "action_config": rule.action_config, - "priority": rule.priority, - "is_active": rule.is_active, - "stats": { - "match_count": rule.match_count, - "last_matched_at": rule.last_matched_at.isoformat() if rule.last_matched_at else None, - }, - "created_at": rule.created_at.isoformat() if rule.created_at else None, - "updated_at": rule.updated_at.isoformat() if rule.updated_at else None, - } - - -# ============================================================================= -# ALERT PROFILE REPOSITORY -# ============================================================================= - -class ProfileRepository: - """Repository für Alert Profiles (Nutzer-Profile für Relevanz-Scoring).""" - - def __init__(self, db: DBSession): - self.db = db - - # ==================== CREATE / GET-OR-CREATE ==================== - - def get_or_create(self, user_id: str = None) -> AlertProfileDB: - """Holt oder erstellt ein Profil.""" - profile = self.get_by_user_id(user_id) - if profile: - return profile - - # Neues Profil erstellen - import uuid - profile = AlertProfileDB( - id=str(uuid.uuid4()), - user_id=user_id, - name="Default" if not user_id else f"Profile {user_id[:8]}", - ) - self.db.add(profile) - self.db.commit() - self.db.refresh(profile) - return profile - - def create_default_education_profile(self, user_id: str = None) -> AlertProfileDB: - """Erstellt ein Standard-Profil für Bildungsthemen.""" - import uuid - profile = AlertProfileDB( - id=str(uuid.uuid4()), - user_id=user_id, - name="Bildung Default", - priorities=[ - { - "label": "Inklusion", - "weight": 0.9, - "keywords": ["inklusiv", "Förderbedarf", "Behinderung", "Barrierefreiheit"], - "description": "Inklusive Bildung, Förderschulen, Nachteilsausgleich" - }, - { - "label": "Datenschutz Schule", - "weight": 0.85, - "keywords": ["DSGVO", "Schülerfotos", "Einwilligung", "personenbezogene Daten"], - "description": "DSGVO in Schulen, Datenschutz bei Klassenfotos" - }, - { - "label": "Schulrecht Bayern", - "weight": 0.8, - "keywords": ["BayEUG", "Schulordnung", "Kultusministerium", "Bayern"], - "description": "Bayerisches Schulrecht, Verordnungen" - }, - { - "label": "Digitalisierung Schule", - "weight": 0.7, - "keywords": ["DigitalPakt", "Tablet-Klasse", "Lernplattform"], - "description": "Digitale Medien im Unterricht" - }, - ], - exclusions=["Stellenanzeige", "Praktikum gesucht", "Werbung", "Pressemitteilung"], - policies={ - "prefer_german_sources": True, - "max_age_days": 30, - "min_content_length": 100, - } - ) - self.db.add(profile) - self.db.commit() - self.db.refresh(profile) - return profile - - # ==================== READ ==================== - - def get_by_id(self, profile_id: str) -> Optional[AlertProfileDB]: - """Holt ein Profil nach ID.""" - return self.db.query(AlertProfileDB).filter( - AlertProfileDB.id == profile_id - ).first() - - def get_by_user_id(self, user_id: str) -> Optional[AlertProfileDB]: - """Holt ein Profil nach User-ID.""" - if not user_id: - # Default-Profil ohne User - return self.db.query(AlertProfileDB).filter( - AlertProfileDB.user_id.is_(None) - ).first() - - return self.db.query(AlertProfileDB).filter( - AlertProfileDB.user_id == user_id - ).first() - - # ==================== UPDATE ==================== - - def update_priorities( - self, - profile_id: str, - priorities: List[Dict], - ) -> Optional[AlertProfileDB]: - """Aktualisiert die Prioritäten eines Profils.""" - profile = self.get_by_id(profile_id) - if not profile: - return None - - profile.priorities = priorities - self.db.commit() - self.db.refresh(profile) - return profile - - def update_exclusions( - self, - profile_id: str, - exclusions: List[str], - ) -> Optional[AlertProfileDB]: - """Aktualisiert die Ausschlüsse eines Profils.""" - profile = self.get_by_id(profile_id) - if not profile: - return None - - profile.exclusions = exclusions - self.db.commit() - self.db.refresh(profile) - return profile - - def add_feedback( - self, - profile_id: str, - title: str, - url: str, - is_relevant: bool, - reason: str = "", - ) -> Optional[AlertProfileDB]: - """Fügt Feedback als Beispiel hinzu.""" - profile = self.get_by_id(profile_id) - if not profile: - return None - - example = { - "title": title, - "url": url, - "reason": reason, - "added_at": datetime.utcnow().isoformat(), - } - - if is_relevant: - examples = list(profile.positive_examples or []) - examples.append(example) - profile.positive_examples = examples[-20:] # Max 20 - profile.total_kept += 1 - flag_modified(profile, "positive_examples") - else: - examples = list(profile.negative_examples or []) - examples.append(example) - profile.negative_examples = examples[-20:] # Max 20 - profile.total_dropped += 1 - flag_modified(profile, "negative_examples") - - profile.total_scored += 1 - self.db.commit() - self.db.refresh(profile) - return profile - - def update_stats( - self, - profile_id: str, - kept: int = 0, - dropped: int = 0, - ) -> Optional[AlertProfileDB]: - """Aktualisiert die Statistiken eines Profils.""" - profile = self.get_by_id(profile_id) - if not profile: - return None - - profile.total_scored += kept + dropped - profile.total_kept += kept - profile.total_dropped += dropped - - self.db.commit() - self.db.refresh(profile) - return profile - - # ==================== DELETE ==================== - - def delete(self, profile_id: str) -> bool: - """Löscht ein Profil.""" - profile = self.get_by_id(profile_id) - if not profile: - return False - - self.db.delete(profile) - self.db.commit() - return True - - # ==================== CONVERSION ==================== - - def to_dict(self, profile: AlertProfileDB) -> Dict[str, Any]: - """Konvertiert DB-Model zu Dictionary.""" - return { - "id": profile.id, - "user_id": profile.user_id, - "name": profile.name, - "priorities": profile.priorities, - "exclusions": profile.exclusions, - "policies": profile.policies, - "examples": { - "positive": len(profile.positive_examples or []), - "negative": len(profile.negative_examples or []), - }, - "stats": { - "total_scored": profile.total_scored, - "total_kept": profile.total_kept, - "total_dropped": profile.total_dropped, - "accuracy_estimate": profile.accuracy_estimate, - }, - "created_at": profile.created_at.isoformat() if profile.created_at else None, - "updated_at": profile.updated_at.isoformat() if profile.updated_at else None, - } +from .topic_repository import TopicRepository +from .item_repository import AlertItemRepository +from .rule_repository import RuleRepository +from .profile_repository import ProfileRepository + +__all__ = [ + "TopicRepository", + "AlertItemRepository", + "RuleRepository", + "ProfileRepository", +] diff --git a/backend-lehrer/alerts_agent/db/rule_repository.py b/backend-lehrer/alerts_agent/db/rule_repository.py new file mode 100644 index 0000000..f969ee1 --- /dev/null +++ b/backend-lehrer/alerts_agent/db/rule_repository.py @@ -0,0 +1,187 @@ +""" +Repository für Alert Rules (Filterregeln). +""" +import uuid +from datetime import datetime +from typing import Optional, List, Dict, Any +from sqlalchemy.orm import Session as DBSession +from sqlalchemy import or_ + +from .models import AlertRuleDB, RuleActionEnum + + +class RuleRepository: + """Repository für Alert Rules (Filterregeln).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create( + self, + name: str, + conditions: List[Dict], + action_type: str = "keep", + action_config: Dict = None, + topic_id: str = None, + user_id: str = None, + description: str = "", + priority: int = 0, + ) -> AlertRuleDB: + """Erstellt eine neue Regel.""" + rule = AlertRuleDB( + id=str(uuid.uuid4()), + topic_id=topic_id, + user_id=user_id, + name=name, + description=description, + conditions=conditions, + action_type=RuleActionEnum(action_type), + action_config=action_config or {}, + priority=priority, + ) + self.db.add(rule) + self.db.commit() + self.db.refresh(rule) + return rule + + # ==================== READ ==================== + + def get_by_id(self, rule_id: str) -> Optional[AlertRuleDB]: + """Holt eine Regel nach ID.""" + return self.db.query(AlertRuleDB).filter( + AlertRuleDB.id == rule_id + ).first() + + def get_active( + self, + topic_id: str = None, + user_id: str = None, + ) -> List[AlertRuleDB]: + """Holt alle aktiven Regeln, sortiert nach Priorität.""" + query = self.db.query(AlertRuleDB).filter( + AlertRuleDB.is_active == True + ) + + if topic_id: + # Topic-spezifische und globale Regeln + query = query.filter( + or_( + AlertRuleDB.topic_id == topic_id, + AlertRuleDB.topic_id.is_(None) + ) + ) + + if user_id: + query = query.filter( + or_( + AlertRuleDB.user_id == user_id, + AlertRuleDB.user_id.is_(None) + ) + ) + + return query.order_by(AlertRuleDB.priority.desc()).all() + + def get_all( + self, + user_id: str = None, + topic_id: str = None, + is_active: bool = None, + ) -> List[AlertRuleDB]: + """Holt alle Regeln mit optionalen Filtern.""" + query = self.db.query(AlertRuleDB) + + if user_id: + query = query.filter(AlertRuleDB.user_id == user_id) + if topic_id: + query = query.filter(AlertRuleDB.topic_id == topic_id) + if is_active is not None: + query = query.filter(AlertRuleDB.is_active == is_active) + + return query.order_by(AlertRuleDB.priority.desc()).all() + + # ==================== UPDATE ==================== + + def update( + self, + rule_id: str, + name: str = None, + description: str = None, + conditions: List[Dict] = None, + action_type: str = None, + action_config: Dict = None, + priority: int = None, + is_active: bool = None, + ) -> Optional[AlertRuleDB]: + """Aktualisiert eine Regel.""" + rule = self.get_by_id(rule_id) + if not rule: + return None + + if name is not None: + rule.name = name + if description is not None: + rule.description = description + if conditions is not None: + rule.conditions = conditions + if action_type is not None: + rule.action_type = RuleActionEnum(action_type) + if action_config is not None: + rule.action_config = action_config + if priority is not None: + rule.priority = priority + if is_active is not None: + rule.is_active = is_active + + self.db.commit() + self.db.refresh(rule) + return rule + + def increment_match_count(self, rule_id: str) -> Optional[AlertRuleDB]: + """Erhöht den Match-Counter einer Regel.""" + rule = self.get_by_id(rule_id) + if not rule: + return None + + rule.match_count += 1 + rule.last_matched_at = datetime.utcnow() + + self.db.commit() + self.db.refresh(rule) + return rule + + # ==================== DELETE ==================== + + def delete(self, rule_id: str) -> bool: + """Löscht eine Regel.""" + rule = self.get_by_id(rule_id) + if not rule: + return False + + self.db.delete(rule) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dict(self, rule: AlertRuleDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": rule.id, + "topic_id": rule.topic_id, + "user_id": rule.user_id, + "name": rule.name, + "description": rule.description, + "conditions": rule.conditions, + "action_type": rule.action_type.value, + "action_config": rule.action_config, + "priority": rule.priority, + "is_active": rule.is_active, + "stats": { + "match_count": rule.match_count, + "last_matched_at": rule.last_matched_at.isoformat() if rule.last_matched_at else None, + }, + "created_at": rule.created_at.isoformat() if rule.created_at else None, + "updated_at": rule.updated_at.isoformat() if rule.updated_at else None, + } diff --git a/backend-lehrer/alerts_agent/db/topic_repository.py b/backend-lehrer/alerts_agent/db/topic_repository.py new file mode 100644 index 0000000..77c5bab --- /dev/null +++ b/backend-lehrer/alerts_agent/db/topic_repository.py @@ -0,0 +1,185 @@ +""" +Repository für Alert Topics (Feed-Quellen). +""" +import uuid +from datetime import datetime +from typing import Optional, List, Dict, Any +from sqlalchemy.orm import Session as DBSession + +from .models import AlertTopicDB, FeedTypeEnum + + +class TopicRepository: + """Repository für Alert Topics (Feed-Quellen).""" + + def __init__(self, db: DBSession): + self.db = db + + # ==================== CREATE ==================== + + def create( + self, + name: str, + feed_url: str = None, + feed_type: str = "rss", + user_id: str = None, + description: str = "", + fetch_interval_minutes: int = 60, + is_active: bool = True, + ) -> AlertTopicDB: + """Erstellt ein neues Topic.""" + topic = AlertTopicDB( + id=str(uuid.uuid4()), + user_id=user_id, + name=name, + description=description, + feed_url=feed_url, + feed_type=FeedTypeEnum(feed_type), + fetch_interval_minutes=fetch_interval_minutes, + is_active=is_active, + ) + self.db.add(topic) + self.db.commit() + self.db.refresh(topic) + return topic + + # ==================== READ ==================== + + def get_by_id(self, topic_id: str) -> Optional[AlertTopicDB]: + """Holt ein Topic nach ID.""" + return self.db.query(AlertTopicDB).filter( + AlertTopicDB.id == topic_id + ).first() + + def get_all( + self, + user_id: str = None, + is_active: bool = None, + limit: int = 100, + offset: int = 0, + ) -> List[AlertTopicDB]: + """Holt alle Topics mit optionalen Filtern.""" + query = self.db.query(AlertTopicDB) + + if user_id: + query = query.filter(AlertTopicDB.user_id == user_id) + if is_active is not None: + query = query.filter(AlertTopicDB.is_active == is_active) + + return query.order_by( + AlertTopicDB.created_at.desc() + ).offset(offset).limit(limit).all() + + def get_active_for_fetch(self) -> List[AlertTopicDB]: + """Holt alle aktiven Topics die gefetcht werden sollten.""" + return self.db.query(AlertTopicDB).filter( + AlertTopicDB.is_active == True, + AlertTopicDB.feed_url.isnot(None), + ).all() + + # ==================== UPDATE ==================== + + def update( + self, + topic_id: str, + name: str = None, + description: str = None, + feed_url: str = None, + feed_type: str = None, + is_active: bool = None, + fetch_interval_minutes: int = None, + ) -> Optional[AlertTopicDB]: + """Aktualisiert ein Topic.""" + topic = self.get_by_id(topic_id) + if not topic: + return None + + if name is not None: + topic.name = name + if description is not None: + topic.description = description + if feed_url is not None: + topic.feed_url = feed_url + if feed_type is not None: + topic.feed_type = FeedTypeEnum(feed_type) + if is_active is not None: + topic.is_active = is_active + if fetch_interval_minutes is not None: + topic.fetch_interval_minutes = fetch_interval_minutes + + self.db.commit() + self.db.refresh(topic) + return topic + + def update_fetch_status( + self, + topic_id: str, + last_fetch_error: str = None, + items_fetched: int = 0, + ) -> Optional[AlertTopicDB]: + """Aktualisiert den Fetch-Status eines Topics.""" + topic = self.get_by_id(topic_id) + if not topic: + return None + + topic.last_fetched_at = datetime.utcnow() + topic.last_fetch_error = last_fetch_error + topic.total_items_fetched += items_fetched + + self.db.commit() + self.db.refresh(topic) + return topic + + def increment_stats( + self, + topic_id: str, + kept: int = 0, + dropped: int = 0, + ) -> Optional[AlertTopicDB]: + """Erhöht die Statistiken eines Topics.""" + topic = self.get_by_id(topic_id) + if not topic: + return None + + topic.items_kept += kept + topic.items_dropped += dropped + + self.db.commit() + self.db.refresh(topic) + return topic + + # ==================== DELETE ==================== + + def delete(self, topic_id: str) -> bool: + """Löscht ein Topic (und alle zugehörigen Items via CASCADE).""" + topic = self.get_by_id(topic_id) + if not topic: + return False + + self.db.delete(topic) + self.db.commit() + return True + + # ==================== CONVERSION ==================== + + def to_dict(self, topic: AlertTopicDB) -> Dict[str, Any]: + """Konvertiert DB-Model zu Dictionary.""" + return { + "id": topic.id, + "user_id": topic.user_id, + "name": topic.name, + "description": topic.description, + "feed_url": topic.feed_url, + "feed_type": topic.feed_type.value, + "is_active": topic.is_active, + "fetch_interval_minutes": topic.fetch_interval_minutes, + "last_fetched_at": topic.last_fetched_at.isoformat() if topic.last_fetched_at else None, + "last_fetch_error": topic.last_fetch_error, + "stats": { + "total_items_fetched": topic.total_items_fetched, + "items_kept": topic.items_kept, + "items_dropped": topic.items_dropped, + }, + "created_at": topic.created_at.isoformat() if topic.created_at else None, + "updated_at": topic.updated_at.isoformat() if topic.updated_at else None, + } diff --git a/backend-lehrer/services/pdf_models.py b/backend-lehrer/services/pdf_models.py new file mode 100644 index 0000000..6964d03 --- /dev/null +++ b/backend-lehrer/services/pdf_models.py @@ -0,0 +1,84 @@ +""" +PDF Service - Data Models and Shared Types. + +Dataclasses for letters, certificates, and corrections. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, List + + +@dataclass +class SchoolInfo: + """Schulinformationen für Header.""" + name: str + address: str + phone: str + email: str + logo_path: Optional[str] = None + website: Optional[str] = None + principal: Optional[str] = None + + +@dataclass +class LetterData: + """Daten für Elternbrief-PDF.""" + recipient_name: str + recipient_address: str + student_name: str + student_class: str + subject: str + content: str + date: str + teacher_name: str + teacher_title: Optional[str] = None + school_info: Optional[SchoolInfo] = None + letter_type: str = "general" # general, halbjahr, fehlzeiten, elternabend, lob + tone: str = "professional" + legal_references: Optional[List[Dict[str, str]]] = None + gfk_principles_applied: Optional[List[str]] = None + + +@dataclass +class CertificateData: + """Daten für Zeugnis-PDF.""" + student_name: str + student_birthdate: str + student_class: str + school_year: str + certificate_type: str # halbjahr, jahres, abschluss + subjects: List[Dict[str, Any]] # [{name, grade, note}] + attendance: Dict[str, int] # {days_absent, days_excused, days_unexcused} + remarks: Optional[str] = None + class_teacher: str = "" + principal: str = "" + school_info: Optional[SchoolInfo] = None + issue_date: str = "" + social_behavior: Optional[str] = None # A, B, C, D + work_behavior: Optional[str] = None # A, B, C, D + + +@dataclass +class StudentInfo: + """Schülerinformationen für Korrektur-PDFs.""" + student_id: str + name: str + class_name: str + + +@dataclass +class CorrectionData: + """Daten für Korrektur-Übersicht PDF.""" + student: StudentInfo + exam_title: str + subject: str + date: str + max_points: int + achieved_points: int + grade: str + percentage: float + corrections: List[Dict[str, Any]] # [{question, answer, points, feedback}] + teacher_notes: str = "" + ai_feedback: str = "" + grade_distribution: Optional[Dict[str, int]] = None # {note: anzahl} + class_average: Optional[float] = None diff --git a/backend-lehrer/services/pdf_service.py b/backend-lehrer/services/pdf_service.py index 9559964..356b5b4 100644 --- a/backend-lehrer/services/pdf_service.py +++ b/backend-lehrer/services/pdf_service.py @@ -7,101 +7,37 @@ Shared Service für: - Correction (Korrektur-Übersichten) Verwendet WeasyPrint für PDF-Rendering und Jinja2 für Templates. + +Split structure: +- pdf_models.py: Data classes (SchoolInfo, LetterData, CertificateData, etc.) +- pdf_templates.py: Inline HTML templates (letter, certificate, correction) +- pdf_service.py: Core PDFService class + convenience functions (this file) """ import logging -import os from datetime import datetime from pathlib import Path -from typing import Any, Dict, Optional, List -from dataclasses import dataclass +from typing import Any, Dict, Optional from jinja2 import Environment, FileSystemLoader, select_autoescape from weasyprint import HTML, CSS from weasyprint.text.fonts import FontConfiguration +from .pdf_models import ( + SchoolInfo, LetterData, CertificateData, StudentInfo, CorrectionData, +) +from .pdf_templates import ( + get_letter_template_html, + get_certificate_template_html, + get_correction_template_html, +) + logger = logging.getLogger(__name__) # Template directory TEMPLATES_DIR = Path(__file__).parent.parent / "templates" / "pdf" -@dataclass -class SchoolInfo: - """Schulinformationen für Header.""" - name: str - address: str - phone: str - email: str - logo_path: Optional[str] = None - website: Optional[str] = None - principal: Optional[str] = None - - -@dataclass -class LetterData: - """Daten für Elternbrief-PDF.""" - recipient_name: str - recipient_address: str - student_name: str - student_class: str - subject: str - content: str - date: str - teacher_name: str - teacher_title: Optional[str] = None - school_info: Optional[SchoolInfo] = None - letter_type: str = "general" # general, halbjahr, fehlzeiten, elternabend, lob - tone: str = "professional" - legal_references: Optional[List[Dict[str, str]]] = None - gfk_principles_applied: Optional[List[str]] = None - - -@dataclass -class CertificateData: - """Daten für Zeugnis-PDF.""" - student_name: str - student_birthdate: str - student_class: str - school_year: str - certificate_type: str # halbjahr, jahres, abschluss - subjects: List[Dict[str, Any]] # [{name, grade, note}] - attendance: Dict[str, int] # {days_absent, days_excused, days_unexcused} - remarks: Optional[str] = None - class_teacher: str = "" - principal: str = "" - school_info: Optional[SchoolInfo] = None - issue_date: str = "" - social_behavior: Optional[str] = None # A, B, C, D - work_behavior: Optional[str] = None # A, B, C, D - - -@dataclass -class StudentInfo: - """Schülerinformationen für Korrektur-PDFs.""" - student_id: str - name: str - class_name: str - - -@dataclass -class CorrectionData: - """Daten für Korrektur-Übersicht PDF.""" - student: StudentInfo - exam_title: str - subject: str - date: str - max_points: int - achieved_points: int - grade: str - percentage: float - corrections: List[Dict[str, Any]] # [{question, answer, points, feedback}] - teacher_notes: str = "" - ai_feedback: str = "" - grade_distribution: Optional[Dict[str, int]] = None # {note: anzahl} - class_average: Optional[float] = None - - class PDFService: """ Zentrale PDF-Generierung für BreakPilot. @@ -113,18 +49,9 @@ class PDFService: """ def __init__(self, templates_dir: Optional[Path] = None): - """ - Initialisiert den PDF-Service. - - Args: - templates_dir: Optionaler Pfad zu Templates (Standard: backend/templates/pdf) - """ self.templates_dir = templates_dir or TEMPLATES_DIR - - # Ensure templates directory exists self.templates_dir.mkdir(parents=True, exist_ok=True) - # Initialize Jinja2 environment self.jinja_env = Environment( loader=FileSystemLoader(str(self.templates_dir)), autoescape=select_autoescape(['html', 'xml']), @@ -132,13 +59,10 @@ class PDFService: lstrip_blocks=True ) - # Add custom filters self.jinja_env.filters['date_format'] = self._date_format self.jinja_env.filters['grade_color'] = self._grade_color - # Font configuration for WeasyPrint self.font_config = FontConfiguration() - logger.info(f"PDFService initialized with templates from {self.templates_dir}") @staticmethod @@ -156,16 +80,9 @@ class PDFService: def _grade_color(grade: str) -> str: """Gibt Farbe basierend auf Note zurück.""" grade_colors = { - "1": "#27ae60", # Grün - "2": "#2ecc71", # Hellgrün - "3": "#f1c40f", # Gelb - "4": "#e67e22", # Orange - "5": "#e74c3c", # Rot - "6": "#c0392b", # Dunkelrot - "A": "#27ae60", - "B": "#2ecc71", - "C": "#f1c40f", - "D": "#e74c3c", + "1": "#27ae60", "2": "#2ecc71", "3": "#f1c40f", + "4": "#e67e22", "5": "#e74c3c", "6": "#c0392b", + "A": "#27ae60", "B": "#2ecc71", "C": "#f1c40f", "D": "#e74c3c", } return grade_colors.get(str(grade), "#333333") @@ -181,291 +98,73 @@ class PDFService: color: #666; } } - body { font-family: 'DejaVu Sans', 'Liberation Sans', Arial, sans-serif; - font-size: 11pt; - line-height: 1.5; - color: #333; - } - - h1, h2, h3 { - font-weight: bold; - margin-top: 1em; - margin-bottom: 0.5em; - } - - h1 { font-size: 16pt; } - h2 { font-size: 14pt; } - h3 { font-size: 12pt; } - - .header { - border-bottom: 2px solid #2c3e50; - padding-bottom: 15px; - margin-bottom: 20px; - } - - .school-name { - font-size: 18pt; - font-weight: bold; - color: #2c3e50; - } - - .school-info { - font-size: 9pt; - color: #666; - } - - .letter-date { - text-align: right; - margin-bottom: 20px; - } - - .recipient { - margin-bottom: 30px; - } - - .subject { - font-weight: bold; - margin-bottom: 20px; - } - - .content { - text-align: justify; - margin-bottom: 30px; - } - - .signature { - margin-top: 40px; - } - - .legal-references { - font-size: 9pt; - color: #666; - border-top: 1px solid #ddd; - margin-top: 30px; - padding-top: 10px; - } - - .gfk-badge { - display: inline-block; - background: #e8f5e9; - color: #27ae60; - font-size: 8pt; - padding: 2px 8px; - border-radius: 10px; - margin-right: 5px; - } - - /* Zeugnis-Styles */ - .certificate-header { - text-align: center; - margin-bottom: 30px; - } - - .certificate-title { - font-size: 20pt; - font-weight: bold; - margin-bottom: 10px; - } - - .student-info { - margin-bottom: 20px; - padding: 15px; - background: #f9f9f9; - border-radius: 5px; - } - - .grades-table { - width: 100%; - border-collapse: collapse; - margin-bottom: 20px; - } - - .grades-table th, - .grades-table td { - border: 1px solid #ddd; - padding: 8px 12px; - text-align: left; - } - - .grades-table th { - background: #2c3e50; - color: white; - } - - .grades-table tr:nth-child(even) { - background: #f9f9f9; - } - - .grade-cell { - text-align: center; - font-weight: bold; - font-size: 12pt; - } - - .attendance-box { - background: #fff3cd; - padding: 15px; - border-radius: 5px; - margin-bottom: 20px; - } - - .signatures-row { - display: flex; - justify-content: space-between; - margin-top: 50px; - } - - .signature-block { - text-align: center; - width: 40%; - } - - .signature-line { - border-top: 1px solid #333; - margin-top: 40px; - padding-top: 5px; - } - - /* Korrektur-Styles */ - .exam-header { - background: #2c3e50; - color: white; - padding: 15px; - margin-bottom: 20px; - } - - .result-box { - background: #e8f5e9; - padding: 20px; - text-align: center; - margin-bottom: 20px; - border-radius: 5px; - } - - .result-grade { - font-size: 36pt; - font-weight: bold; - } - - .result-points { - font-size: 14pt; - color: #666; - } - - .corrections-list { - margin-bottom: 20px; - } - - .correction-item { - border: 1px solid #ddd; - padding: 15px; - margin-bottom: 10px; - border-radius: 5px; - } - - .correction-question { - font-weight: bold; - margin-bottom: 5px; - } - - .correction-feedback { - background: #fff8e1; - padding: 10px; - margin-top: 10px; - border-left: 3px solid #ffc107; - font-size: 10pt; - } - - .stats-table { - width: 100%; - margin-top: 20px; - } - - .stats-table td { - padding: 5px 10px; + font-size: 11pt; line-height: 1.5; color: #333; } + h1, h2, h3 { font-weight: bold; margin-top: 1em; margin-bottom: 0.5em; } + h1 { font-size: 16pt; } h2 { font-size: 14pt; } h3 { font-size: 12pt; } + .header { border-bottom: 2px solid #2c3e50; padding-bottom: 15px; margin-bottom: 20px; } + .school-name { font-size: 18pt; font-weight: bold; color: #2c3e50; } + .school-info { font-size: 9pt; color: #666; } + .letter-date { text-align: right; margin-bottom: 20px; } + .recipient { margin-bottom: 30px; } + .subject { font-weight: bold; margin-bottom: 20px; } + .content { text-align: justify; margin-bottom: 30px; } + .signature { margin-top: 40px; } + .legal-references { font-size: 9pt; color: #666; border-top: 1px solid #ddd; margin-top: 30px; padding-top: 10px; } + .gfk-badge { display: inline-block; background: #e8f5e9; color: #27ae60; font-size: 8pt; padding: 2px 8px; border-radius: 10px; margin-right: 5px; } + .certificate-header { text-align: center; margin-bottom: 30px; } + .certificate-title { font-size: 20pt; font-weight: bold; margin-bottom: 10px; } + .student-info { margin-bottom: 20px; padding: 15px; background: #f9f9f9; border-radius: 5px; } + .grades-table { width: 100%; border-collapse: collapse; margin-bottom: 20px; } + .grades-table th, .grades-table td { border: 1px solid #ddd; padding: 8px 12px; text-align: left; } + .grades-table th { background: #2c3e50; color: white; } + .grades-table tr:nth-child(even) { background: #f9f9f9; } + .grade-cell { text-align: center; font-weight: bold; font-size: 12pt; } + .attendance-box { background: #fff3cd; padding: 15px; border-radius: 5px; margin-bottom: 20px; } + .signatures-row { display: flex; justify-content: space-between; margin-top: 50px; } + .signature-block { text-align: center; width: 40%; } + .signature-line { border-top: 1px solid #333; margin-top: 40px; padding-top: 5px; } + .exam-header { background: #2c3e50; color: white; padding: 15px; margin-bottom: 20px; } + .result-box { background: #e8f5e9; padding: 20px; text-align: center; margin-bottom: 20px; border-radius: 5px; } + .result-grade { font-size: 36pt; font-weight: bold; } + .result-points { font-size: 14pt; color: #666; } + .corrections-list { margin-bottom: 20px; } + .correction-item { border: 1px solid #ddd; padding: 15px; margin-bottom: 10px; border-radius: 5px; } + .correction-question { font-weight: bold; margin-bottom: 5px; } + .correction-feedback { background: #fff8e1; padding: 10px; margin-top: 10px; border-left: 3px solid #ffc107; font-size: 10pt; } + .stats-table { width: 100%; margin-top: 20px; } + .stats-table td { padding: 5px 10px; } """ def generate_letter_pdf(self, data: LetterData) -> bytes: - """ - Generiert PDF für Elternbrief. - - Args: - data: LetterData mit allen Briefinformationen - - Returns: - PDF als bytes - """ + """Generiert PDF für Elternbrief.""" logger.info(f"Generating letter PDF for student: {data.student_name}") - template = self._get_letter_template() - html_content = template.render( - data=data, - generated_at=datetime.now().strftime("%d.%m.%Y %H:%M") - ) - + html_content = template.render(data=data, generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")) css = CSS(string=self._get_base_css(), font_config=self.font_config) - pdf_bytes = HTML(string=html_content).write_pdf( - stylesheets=[css], - font_config=self.font_config - ) - + pdf_bytes = HTML(string=html_content).write_pdf(stylesheets=[css], font_config=self.font_config) logger.info(f"Letter PDF generated: {len(pdf_bytes)} bytes") return pdf_bytes def generate_certificate_pdf(self, data: CertificateData) -> bytes: - """ - Generiert PDF für Schulzeugnis. - - Args: - data: CertificateData mit allen Zeugnisinformationen - - Returns: - PDF als bytes - """ + """Generiert PDF für Schulzeugnis.""" logger.info(f"Generating certificate PDF for: {data.student_name}") - template = self._get_certificate_template() - html_content = template.render( - data=data, - generated_at=datetime.now().strftime("%d.%m.%Y %H:%M") - ) - + html_content = template.render(data=data, generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")) css = CSS(string=self._get_base_css(), font_config=self.font_config) - pdf_bytes = HTML(string=html_content).write_pdf( - stylesheets=[css], - font_config=self.font_config - ) - + pdf_bytes = HTML(string=html_content).write_pdf(stylesheets=[css], font_config=self.font_config) logger.info(f"Certificate PDF generated: {len(pdf_bytes)} bytes") return pdf_bytes def generate_correction_pdf(self, data: CorrectionData) -> bytes: - """ - Generiert PDF für Korrektur-Übersicht. - - Args: - data: CorrectionData mit allen Korrekturinformationen - - Returns: - PDF als bytes - """ + """Generiert PDF für Korrektur-Übersicht.""" logger.info(f"Generating correction PDF for: {data.student.name}") - template = self._get_correction_template() - html_content = template.render( - data=data, - generated_at=datetime.now().strftime("%d.%m.%Y %H:%M") - ) - + html_content = template.render(data=data, generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")) css = CSS(string=self._get_base_css(), font_config=self.font_config) - pdf_bytes = HTML(string=html_content).write_pdf( - stylesheets=[css], - font_config=self.font_config - ) - + pdf_bytes = HTML(string=html_content).write_pdf(stylesheets=[css], font_config=self.font_config) logger.info(f"Correction PDF generated: {len(pdf_bytes)} bytes") return pdf_bytes @@ -474,321 +173,27 @@ class PDFService: template_path = self.templates_dir / "letter.html" if template_path.exists(): return self.jinja_env.get_template("letter.html") - - # Inline-Template als Fallback - return self.jinja_env.from_string(self._get_letter_template_html()) + return self.jinja_env.from_string(get_letter_template_html()) def _get_certificate_template(self): """Gibt Certificate-Template zurück.""" template_path = self.templates_dir / "certificate.html" if template_path.exists(): return self.jinja_env.get_template("certificate.html") - - return self.jinja_env.from_string(self._get_certificate_template_html()) + return self.jinja_env.from_string(get_certificate_template_html()) def _get_correction_template(self): """Gibt Correction-Template zurück.""" template_path = self.templates_dir / "correction.html" if template_path.exists(): return self.jinja_env.get_template("correction.html") - - return self.jinja_env.from_string(self._get_correction_template_html()) - - @staticmethod - def _get_letter_template_html() -> str: - """Inline HTML-Template für Elternbriefe.""" - return """ - - - - - {{ data.subject }} - - -
- {% if data.school_info %} -
{{ data.school_info.name }}
-
- {{ data.school_info.address }}
- Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }} - {% if data.school_info.website %} | {{ data.school_info.website }}{% endif %} -
- {% else %} -
Schule
- {% endif %} -
- -
- {{ data.date }} -
- -
- {{ data.recipient_name }}
- {{ data.recipient_address | replace('\\n', '
') | safe }} -
- -
- Betreff: {{ data.subject }} -
- -
- Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }} -
- -
- {{ data.content | replace('\\n', '
') | safe }} -
- - {% if data.gfk_principles_applied %} -
- {% for principle in data.gfk_principles_applied %} - ✓ {{ principle }} - {% endfor %} -
- {% endif %} - -
-

Mit freundlichen Grüßen

-

- {{ data.teacher_name }} - {% if data.teacher_title %}
{{ data.teacher_title }}{% endif %} -

-
- - {% if data.legal_references %} - - {% endif %} - -
- Erstellt mit BreakPilot | {{ generated_at }} -
- - -""" - - @staticmethod - def _get_certificate_template_html() -> str: - """Inline HTML-Template für Zeugnisse.""" - return """ - - - - - Zeugnis - {{ data.student_name }} - - -
- {% if data.school_info %} -
{{ data.school_info.name }}
- {% endif %} -
- {% if data.certificate_type == 'halbjahr' %} - Halbjahreszeugnis - {% elif data.certificate_type == 'jahres' %} - Jahreszeugnis - {% else %} - Abschlusszeugnis - {% endif %} -
-
Schuljahr {{ data.school_year }}
-
- -
- - - - - - - - - -
Name: {{ data.student_name }}Geburtsdatum: {{ data.student_birthdate }}
Klasse: {{ data.student_class }} 
-
- -

Leistungen

- - - - - - - - - - {% for subject in data.subjects %} - - - - - - {% endfor %} - -
FachNotePunkte
{{ subject.name }} - {{ subject.grade }} - {{ subject.points | default('-') }}
- - {% if data.social_behavior or data.work_behavior %} -

Verhalten

- - {% if data.social_behavior %} - - - - - {% endif %} - {% if data.work_behavior %} - - - - - {% endif %} -
Sozialverhalten{{ data.social_behavior }}
Arbeitsverhalten{{ data.work_behavior }}
- {% endif %} - -
- Versäumte Tage: {{ data.attendance.days_absent | default(0) }} - (davon entschuldigt: {{ data.attendance.days_excused | default(0) }}, - unentschuldigt: {{ data.attendance.days_unexcused | default(0) }}) -
- - {% if data.remarks %} -
- Bemerkungen:
- {{ data.remarks }} -
- {% endif %} - -
- Ausgestellt am: {{ data.issue_date }} -
- -
-
-
{{ data.class_teacher }}
-
Klassenlehrer/in
-
-
-
{{ data.principal }}
-
Schulleiter/in
-
-
- -
-
Siegel der Schule
-
- - -""" - - @staticmethod - def _get_correction_template_html() -> str: - """Inline HTML-Template für Korrektur-Übersichten.""" - return """ - - - - - Korrektur - {{ data.exam_title }} - - -
-

{{ data.exam_title }}

-
{{ data.subject }} | {{ data.date }}
-
- -
- {{ data.student.name }} | Klasse {{ data.student.class_name }} -
- -
-
- Note: {{ data.grade }} -
-
- {{ data.achieved_points }} von {{ data.max_points }} Punkten - ({{ data.percentage | round(1) }}%) -
-
- -

Detaillierte Auswertung

-
- {% for item in data.corrections %} -
-
- {{ item.question }} -
- {% if item.answer %} -
- Antwort: {{ item.answer }} -
- {% endif %} -
- Punkte: {{ item.points }} -
- {% if item.feedback %} -
- {{ item.feedback }} -
- {% endif %} -
- {% endfor %} -
- - {% if data.teacher_notes %} -
- Lehrerkommentar:
- {{ data.teacher_notes }} -
- {% endif %} - - {% if data.ai_feedback %} -
- KI-Feedback:
- {{ data.ai_feedback }} -
- {% endif %} - - {% if data.class_average or data.grade_distribution %} -

Klassenstatistik

- - {% if data.class_average %} - - - - - {% endif %} - {% if data.grade_distribution %} - - - - - {% endif %} -
Klassendurchschnitt:{{ data.class_average }}
Notenverteilung: - {% for grade, count in data.grade_distribution.items() %} - Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %} - {% endfor %} -
- {% endif %} - -
-

Datum: {{ data.date }}

-
- -
- Erstellt mit BreakPilot | {{ generated_at }} -
- - -""" + return self.jinja_env.from_string(get_correction_template_html()) +# ============================================================================= # Convenience functions for direct usage +# ============================================================================= + _pdf_service: Optional[PDFService] = None @@ -801,18 +206,8 @@ def get_pdf_service() -> PDFService: def generate_letter_pdf(data: Dict[str, Any]) -> bytes: - """ - Convenience function zum Generieren eines Elternbrief-PDFs. - - Args: - data: Dict mit allen Briefdaten - - Returns: - PDF als bytes - """ + """Convenience function zum Generieren eines Elternbrief-PDFs.""" service = get_pdf_service() - - # Convert dict to LetterData school_info = None if data.get("school_info"): school_info = SchoolInfo(**data["school_info"]) @@ -833,22 +228,12 @@ def generate_letter_pdf(data: Dict[str, Any]) -> bytes: legal_references=data.get("legal_references"), gfk_principles_applied=data.get("gfk_principles_applied") ) - return service.generate_letter_pdf(letter_data) def generate_certificate_pdf(data: Dict[str, Any]) -> bytes: - """ - Convenience function zum Generieren eines Zeugnis-PDFs. - - Args: - data: Dict mit allen Zeugnisdaten - - Returns: - PDF als bytes - """ + """Convenience function zum Generieren eines Zeugnis-PDFs.""" service = get_pdf_service() - school_info = None if data.get("school_info"): school_info = SchoolInfo(**data["school_info"]) @@ -869,30 +254,19 @@ def generate_certificate_pdf(data: Dict[str, Any]) -> bytes: social_behavior=data.get("social_behavior"), work_behavior=data.get("work_behavior") ) - return service.generate_certificate_pdf(cert_data) def generate_correction_pdf(data: Dict[str, Any]) -> bytes: - """ - Convenience function zum Generieren eines Korrektur-PDFs. - - Args: - data: Dict mit allen Korrekturdaten - - Returns: - PDF als bytes - """ + """Convenience function zum Generieren eines Korrektur-PDFs.""" service = get_pdf_service() - # Create StudentInfo from dict student = StudentInfo( student_id=data.get("student_id", "unknown"), name=data.get("student_name", data.get("name", "")), class_name=data.get("student_class", data.get("class_name", "")) ) - # Calculate percentage if not provided max_points = data.get("max_points", data.get("total_points", 0)) achieved_points = data.get("achieved_points", 0) percentage = data.get("percentage", (achieved_points / max_points * 100) if max_points > 0 else 0.0) @@ -912,5 +286,4 @@ def generate_correction_pdf(data: Dict[str, Any]) -> bytes: grade_distribution=data.get("grade_distribution"), class_average=data.get("class_average") ) - return service.generate_correction_pdf(correction_data) diff --git a/backend-lehrer/services/pdf_templates.py b/backend-lehrer/services/pdf_templates.py new file mode 100644 index 0000000..9c8f924 --- /dev/null +++ b/backend-lehrer/services/pdf_templates.py @@ -0,0 +1,298 @@ +""" +PDF Service - Inline HTML Templates. + +Fallback templates when external template files don't exist. +""" + + +def get_letter_template_html() -> str: + """Inline HTML-Template für Elternbriefe.""" + return """ + + + + + {{ data.subject }} + + +
+ {% if data.school_info %} +
{{ data.school_info.name }}
+
+ {{ data.school_info.address }}
+ Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }} + {% if data.school_info.website %} | {{ data.school_info.website }}{% endif %} +
+ {% else %} +
Schule
+ {% endif %} +
+ +
+ {{ data.date }} +
+ +
+ {{ data.recipient_name }}
+ {{ data.recipient_address | replace('\\n', '
') | safe }} +
+ +
+ Betreff: {{ data.subject }} +
+ +
+ Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }} +
+ +
+ {{ data.content | replace('\\n', '
') | safe }} +
+ + {% if data.gfk_principles_applied %} +
+ {% for principle in data.gfk_principles_applied %} + ✓ {{ principle }} + {% endfor %} +
+ {% endif %} + +
+

Mit freundlichen Grüßen

+

+ {{ data.teacher_name }} + {% if data.teacher_title %}
{{ data.teacher_title }}{% endif %} +

+
+ + {% if data.legal_references %} + + {% endif %} + +
+ Erstellt mit BreakPilot | {{ generated_at }} +
+ + +""" + + +def get_certificate_template_html() -> str: + """Inline HTML-Template für Zeugnisse.""" + return """ + + + + + Zeugnis - {{ data.student_name }} + + +
+ {% if data.school_info %} +
{{ data.school_info.name }}
+ {% endif %} +
+ {% if data.certificate_type == 'halbjahr' %} + Halbjahreszeugnis + {% elif data.certificate_type == 'jahres' %} + Jahreszeugnis + {% else %} + Abschlusszeugnis + {% endif %} +
+
Schuljahr {{ data.school_year }}
+
+ +
+ + + + + + + + + +
Name: {{ data.student_name }}Geburtsdatum: {{ data.student_birthdate }}
Klasse: {{ data.student_class }} 
+
+ +

Leistungen

+ + + + + + + + + + {% for subject in data.subjects %} + + + + + + {% endfor %} + +
FachNotePunkte
{{ subject.name }} + {{ subject.grade }} + {{ subject.points | default('-') }}
+ + {% if data.social_behavior or data.work_behavior %} +

Verhalten

+ + {% if data.social_behavior %} + + + + + {% endif %} + {% if data.work_behavior %} + + + + + {% endif %} +
Sozialverhalten{{ data.social_behavior }}
Arbeitsverhalten{{ data.work_behavior }}
+ {% endif %} + +
+ Versäumte Tage: {{ data.attendance.days_absent | default(0) }} + (davon entschuldigt: {{ data.attendance.days_excused | default(0) }}, + unentschuldigt: {{ data.attendance.days_unexcused | default(0) }}) +
+ + {% if data.remarks %} +
+ Bemerkungen:
+ {{ data.remarks }} +
+ {% endif %} + +
+ Ausgestellt am: {{ data.issue_date }} +
+ +
+
+
{{ data.class_teacher }}
+
Klassenlehrer/in
+
+
+
{{ data.principal }}
+
Schulleiter/in
+
+
+ +
+
Siegel der Schule
+
+ + +""" + + +def get_correction_template_html() -> str: + """Inline HTML-Template für Korrektur-Übersichten.""" + return """ + + + + + Korrektur - {{ data.exam_title }} + + +
+

{{ data.exam_title }}

+
{{ data.subject }} | {{ data.date }}
+
+ +
+ {{ data.student.name }} | Klasse {{ data.student.class_name }} +
+ +
+
+ Note: {{ data.grade }} +
+
+ {{ data.achieved_points }} von {{ data.max_points }} Punkten + ({{ data.percentage | round(1) }}%) +
+
+ +

Detaillierte Auswertung

+
+ {% for item in data.corrections %} +
+
+ {{ item.question }} +
+ {% if item.answer %} +
+ Antwort: {{ item.answer }} +
+ {% endif %} +
+ Punkte: {{ item.points }} +
+ {% if item.feedback %} +
+ {{ item.feedback }} +
+ {% endif %} +
+ {% endfor %} +
+ + {% if data.teacher_notes %} +
+ Lehrerkommentar:
+ {{ data.teacher_notes }} +
+ {% endif %} + + {% if data.ai_feedback %} +
+ KI-Feedback:
+ {{ data.ai_feedback }} +
+ {% endif %} + + {% if data.class_average or data.grade_distribution %} +

Klassenstatistik

+ + {% if data.class_average %} + + + + + {% endif %} + {% if data.grade_distribution %} + + + + + {% endif %} +
Klassendurchschnitt:{{ data.class_average }}
Notenverteilung: + {% for grade, count in data.grade_distribution.items() %} + Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %} + {% endfor %} +
+ {% endif %} + +
+

Datum: {{ data.date }}

+
+ +
+ Erstellt mit BreakPilot | {{ generated_at }} +
+ + +""" diff --git a/backend-lehrer/teacher_dashboard_analytics.py b/backend-lehrer/teacher_dashboard_analytics.py new file mode 100644 index 0000000..4a84bca --- /dev/null +++ b/backend-lehrer/teacher_dashboard_analytics.py @@ -0,0 +1,267 @@ +# ============================================== +# Teacher Dashboard - Analytics & Progress Routes +# ============================================== + +from fastapi import APIRouter, HTTPException, Query, Depends, Request +from typing import List, Optional, Dict, Any +from datetime import datetime, timedelta +import logging + +from teacher_dashboard_models import ( + UnitAssignmentStatus, TeacherControlSettings, + UnitAssignment, StudentUnitProgress, ClassUnitProgress, + MisconceptionReport, ClassAnalyticsSummary, ContentResource, + get_current_teacher, get_teacher_database, + get_classes_for_teacher, get_students_in_class, + REQUIRE_AUTH, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Teacher Dashboard"]) + +# Shared in-memory store reference (set from teacher_dashboard_api) +_assignments_store: Dict[str, Dict[str, Any]] = {} + + +def set_assignments_store(store: Dict[str, Dict[str, Any]]): + """Share the in-memory assignments store from the main module.""" + global _assignments_store + _assignments_store = store + + +# ============================================== +# API Endpoints - Progress & Analytics +# ============================================== + +@router.get("/assignments/{assignment_id}/progress", response_model=ClassUnitProgress) +async def get_assignment_progress( + assignment_id: str, + teacher: Dict[str, Any] = Depends(get_current_teacher) +) -> ClassUnitProgress: + """Get detailed progress for an assignment.""" + db = await get_teacher_database() + assignment = None + if db: + try: + assignment = await db.get_assignment(assignment_id) + except Exception as e: + logger.error(f"Failed to get assignment: {e}") + if not assignment and assignment_id in _assignments_store: + assignment = _assignments_store[assignment_id] + if not assignment or assignment["teacher_id"] != teacher["user_id"]: + raise HTTPException(status_code=404, detail="Assignment not found") + + students = await get_students_in_class(assignment["class_id"]) + student_progress = [] + total_completion = 0.0 + total_precheck = 0.0 + total_postcheck = 0.0 + total_time = 0 + precheck_count = 0 + postcheck_count = 0 + started = 0 + completed = 0 + + for student in students: + student_id = student.get("id", student.get("student_id")) + progress = StudentUnitProgress( + student_id=student_id, + student_name=student.get("name", f"Student {student_id[:8]}"), + status="not_started", completion_rate=0.0, stops_completed=0, total_stops=0, + ) + if db: + try: + session_data = await db.get_student_unit_session( + student_id=student_id, unit_id=assignment["unit_id"] + ) + if session_data: + progress.session_id = session_data.get("session_id") + progress.status = "completed" if session_data.get("completed_at") else "in_progress" + progress.completion_rate = session_data.get("completion_rate", 0.0) + progress.precheck_score = session_data.get("precheck_score") + progress.postcheck_score = session_data.get("postcheck_score") + progress.time_spent_minutes = session_data.get("duration_seconds", 0) // 60 + progress.last_activity = session_data.get("updated_at") + progress.stops_completed = session_data.get("stops_completed", 0) + progress.total_stops = session_data.get("total_stops", 0) + if progress.precheck_score is not None and progress.postcheck_score is not None: + progress.learning_gain = progress.postcheck_score - progress.precheck_score + total_completion += progress.completion_rate + total_time += progress.time_spent_minutes + if progress.precheck_score is not None: + total_precheck += progress.precheck_score + precheck_count += 1 + if progress.postcheck_score is not None: + total_postcheck += progress.postcheck_score + postcheck_count += 1 + if progress.status != "not_started": + started += 1 + if progress.status == "completed": + completed += 1 + except Exception as e: + logger.error(f"Failed to get student progress: {e}") + student_progress.append(progress) + + total_students = len(students) or 1 + return ClassUnitProgress( + assignment_id=assignment_id, unit_id=assignment["unit_id"], + unit_title=f"Unit {assignment['unit_id']}", class_id=assignment["class_id"], + class_name=f"Class {assignment['class_id'][:8]}", total_students=len(students), + started_count=started, completed_count=completed, + avg_completion_rate=total_completion / total_students, + avg_precheck_score=total_precheck / precheck_count if precheck_count > 0 else None, + avg_postcheck_score=total_postcheck / postcheck_count if postcheck_count > 0 else None, + avg_learning_gain=(total_postcheck / postcheck_count - total_precheck / precheck_count) + if precheck_count > 0 and postcheck_count > 0 else None, + avg_time_minutes=total_time / started if started > 0 else 0, + students=student_progress, + ) + + +@router.get("/classes/{class_id}/analytics", response_model=ClassAnalyticsSummary) +async def get_class_analytics( + class_id: str, + teacher: Dict[str, Any] = Depends(get_current_teacher) +) -> ClassAnalyticsSummary: + """Get summary analytics for a class.""" + db = await get_teacher_database() + assignments = [] + if db: + try: + assignments = await db.list_assignments(teacher_id=teacher["user_id"], class_id=class_id) + except Exception as e: + logger.error(f"Failed to list assignments: {e}") + if not assignments: + assignments = [ + a for a in _assignments_store.values() + if a["class_id"] == class_id and a["teacher_id"] == teacher["user_id"] + ] + + total_units = len(assignments) + completed_units = sum(1 for a in assignments if a.get("status") == "completed") + active_units = sum(1 for a in assignments if a.get("status") == "active") + + students = await get_students_in_class(class_id) + student_scores = {} + misconceptions = [] + if db: + try: + for student in students: + student_id = student.get("id", student.get("student_id")) + analytics = await db.get_student_analytics(student_id) + if analytics: + student_scores[student_id] = { + "name": student.get("name", student_id[:8]), + "avg_score": analytics.get("avg_postcheck_score", 0), + "total_time": analytics.get("total_time_minutes", 0), + } + misconceptions_data = await db.get_class_misconceptions(class_id) + for m in misconceptions_data: + misconceptions.append(MisconceptionReport( + concept_id=m["concept_id"], concept_label=m["concept_label"], + misconception=m["misconception"], affected_students=m["affected_students"], + frequency=m["frequency"], unit_id=m["unit_id"], stop_id=m["stop_id"], + )) + except Exception as e: + logger.error(f"Failed to aggregate analytics: {e}") + + sorted_students = sorted(student_scores.items(), key=lambda x: x[1]["avg_score"], reverse=True) + top_performers = [s[1]["name"] for s in sorted_students[:3]] + struggling_students = [s[1]["name"] for s in sorted_students[-3:] if s[1]["avg_score"] < 0.6] + total_time = sum(s["total_time"] for s in student_scores.values()) + avg_scores = [s["avg_score"] for s in student_scores.values() if s["avg_score"] > 0] + avg_completion = sum(avg_scores) / len(avg_scores) if avg_scores else 0 + + return ClassAnalyticsSummary( + class_id=class_id, class_name=f"Klasse {class_id[:8]}", + total_units_assigned=total_units, units_completed=completed_units, + active_units=active_units, avg_completion_rate=avg_completion, + avg_learning_gain=None, total_time_hours=total_time / 60, + top_performers=top_performers, struggling_students=struggling_students, + common_misconceptions=misconceptions[:5], + ) + + +@router.get("/students/{student_id}/progress") +async def get_student_progress( + student_id: str, + teacher: Dict[str, Any] = Depends(get_current_teacher) +) -> Dict[str, Any]: + """Get detailed progress for a specific student.""" + db = await get_teacher_database() + if db: + try: + progress = await db.get_student_full_progress(student_id) + return progress + except Exception as e: + logger.error(f"Failed to get student progress: {e}") + return { + "student_id": student_id, "units_attempted": 0, "units_completed": 0, + "avg_score": 0.0, "total_time_minutes": 0, "sessions": [], + } + + +# ============================================== +# API Endpoints - Content Resources +# ============================================== + +@router.get("/assignments/{assignment_id}/resources", response_model=List[ContentResource]) +async def get_assignment_resources( + assignment_id: str, + teacher: Dict[str, Any] = Depends(get_current_teacher), + request: Request = None +) -> List[ContentResource]: + """Get generated content resources for an assignment.""" + db = await get_teacher_database() + assignment = None + if db: + try: + assignment = await db.get_assignment(assignment_id) + except Exception as e: + logger.error(f"Failed to get assignment: {e}") + if not assignment and assignment_id in _assignments_store: + assignment = _assignments_store[assignment_id] + if not assignment or assignment["teacher_id"] != teacher["user_id"]: + raise HTTPException(status_code=404, detail="Assignment not found") + + unit_id = assignment["unit_id"] + base_url = str(request.base_url).rstrip("/") if request else "http://localhost:8000" + return [ + ContentResource(resource_type="h5p", title=f"{unit_id} - H5P Aktivitaeten", + url=f"{base_url}/api/units/content/{unit_id}/h5p", + generated_at=datetime.utcnow(), unit_id=unit_id), + ContentResource(resource_type="worksheet", title=f"{unit_id} - Arbeitsblatt (HTML)", + url=f"{base_url}/api/units/content/{unit_id}/worksheet", + generated_at=datetime.utcnow(), unit_id=unit_id), + ContentResource(resource_type="pdf", title=f"{unit_id} - Arbeitsblatt (PDF)", + url=f"{base_url}/api/units/content/{unit_id}/worksheet.pdf", + generated_at=datetime.utcnow(), unit_id=unit_id), + ] + + +@router.post("/assignments/{assignment_id}/regenerate-content") +async def regenerate_content( + assignment_id: str, + resource_type: str = Query("all", description="h5p, pdf, or all"), + teacher: Dict[str, Any] = Depends(get_current_teacher) +) -> Dict[str, Any]: + """Trigger regeneration of content resources.""" + db = await get_teacher_database() + assignment = None + if db: + try: + assignment = await db.get_assignment(assignment_id) + except Exception as e: + logger.error(f"Failed to get assignment: {e}") + if not assignment and assignment_id in _assignments_store: + assignment = _assignments_store[assignment_id] + if not assignment or assignment["teacher_id"] != teacher["user_id"]: + raise HTTPException(status_code=404, detail="Assignment not found") + + logger.info(f"Content regeneration triggered for {assignment['unit_id']}: {resource_type}") + return { + "status": "queued", "assignment_id": assignment_id, + "unit_id": assignment["unit_id"], "resource_type": resource_type, + "message": "Content regeneration has been queued", + } diff --git a/backend-lehrer/teacher_dashboard_api.py b/backend-lehrer/teacher_dashboard_api.py index 04c217b..0212acf 100644 --- a/backend-lehrer/teacher_dashboard_api.py +++ b/backend-lehrer/teacher_dashboard_api.py @@ -1,245 +1,42 @@ # ============================================== # Breakpilot Drive - Teacher Dashboard API # ============================================== -# Lehrer-Dashboard fuer Unit-Zuweisung und Analytics: -# - Units zu Klassen zuweisen -# - Schueler-Fortschritt einsehen -# - Klassen-Analytics -# - H5P und PDF Content verwalten -# - Unit-Einstellungen pro Klasse +# Lehrer-Dashboard fuer Unit-Zuweisung und Analytics. +# +# Split structure: +# - teacher_dashboard_models.py: Models, Auth, DB/School helpers +# - teacher_dashboard_analytics.py: Progress, analytics, content routes +# - teacher_dashboard_api.py: Assignment CRUD, dashboard, units (this file) -from fastapi import APIRouter, HTTPException, Query, Depends, Request -from pydantic import BaseModel, Field +from fastapi import APIRouter, HTTPException, Query, Depends from typing import List, Optional, Dict, Any from datetime import datetime, timedelta -from enum import Enum import uuid -import os import logging -import httpx + +from teacher_dashboard_models import ( + UnitAssignmentStatus, TeacherControlSettings, AssignUnitRequest, + UnitAssignment, + get_current_teacher, get_teacher_database, + get_classes_for_teacher, + REQUIRE_AUTH, +) +from teacher_dashboard_analytics import ( + router as analytics_router, + set_assignments_store, +) logger = logging.getLogger(__name__) -# Feature flags -USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true" -REQUIRE_AUTH = os.getenv("TEACHER_REQUIRE_AUTH", "true").lower() == "true" -SCHOOL_SERVICE_URL = os.getenv("SCHOOL_SERVICE_URL", "http://school-service:8084") - router = APIRouter(prefix="/api/teacher", tags=["Teacher Dashboard"]) - -# ============================================== -# Pydantic Models -# ============================================== - -class UnitAssignmentStatus(str, Enum): - """Status of a unit assignment""" - DRAFT = "draft" - ACTIVE = "active" - COMPLETED = "completed" - ARCHIVED = "archived" - - -class TeacherControlSettings(BaseModel): - """Unit settings that teachers can configure""" - allow_skip: bool = True - allow_replay: bool = True - max_time_per_stop_sec: int = 90 - show_hints: bool = True - require_precheck: bool = True - require_postcheck: bool = True - - -class AssignUnitRequest(BaseModel): - """Request to assign a unit to a class""" - unit_id: str - class_id: str - due_date: Optional[datetime] = None - settings: Optional[TeacherControlSettings] = None - notes: Optional[str] = None - - -class UnitAssignment(BaseModel): - """Unit assignment record""" - assignment_id: str - unit_id: str - class_id: str - teacher_id: str - status: UnitAssignmentStatus - settings: TeacherControlSettings - due_date: Optional[datetime] = None - notes: Optional[str] = None - created_at: datetime - updated_at: datetime - - -class StudentUnitProgress(BaseModel): - """Progress of a single student on a unit""" - student_id: str - student_name: str - session_id: Optional[str] = None - status: str # "not_started", "in_progress", "completed" - completion_rate: float = 0.0 - precheck_score: Optional[float] = None - postcheck_score: Optional[float] = None - learning_gain: Optional[float] = None - time_spent_minutes: int = 0 - last_activity: Optional[datetime] = None - current_stop: Optional[str] = None - stops_completed: int = 0 - total_stops: int = 0 - - -class ClassUnitProgress(BaseModel): - """Overall progress of a class on a unit""" - assignment_id: str - unit_id: str - unit_title: str - class_id: str - class_name: str - total_students: int - started_count: int - completed_count: int - avg_completion_rate: float - avg_precheck_score: Optional[float] = None - avg_postcheck_score: Optional[float] = None - avg_learning_gain: Optional[float] = None - avg_time_minutes: float - students: List[StudentUnitProgress] - - -class MisconceptionReport(BaseModel): - """Report of detected misconceptions""" - concept_id: str - concept_label: str - misconception: str - affected_students: List[str] - frequency: int - unit_id: str - stop_id: str - - -class ClassAnalyticsSummary(BaseModel): - """Summary analytics for a class""" - class_id: str - class_name: str - total_units_assigned: int - units_completed: int - active_units: int - avg_completion_rate: float - avg_learning_gain: Optional[float] - total_time_hours: float - top_performers: List[str] - struggling_students: List[str] - common_misconceptions: List[MisconceptionReport] - - -class ContentResource(BaseModel): - """Generated content resource""" - resource_type: str # "h5p", "pdf", "worksheet" - title: str - url: str - generated_at: datetime - unit_id: str - - -# ============================================== -# Auth Dependency -# ============================================== - -async def get_current_teacher(request: Request) -> Dict[str, Any]: - """Get current teacher from JWT token.""" - if not REQUIRE_AUTH: - # Dev mode: return demo teacher - return { - "user_id": "e9484ad9-32ee-4f2b-a4e1-d182e02ccf20", - "email": "demo@breakpilot.app", - "role": "teacher", - "name": "Demo Lehrer" - } - - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Missing authorization token") - - try: - import jwt - token = auth_header[7:] - secret = os.getenv("JWT_SECRET", "dev-secret-key") - payload = jwt.decode(token, secret, algorithms=["HS256"]) - - if payload.get("role") not in ["teacher", "admin"]: - raise HTTPException(status_code=403, detail="Teacher or admin role required") - - return payload - except jwt.ExpiredSignatureError: - raise HTTPException(status_code=401, detail="Token expired") - except jwt.InvalidTokenError: - raise HTTPException(status_code=401, detail="Invalid token") - - -# ============================================== -# Database Integration -# ============================================== - -_teacher_db = None - -async def get_teacher_database(): - """Get teacher database instance with lazy initialization.""" - global _teacher_db - if not USE_DATABASE: - return None - if _teacher_db is None: - try: - from unit.database import get_teacher_db - _teacher_db = await get_teacher_db() - logger.info("Teacher database initialized") - except ImportError: - logger.warning("Teacher database module not available") - except Exception as e: - logger.warning(f"Teacher database not available: {e}") - return _teacher_db - - -# ============================================== -# School Service Integration -# ============================================== - -async def get_classes_for_teacher(teacher_id: str) -> List[Dict[str, Any]]: - """Get classes assigned to a teacher from school service.""" - async with httpx.AsyncClient(timeout=10.0) as client: - try: - response = await client.get( - f"{SCHOOL_SERVICE_URL}/api/v1/school/classes", - headers={"X-Teacher-ID": teacher_id} - ) - if response.status_code == 200: - return response.json() - except Exception as e: - logger.error(f"Failed to get classes from school service: {e}") - return [] - - -async def get_students_in_class(class_id: str) -> List[Dict[str, Any]]: - """Get students in a class from school service.""" - async with httpx.AsyncClient(timeout=10.0) as client: - try: - response = await client.get( - f"{SCHOOL_SERVICE_URL}/api/v1/school/classes/{class_id}/students" - ) - if response.status_code == 200: - return response.json() - except Exception as e: - logger.error(f"Failed to get students from school service: {e}") - return [] - - -# ============================================== # In-Memory Storage (Fallback) -# ============================================== - _assignments_store: Dict[str, Dict[str, Any]] = {} +# Share the store with the analytics module and include its routes +set_assignments_store(_assignments_store) +router.include_router(analytics_router) + # ============================================== # API Endpoints - Unit Assignment @@ -250,28 +47,17 @@ async def assign_unit_to_class( request_data: AssignUnitRequest, teacher: Dict[str, Any] = Depends(get_current_teacher) ) -> UnitAssignment: - """ - Assign a unit to a class. - - Creates an assignment that allows students in the class to play the unit. - Teacher can configure settings like skip, replay, time limits. - """ + """Assign a unit to a class.""" assignment_id = str(uuid.uuid4()) now = datetime.utcnow() - settings = request_data.settings or TeacherControlSettings() assignment = { - "assignment_id": assignment_id, - "unit_id": request_data.unit_id, - "class_id": request_data.class_id, - "teacher_id": teacher["user_id"], - "status": UnitAssignmentStatus.ACTIVE, - "settings": settings.model_dump(), - "due_date": request_data.due_date, - "notes": request_data.notes, - "created_at": now, - "updated_at": now, + "assignment_id": assignment_id, "unit_id": request_data.unit_id, + "class_id": request_data.class_id, "teacher_id": teacher["user_id"], + "status": UnitAssignmentStatus.ACTIVE, "settings": settings.model_dump(), + "due_date": request_data.due_date, "notes": request_data.notes, + "created_at": now, "updated_at": now, } db = await get_teacher_database() @@ -281,22 +67,15 @@ async def assign_unit_to_class( except Exception as e: logger.error(f"Failed to store assignment: {e}") - # Fallback: store in memory _assignments_store[assignment_id] = assignment - logger.info(f"Unit {request_data.unit_id} assigned to class {request_data.class_id}") return UnitAssignment( - assignment_id=assignment_id, - unit_id=request_data.unit_id, - class_id=request_data.class_id, - teacher_id=teacher["user_id"], - status=UnitAssignmentStatus.ACTIVE, - settings=settings, - due_date=request_data.due_date, - notes=request_data.notes, - created_at=now, - updated_at=now, + assignment_id=assignment_id, unit_id=request_data.unit_id, + class_id=request_data.class_id, teacher_id=teacher["user_id"], + status=UnitAssignmentStatus.ACTIVE, settings=settings, + due_date=request_data.due_date, notes=request_data.notes, + created_at=now, updated_at=now, ) @@ -306,11 +85,7 @@ async def list_assignments( status: Optional[UnitAssignmentStatus] = Query(None, description="Filter by status"), teacher: Dict[str, Any] = Depends(get_current_teacher) ) -> List[UnitAssignment]: - """ - List all unit assignments for the teacher. - - Optionally filter by class or status. - """ + """List all unit assignments for the teacher.""" db = await get_teacher_database() assignments = [] @@ -325,7 +100,6 @@ async def list_assignments( logger.error(f"Failed to list assignments: {e}") if not assignments: - # Fallback: filter in-memory store for assignment in _assignments_store.values(): if assignment["teacher_id"] != teacher["user_id"]: continue @@ -337,16 +111,11 @@ async def list_assignments( return [ UnitAssignment( - assignment_id=a["assignment_id"], - unit_id=a["unit_id"], - class_id=a["class_id"], - teacher_id=a["teacher_id"], - status=a["status"], - settings=TeacherControlSettings(**a["settings"]), - due_date=a.get("due_date"), - notes=a.get("notes"), - created_at=a["created_at"], - updated_at=a["updated_at"], + assignment_id=a["assignment_id"], unit_id=a["unit_id"], + class_id=a["class_id"], teacher_id=a["teacher_id"], + status=a["status"], settings=TeacherControlSettings(**a["settings"]), + due_date=a.get("due_date"), notes=a.get("notes"), + created_at=a["created_at"], updated_at=a["updated_at"], ) for a in assignments ] @@ -359,41 +128,30 @@ async def get_assignment( ) -> UnitAssignment: """Get details of a specific assignment.""" db = await get_teacher_database() - if db: try: assignment = await db.get_assignment(assignment_id) if assignment and assignment["teacher_id"] == teacher["user_id"]: return UnitAssignment( - assignment_id=assignment["assignment_id"], - unit_id=assignment["unit_id"], - class_id=assignment["class_id"], - teacher_id=assignment["teacher_id"], + assignment_id=assignment["assignment_id"], unit_id=assignment["unit_id"], + class_id=assignment["class_id"], teacher_id=assignment["teacher_id"], status=assignment["status"], settings=TeacherControlSettings(**assignment["settings"]), - due_date=assignment.get("due_date"), - notes=assignment.get("notes"), - created_at=assignment["created_at"], - updated_at=assignment["updated_at"], + due_date=assignment.get("due_date"), notes=assignment.get("notes"), + created_at=assignment["created_at"], updated_at=assignment["updated_at"], ) except Exception as e: logger.error(f"Failed to get assignment: {e}") - # Fallback if assignment_id in _assignments_store: a = _assignments_store[assignment_id] if a["teacher_id"] == teacher["user_id"]: return UnitAssignment( - assignment_id=a["assignment_id"], - unit_id=a["unit_id"], - class_id=a["class_id"], - teacher_id=a["teacher_id"], - status=a["status"], - settings=TeacherControlSettings(**a["settings"]), - due_date=a.get("due_date"), - notes=a.get("notes"), - created_at=a["created_at"], - updated_at=a["updated_at"], + assignment_id=a["assignment_id"], unit_id=a["unit_id"], + class_id=a["class_id"], teacher_id=a["teacher_id"], + status=a["status"], settings=TeacherControlSettings(**a["settings"]), + due_date=a.get("due_date"), notes=a.get("notes"), + created_at=a["created_at"], updated_at=a["updated_at"], ) raise HTTPException(status_code=404, detail="Assignment not found") @@ -424,7 +182,6 @@ async def update_assignment( if not assignment or assignment["teacher_id"] != teacher["user_id"]: raise HTTPException(status_code=404, detail="Assignment not found") - # Update fields if settings: assignment["settings"] = settings.model_dump() if status: @@ -444,16 +201,11 @@ async def update_assignment( _assignments_store[assignment_id] = assignment return UnitAssignment( - assignment_id=assignment["assignment_id"], - unit_id=assignment["unit_id"], - class_id=assignment["class_id"], - teacher_id=assignment["teacher_id"], - status=assignment["status"], - settings=TeacherControlSettings(**assignment["settings"]), - due_date=assignment.get("due_date"), - notes=assignment.get("notes"), - created_at=assignment["created_at"], - updated_at=assignment["updated_at"], + assignment_id=assignment["assignment_id"], unit_id=assignment["unit_id"], + class_id=assignment["class_id"], teacher_id=assignment["teacher_id"], + status=assignment["status"], settings=TeacherControlSettings(**assignment["settings"]), + due_date=assignment.get("due_date"), notes=assignment.get("notes"), + created_at=assignment["created_at"], updated_at=assignment["updated_at"], ) @@ -464,7 +216,6 @@ async def delete_assignment( ) -> Dict[str, str]: """Delete/archive an assignment.""" db = await get_teacher_database() - if db: try: assignment = await db.get_assignment(assignment_id) @@ -485,339 +236,6 @@ async def delete_assignment( raise HTTPException(status_code=404, detail="Assignment not found") -# ============================================== -# API Endpoints - Progress & Analytics -# ============================================== - -@router.get("/assignments/{assignment_id}/progress", response_model=ClassUnitProgress) -async def get_assignment_progress( - assignment_id: str, - teacher: Dict[str, Any] = Depends(get_current_teacher) -) -> ClassUnitProgress: - """ - Get detailed progress for an assignment. - - Shows each student's status, scores, and time spent. - """ - db = await get_teacher_database() - assignment = None - - if db: - try: - assignment = await db.get_assignment(assignment_id) - except Exception as e: - logger.error(f"Failed to get assignment: {e}") - - if not assignment and assignment_id in _assignments_store: - assignment = _assignments_store[assignment_id] - - if not assignment or assignment["teacher_id"] != teacher["user_id"]: - raise HTTPException(status_code=404, detail="Assignment not found") - - # Get students in class - students = await get_students_in_class(assignment["class_id"]) - - # Get progress for each student - student_progress = [] - total_completion = 0.0 - total_precheck = 0.0 - total_postcheck = 0.0 - total_time = 0 - precheck_count = 0 - postcheck_count = 0 - started = 0 - completed = 0 - - for student in students: - student_id = student.get("id", student.get("student_id")) - progress = StudentUnitProgress( - student_id=student_id, - student_name=student.get("name", f"Student {student_id[:8]}"), - status="not_started", - completion_rate=0.0, - stops_completed=0, - total_stops=0, - ) - - if db: - try: - session_data = await db.get_student_unit_session( - student_id=student_id, - unit_id=assignment["unit_id"] - ) - if session_data: - progress.session_id = session_data.get("session_id") - progress.status = "completed" if session_data.get("completed_at") else "in_progress" - progress.completion_rate = session_data.get("completion_rate", 0.0) - progress.precheck_score = session_data.get("precheck_score") - progress.postcheck_score = session_data.get("postcheck_score") - progress.time_spent_minutes = session_data.get("duration_seconds", 0) // 60 - progress.last_activity = session_data.get("updated_at") - progress.stops_completed = session_data.get("stops_completed", 0) - progress.total_stops = session_data.get("total_stops", 0) - - if progress.precheck_score is not None and progress.postcheck_score is not None: - progress.learning_gain = progress.postcheck_score - progress.precheck_score - - # Aggregate stats - total_completion += progress.completion_rate - total_time += progress.time_spent_minutes - if progress.precheck_score is not None: - total_precheck += progress.precheck_score - precheck_count += 1 - if progress.postcheck_score is not None: - total_postcheck += progress.postcheck_score - postcheck_count += 1 - if progress.status != "not_started": - started += 1 - if progress.status == "completed": - completed += 1 - except Exception as e: - logger.error(f"Failed to get student progress: {e}") - - student_progress.append(progress) - - total_students = len(students) or 1 # Avoid division by zero - - return ClassUnitProgress( - assignment_id=assignment_id, - unit_id=assignment["unit_id"], - unit_title=f"Unit {assignment['unit_id']}", # Would load from unit definition - class_id=assignment["class_id"], - class_name=f"Class {assignment['class_id'][:8]}", # Would load from school service - total_students=len(students), - started_count=started, - completed_count=completed, - avg_completion_rate=total_completion / total_students, - avg_precheck_score=total_precheck / precheck_count if precheck_count > 0 else None, - avg_postcheck_score=total_postcheck / postcheck_count if postcheck_count > 0 else None, - avg_learning_gain=(total_postcheck / postcheck_count - total_precheck / precheck_count) - if precheck_count > 0 and postcheck_count > 0 else None, - avg_time_minutes=total_time / started if started > 0 else 0, - students=student_progress, - ) - - -@router.get("/classes/{class_id}/analytics", response_model=ClassAnalyticsSummary) -async def get_class_analytics( - class_id: str, - teacher: Dict[str, Any] = Depends(get_current_teacher) -) -> ClassAnalyticsSummary: - """ - Get summary analytics for a class. - - Includes all unit assignments, overall progress, and common misconceptions. - """ - db = await get_teacher_database() - - # Get all assignments for this class - assignments = [] - if db: - try: - assignments = await db.list_assignments( - teacher_id=teacher["user_id"], - class_id=class_id - ) - except Exception as e: - logger.error(f"Failed to list assignments: {e}") - - if not assignments: - assignments = [ - a for a in _assignments_store.values() - if a["class_id"] == class_id and a["teacher_id"] == teacher["user_id"] - ] - - total_units = len(assignments) - completed_units = sum(1 for a in assignments if a.get("status") == "completed") - active_units = sum(1 for a in assignments if a.get("status") == "active") - - # Aggregate student performance - students = await get_students_in_class(class_id) - student_scores = {} - misconceptions = [] - - if db: - try: - for student in students: - student_id = student.get("id", student.get("student_id")) - analytics = await db.get_student_analytics(student_id) - if analytics: - student_scores[student_id] = { - "name": student.get("name", student_id[:8]), - "avg_score": analytics.get("avg_postcheck_score", 0), - "total_time": analytics.get("total_time_minutes", 0), - } - - # Get common misconceptions - misconceptions_data = await db.get_class_misconceptions(class_id) - for m in misconceptions_data: - misconceptions.append(MisconceptionReport( - concept_id=m["concept_id"], - concept_label=m["concept_label"], - misconception=m["misconception"], - affected_students=m["affected_students"], - frequency=m["frequency"], - unit_id=m["unit_id"], - stop_id=m["stop_id"], - )) - except Exception as e: - logger.error(f"Failed to aggregate analytics: {e}") - - # Identify top and struggling students - sorted_students = sorted( - student_scores.items(), - key=lambda x: x[1]["avg_score"], - reverse=True - ) - top_performers = [s[1]["name"] for s in sorted_students[:3]] - struggling_students = [s[1]["name"] for s in sorted_students[-3:] if s[1]["avg_score"] < 0.6] - - total_time = sum(s["total_time"] for s in student_scores.values()) - avg_scores = [s["avg_score"] for s in student_scores.values() if s["avg_score"] > 0] - avg_completion = sum(avg_scores) / len(avg_scores) if avg_scores else 0 - - return ClassAnalyticsSummary( - class_id=class_id, - class_name=f"Klasse {class_id[:8]}", - total_units_assigned=total_units, - units_completed=completed_units, - active_units=active_units, - avg_completion_rate=avg_completion, - avg_learning_gain=None, # Would calculate from pre/post scores - total_time_hours=total_time / 60, - top_performers=top_performers, - struggling_students=struggling_students, - common_misconceptions=misconceptions[:5], - ) - - -@router.get("/students/{student_id}/progress") -async def get_student_progress( - student_id: str, - teacher: Dict[str, Any] = Depends(get_current_teacher) -) -> Dict[str, Any]: - """ - Get detailed progress for a specific student. - - Shows all units attempted and their performance. - """ - db = await get_teacher_database() - - if db: - try: - progress = await db.get_student_full_progress(student_id) - return progress - except Exception as e: - logger.error(f"Failed to get student progress: {e}") - - return { - "student_id": student_id, - "units_attempted": 0, - "units_completed": 0, - "avg_score": 0.0, - "total_time_minutes": 0, - "sessions": [], - } - - -# ============================================== -# API Endpoints - Content Resources -# ============================================== - -@router.get("/assignments/{assignment_id}/resources", response_model=List[ContentResource]) -async def get_assignment_resources( - assignment_id: str, - teacher: Dict[str, Any] = Depends(get_current_teacher), - request: Request = None -) -> List[ContentResource]: - """ - Get generated content resources for an assignment. - - Returns links to H5P activities and PDF worksheets. - """ - db = await get_teacher_database() - assignment = None - - if db: - try: - assignment = await db.get_assignment(assignment_id) - except Exception as e: - logger.error(f"Failed to get assignment: {e}") - - if not assignment and assignment_id in _assignments_store: - assignment = _assignments_store[assignment_id] - - if not assignment or assignment["teacher_id"] != teacher["user_id"]: - raise HTTPException(status_code=404, detail="Assignment not found") - - unit_id = assignment["unit_id"] - base_url = str(request.base_url).rstrip("/") if request else "http://localhost:8000" - - resources = [ - ContentResource( - resource_type="h5p", - title=f"{unit_id} - H5P Aktivitaeten", - url=f"{base_url}/api/units/content/{unit_id}/h5p", - generated_at=datetime.utcnow(), - unit_id=unit_id, - ), - ContentResource( - resource_type="worksheet", - title=f"{unit_id} - Arbeitsblatt (HTML)", - url=f"{base_url}/api/units/content/{unit_id}/worksheet", - generated_at=datetime.utcnow(), - unit_id=unit_id, - ), - ContentResource( - resource_type="pdf", - title=f"{unit_id} - Arbeitsblatt (PDF)", - url=f"{base_url}/api/units/content/{unit_id}/worksheet.pdf", - generated_at=datetime.utcnow(), - unit_id=unit_id, - ), - ] - - return resources - - -@router.post("/assignments/{assignment_id}/regenerate-content") -async def regenerate_content( - assignment_id: str, - resource_type: str = Query("all", description="h5p, pdf, or all"), - teacher: Dict[str, Any] = Depends(get_current_teacher) -) -> Dict[str, Any]: - """ - Trigger regeneration of content resources. - - Useful after updating unit definitions. - """ - db = await get_teacher_database() - assignment = None - - if db: - try: - assignment = await db.get_assignment(assignment_id) - except Exception as e: - logger.error(f"Failed to get assignment: {e}") - - if not assignment and assignment_id in _assignments_store: - assignment = _assignments_store[assignment_id] - - if not assignment or assignment["teacher_id"] != teacher["user_id"]: - raise HTTPException(status_code=404, detail="Assignment not found") - - # In production, this would trigger async job to regenerate content - logger.info(f"Content regeneration triggered for {assignment['unit_id']}: {resource_type}") - - return { - "status": "queued", - "assignment_id": assignment_id, - "unit_id": assignment["unit_id"], - "resource_type": resource_type, - "message": "Content regeneration has been queued", - } - - # ============================================== # API Endpoints - Available Units # ============================================== @@ -829,51 +247,30 @@ async def list_available_units( locale: str = Query("de-DE", description="Locale"), teacher: Dict[str, Any] = Depends(get_current_teacher) ) -> List[Dict[str, Any]]: - """ - List all available units for assignment. - - Teachers see all published units matching their criteria. - """ + """List all available units for assignment.""" db = await get_teacher_database() - if db: try: - units = await db.list_available_units( - grade=grade, - template=template, - locale=locale - ) + units = await db.list_available_units(grade=grade, template=template, locale=locale) return units except Exception as e: logger.error(f"Failed to list units: {e}") - - # Fallback: return demo units return [ { - "unit_id": "bio_eye_lightpath_v1", - "title": "Auge - Lichtstrahl-Flug", - "template": "flight_path", - "grade_band": ["5", "6", "7"], - "duration_minutes": 8, - "difficulty": "base", + "unit_id": "bio_eye_lightpath_v1", "title": "Auge - Lichtstrahl-Flug", + "template": "flight_path", "grade_band": ["5", "6", "7"], + "duration_minutes": 8, "difficulty": "base", "description": "Reise durch das Auge und folge dem Lichtstrahl", - "learning_objectives": [ - "Verstehen des Lichtwegs durch das Auge", - "Funktionen der Augenbestandteile benennen", - ], + "learning_objectives": ["Verstehen des Lichtwegs durch das Auge", + "Funktionen der Augenbestandteile benennen"], }, { "unit_id": "math_pizza_equivalence_v1", "title": "Pizza-Boxenstopp - Brueche und Prozent", - "template": "station_loop", - "grade_band": ["5", "6"], - "duration_minutes": 10, - "difficulty": "base", + "template": "station_loop", "grade_band": ["5", "6"], + "duration_minutes": 10, "difficulty": "base", "description": "Entdecke die Verbindung zwischen Bruechen, Dezimalzahlen und Prozent", - "learning_objectives": [ - "Brueche in Prozent umrechnen", - "Aequivalenzen erkennen", - ], + "learning_objectives": ["Brueche in Prozent umrechnen", "Aequivalenzen erkennen"], }, ] @@ -886,54 +283,38 @@ async def list_available_units( async def get_dashboard( teacher: Dict[str, Any] = Depends(get_current_teacher) ) -> Dict[str, Any]: - """ - Get teacher dashboard overview. - - Summary of all classes, active assignments, and alerts. - """ + """Get teacher dashboard overview.""" db = await get_teacher_database() - - # Get teacher's classes classes = await get_classes_for_teacher(teacher["user_id"]) - # Get all active assignments active_assignments = [] if db: try: active_assignments = await db.list_assignments( - teacher_id=teacher["user_id"], - status="active" + teacher_id=teacher["user_id"], status="active" ) except Exception as e: logger.error(f"Failed to list assignments: {e}") - if not active_assignments: active_assignments = [ a for a in _assignments_store.values() if a["teacher_id"] == teacher["user_id"] and a.get("status") == "active" ] - # Calculate alerts (students falling behind, due dates, etc.) alerts = [] for assignment in active_assignments: if assignment.get("due_date") and assignment["due_date"] < datetime.utcnow() + timedelta(days=2): alerts.append({ - "type": "due_soon", - "assignment_id": assignment["assignment_id"], - "message": f"Zuweisung endet in weniger als 2 Tagen", + "type": "due_soon", "assignment_id": assignment["assignment_id"], + "message": "Zuweisung endet in weniger als 2 Tagen", }) return { - "teacher": { - "id": teacher["user_id"], - "name": teacher.get("name", "Lehrer"), - "email": teacher.get("email"), - }, - "classes": len(classes), - "active_assignments": len(active_assignments), + "teacher": {"id": teacher["user_id"], "name": teacher.get("name", "Lehrer"), + "email": teacher.get("email")}, + "classes": len(classes), "active_assignments": len(active_assignments), "total_students": sum(c.get("student_count", 0) for c in classes), - "alerts": alerts, - "recent_activity": [], # Would load recent session completions + "alerts": alerts, "recent_activity": [], } @@ -942,10 +323,7 @@ async def health_check() -> Dict[str, Any]: """Health check for teacher dashboard API.""" db = await get_teacher_database() db_status = "connected" if db else "in-memory" - return { - "status": "healthy", - "service": "teacher-dashboard", - "database": db_status, - "auth_required": REQUIRE_AUTH, + "status": "healthy", "service": "teacher-dashboard", + "database": db_status, "auth_required": REQUIRE_AUTH, } diff --git a/backend-lehrer/teacher_dashboard_models.py b/backend-lehrer/teacher_dashboard_models.py new file mode 100644 index 0000000..88e6d9c --- /dev/null +++ b/backend-lehrer/teacher_dashboard_models.py @@ -0,0 +1,226 @@ +""" +Teacher Dashboard - Pydantic Models, Auth Dependency, and Service Helpers. +""" + +import os +import logging +from datetime import datetime +from typing import List, Optional, Dict, Any +from enum import Enum + +from fastapi import HTTPException, Request +from pydantic import BaseModel +import httpx + +logger = logging.getLogger(__name__) + +# Feature flags +USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true" +REQUIRE_AUTH = os.getenv("TEACHER_REQUIRE_AUTH", "true").lower() == "true" +SCHOOL_SERVICE_URL = os.getenv("SCHOOL_SERVICE_URL", "http://school-service:8084") + + +# ============================================== +# Pydantic Models +# ============================================== + +class UnitAssignmentStatus(str, Enum): + """Status of a unit assignment""" + DRAFT = "draft" + ACTIVE = "active" + COMPLETED = "completed" + ARCHIVED = "archived" + + +class TeacherControlSettings(BaseModel): + """Unit settings that teachers can configure""" + allow_skip: bool = True + allow_replay: bool = True + max_time_per_stop_sec: int = 90 + show_hints: bool = True + require_precheck: bool = True + require_postcheck: bool = True + + +class AssignUnitRequest(BaseModel): + """Request to assign a unit to a class""" + unit_id: str + class_id: str + due_date: Optional[datetime] = None + settings: Optional[TeacherControlSettings] = None + notes: Optional[str] = None + + +class UnitAssignment(BaseModel): + """Unit assignment record""" + assignment_id: str + unit_id: str + class_id: str + teacher_id: str + status: UnitAssignmentStatus + settings: TeacherControlSettings + due_date: Optional[datetime] = None + notes: Optional[str] = None + created_at: datetime + updated_at: datetime + + +class StudentUnitProgress(BaseModel): + """Progress of a single student on a unit""" + student_id: str + student_name: str + session_id: Optional[str] = None + status: str # "not_started", "in_progress", "completed" + completion_rate: float = 0.0 + precheck_score: Optional[float] = None + postcheck_score: Optional[float] = None + learning_gain: Optional[float] = None + time_spent_minutes: int = 0 + last_activity: Optional[datetime] = None + current_stop: Optional[str] = None + stops_completed: int = 0 + total_stops: int = 0 + + +class ClassUnitProgress(BaseModel): + """Overall progress of a class on a unit""" + assignment_id: str + unit_id: str + unit_title: str + class_id: str + class_name: str + total_students: int + started_count: int + completed_count: int + avg_completion_rate: float + avg_precheck_score: Optional[float] = None + avg_postcheck_score: Optional[float] = None + avg_learning_gain: Optional[float] = None + avg_time_minutes: float + students: List[StudentUnitProgress] + + +class MisconceptionReport(BaseModel): + """Report of detected misconceptions""" + concept_id: str + concept_label: str + misconception: str + affected_students: List[str] + frequency: int + unit_id: str + stop_id: str + + +class ClassAnalyticsSummary(BaseModel): + """Summary analytics for a class""" + class_id: str + class_name: str + total_units_assigned: int + units_completed: int + active_units: int + avg_completion_rate: float + avg_learning_gain: Optional[float] + total_time_hours: float + top_performers: List[str] + struggling_students: List[str] + common_misconceptions: List[MisconceptionReport] + + +class ContentResource(BaseModel): + """Generated content resource""" + resource_type: str # "h5p", "pdf", "worksheet" + title: str + url: str + generated_at: datetime + unit_id: str + + +# ============================================== +# Auth Dependency +# ============================================== + +async def get_current_teacher(request: Request) -> Dict[str, Any]: + """Get current teacher from JWT token.""" + if not REQUIRE_AUTH: + return { + "user_id": "e9484ad9-32ee-4f2b-a4e1-d182e02ccf20", + "email": "demo@breakpilot.app", + "role": "teacher", + "name": "Demo Lehrer" + } + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing authorization token") + + try: + import jwt + token = auth_header[7:] + secret = os.getenv("JWT_SECRET", "dev-secret-key") + payload = jwt.decode(token, secret, algorithms=["HS256"]) + + if payload.get("role") not in ["teacher", "admin"]: + raise HTTPException(status_code=403, detail="Teacher or admin role required") + + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + + +# ============================================== +# Database Integration +# ============================================== + +_teacher_db = None + + +async def get_teacher_database(): + """Get teacher database instance with lazy initialization.""" + global _teacher_db + if not USE_DATABASE: + return None + if _teacher_db is None: + try: + from unit.database import get_teacher_db + _teacher_db = await get_teacher_db() + logger.info("Teacher database initialized") + except ImportError: + logger.warning("Teacher database module not available") + except Exception as e: + logger.warning(f"Teacher database not available: {e}") + return _teacher_db + + +# ============================================== +# School Service Integration +# ============================================== + +async def get_classes_for_teacher(teacher_id: str) -> List[Dict[str, Any]]: + """Get classes assigned to a teacher from school service.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get( + f"{SCHOOL_SERVICE_URL}/api/v1/school/classes", + headers={"X-Teacher-ID": teacher_id} + ) + if response.status_code == 200: + return response.json() + except Exception as e: + logger.error(f"Failed to get classes from school service: {e}") + return [] + + +async def get_students_in_class(class_id: str) -> List[Dict[str, Any]]: + """Get students in a class from school service.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get( + f"{SCHOOL_SERVICE_URL}/api/v1/school/classes/{class_id}/students" + ) + if response.status_code == 200: + return response.json() + except Exception as e: + logger.error(f"Failed to get students from school service: {e}") + return [] diff --git a/klausur-service/backend/legal_templates_chunking.py b/klausur-service/backend/legal_templates_chunking.py new file mode 100644 index 0000000..724b0da --- /dev/null +++ b/klausur-service/backend/legal_templates_chunking.py @@ -0,0 +1,282 @@ +""" +Legal Templates Chunking — text splitting, type inference, and chunk creation. + +Extracted from legal_templates_ingestion.py to keep files under 500 LOC. + +Lizenz: Apache 2.0 +""" + +import re +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional + +from template_sources import SourceConfig +from github_crawler import ExtractedDocument + + +# Chunking configuration defaults (can be overridden by env vars in ingestion module) +DEFAULT_CHUNK_SIZE = 1000 +DEFAULT_CHUNK_OVERLAP = 200 + + +@dataclass +class TemplateChunk: + """A chunk of template text ready for indexing.""" + text: str + chunk_index: int + document_title: str + template_type: str + clause_category: Optional[str] + language: str + jurisdiction: str + license_id: str + license_name: str + license_url: str + attribution_required: bool + share_alike: bool + no_derivatives: bool + commercial_use: bool + source_name: str + source_url: str + source_repo: Optional[str] + source_commit: Optional[str] + source_file: str + source_hash: str + attribution_text: Optional[str] + copyright_notice: Optional[str] + is_complete_document: bool + is_modular: bool + requires_customization: bool + placeholders: List[str] + training_allowed: bool + output_allowed: bool + modification_allowed: bool + distortion_prohibited: bool + + +@dataclass +class IngestionStatus: + """Status of a source ingestion.""" + source_name: str + status: str # "pending", "running", "completed", "failed" + documents_found: int = 0 + chunks_created: int = 0 + chunks_indexed: int = 0 + errors: List[str] = field(default_factory=list) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +def split_sentences(text: str) -> List[str]: + """Split text into sentences with basic abbreviation handling.""" + # Protect common abbreviations + abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no'] + protected = text + for abbr in abbreviations: + pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE) + protected = pattern.sub(abbr.replace('.', '') + '', protected) + + # Protect decimal numbers + protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) + + # Split on sentence endings + sentences = re.split(r'(?<=[.!?])\s+', protected) + + # Restore protected characters + result = [] + for s in sentences: + s = s.replace('', '.').replace('', '.').replace('', '.') + s = s.strip() + if s: + result.append(s) + + return result + + +def chunk_text( + text: str, + chunk_size: int = DEFAULT_CHUNK_SIZE, + overlap: int = DEFAULT_CHUNK_OVERLAP, +) -> List[str]: + """ + Split text into overlapping chunks. + Respects paragraph and sentence boundaries where possible. + """ + if not text: + return [] + + if len(text) <= chunk_size: + return [text.strip()] + + # Split into paragraphs first + paragraphs = text.split('\n\n') + chunks = [] + current_chunk: List[str] = [] + current_length = 0 + + for para in paragraphs: + para = para.strip() + if not para: + continue + + para_length = len(para) + + if para_length > chunk_size: + # Large paragraph: split by sentences + if current_chunk: + chunks.append('\n\n'.join(current_chunk)) + current_chunk = [] + current_length = 0 + + # Split long paragraph by sentences + sentences = split_sentences(para) + for sentence in sentences: + if current_length + len(sentence) + 1 > chunk_size: + if current_chunk: + chunks.append(' '.join(current_chunk)) + # Keep overlap + overlap_count = max(1, len(current_chunk) // 3) + current_chunk = current_chunk[-overlap_count:] + current_length = sum(len(s) + 1 for s in current_chunk) + current_chunk.append(sentence) + current_length += len(sentence) + 1 + + elif current_length + para_length + 2 > chunk_size: + # Paragraph would exceed chunk size + if current_chunk: + chunks.append('\n\n'.join(current_chunk)) + current_chunk = [] + current_length = 0 + current_chunk.append(para) + current_length = para_length + + else: + current_chunk.append(para) + current_length += para_length + 2 + + # Add final chunk + if current_chunk: + chunks.append('\n\n'.join(current_chunk)) + + return [c.strip() for c in chunks if c.strip()] + + +def infer_template_type(doc: ExtractedDocument, source: SourceConfig) -> str: + """Infer the template type from document content and metadata.""" + text_lower = doc.text.lower() + title_lower = doc.title.lower() + + # Check known indicators + type_indicators = { + "privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"], + "terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"], + "cookie_banner": ["cookie", "cookies", "tracking"], + "impressum": ["impressum", "legal notice", "imprint"], + "widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"], + "dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"], + "sla": ["service level", "availability", "uptime"], + "nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"], + "community_guidelines": ["community", "guidelines", "conduct", "verhaltens"], + "acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"], + } + + for template_type, indicators in type_indicators.items(): + for indicator in indicators: + if indicator in text_lower or indicator in title_lower: + return template_type + + # Fall back to source's first template type + if source.template_types: + return source.template_types[0] + + return "clause" # Generic fallback + + +def infer_clause_category(text: str) -> Optional[str]: + """Infer the clause category from text content.""" + text_lower = text.lower() + + categories = { + "haftung": ["haftung", "liability", "haftungsausschluss", "limitation"], + "datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"], + "widerruf": ["widerruf", "cancellation", "withdrawal"], + "gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"], + "kuendigung": ["kuendigung", "termination", "beendigung"], + "zahlung": ["zahlung", "payment", "preis", "price"], + "gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"], + "aenderungen": ["aenderung", "modification", "amendment"], + "schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"], + } + + for category, indicators in categories.items(): + for indicator in indicators: + if indicator in text_lower: + return category + + return None + + +def create_chunks( + doc: ExtractedDocument, + source: SourceConfig, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, +) -> List[TemplateChunk]: + """Create template chunks from an extracted document.""" + license_info = source.license_info + template_type = infer_template_type(doc, source) + + # Chunk the text + text_chunks = chunk_text(doc.text, chunk_size, chunk_overlap) + + chunks = [] + for i, chunk_text_str in enumerate(text_chunks): + # Determine if this is a complete document or a clause + is_complete = len(text_chunks) == 1 and len(chunk_text_str) > 500 + is_modular = len(doc.sections) > 0 or '##' in doc.text + requires_customization = len(doc.placeholders) > 0 + + # Generate attribution text + attribution_text = None + if license_info.attribution_required: + attribution_text = license_info.get_attribution_text( + source.name, + doc.source_url or source.get_source_url() + ) + + chunk = TemplateChunk( + text=chunk_text_str, + chunk_index=i, + document_title=doc.title, + template_type=template_type, + clause_category=infer_clause_category(chunk_text_str), + language=doc.language, + jurisdiction=source.jurisdiction, + license_id=license_info.id.value, + license_name=license_info.name, + license_url=license_info.url, + attribution_required=license_info.attribution_required, + share_alike=license_info.share_alike, + no_derivatives=license_info.no_derivatives, + commercial_use=license_info.commercial_use, + source_name=source.name, + source_url=doc.source_url or source.get_source_url(), + source_repo=source.repo_url, + source_commit=doc.source_commit, + source_file=doc.file_path, + source_hash=doc.source_hash, + attribution_text=attribution_text, + copyright_notice=None, + is_complete_document=is_complete, + is_modular=is_modular, + requires_customization=requires_customization, + placeholders=doc.placeholders, + training_allowed=license_info.training_allowed, + output_allowed=license_info.output_allowed, + modification_allowed=license_info.modification_allowed, + distortion_prohibited=license_info.distortion_prohibited, + ) + chunks.append(chunk) + + return chunks diff --git a/klausur-service/backend/legal_templates_cli.py b/klausur-service/backend/legal_templates_cli.py new file mode 100644 index 0000000..f30b546 --- /dev/null +++ b/klausur-service/backend/legal_templates_cli.py @@ -0,0 +1,165 @@ +""" +Legal Templates CLI — command-line entry point for ingestion and search. + +Extracted from legal_templates_ingestion.py to keep files under 500 LOC. + +Usage: + python legal_templates_cli.py --ingest-all + python legal_templates_cli.py --ingest-source github-site-policy + python legal_templates_cli.py --status + python legal_templates_cli.py --search "Datenschutzerklaerung" + +Lizenz: Apache 2.0 +""" + +import asyncio +import json + +from template_sources import TEMPLATE_SOURCES, LicenseType +from legal_templates_ingestion import LegalTemplatesIngestion + + +async def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="Legal Templates Ingestion") + parser.add_argument( + "--ingest-all", + action="store_true", + help="Ingest all enabled sources" + ) + parser.add_argument( + "--ingest-source", + type=str, + metavar="NAME", + help="Ingest a specific source by name" + ) + parser.add_argument( + "--ingest-license", + type=str, + choices=["cc0", "mit", "cc_by_4", "public_domain"], + help="Ingest all sources of a specific license type" + ) + parser.add_argument( + "--max-priority", + type=int, + default=3, + help="Maximum priority level to ingest (1=highest, 5=lowest)" + ) + parser.add_argument( + "--status", + action="store_true", + help="Show collection status" + ) + parser.add_argument( + "--search", + type=str, + metavar="QUERY", + help="Test search query" + ) + parser.add_argument( + "--template-type", + type=str, + help="Filter search by template type" + ) + parser.add_argument( + "--language", + type=str, + help="Filter search by language" + ) + parser.add_argument( + "--reset", + action="store_true", + help="Reset (delete and recreate) the collection" + ) + parser.add_argument( + "--delete-source", + type=str, + metavar="NAME", + help="Delete all chunks from a source" + ) + + args = parser.parse_args() + + ingestion = LegalTemplatesIngestion() + + try: + if args.reset: + ingestion.reset_collection() + print("Collection reset successfully") + + elif args.delete_source: + count = ingestion.delete_source(args.delete_source) + print(f"Deleted {count} chunks from {args.delete_source}") + + elif args.status: + status = ingestion.get_status() + print(json.dumps(status, indent=2, default=str)) + + elif args.ingest_all: + print(f"Ingesting all sources (max priority: {args.max_priority})...") + results = await ingestion.ingest_all(max_priority=args.max_priority) + print("\nResults:") + for name, status in results.items(): + print(f" {name}: {status.chunks_indexed} chunks ({status.status})") + if status.errors: + for error in status.errors: + print(f" ERROR: {error}") + total = sum(s.chunks_indexed for s in results.values()) + print(f"\nTotal: {total} chunks indexed") + + elif args.ingest_source: + source = next( + (s for s in TEMPLATE_SOURCES if s.name == args.ingest_source), + None + ) + if not source: + print(f"Unknown source: {args.ingest_source}") + print("Available sources:") + for s in TEMPLATE_SOURCES: + print(f" - {s.name}") + return + + print(f"Ingesting: {source.name}") + status = await ingestion.ingest_source(source) + print(f"\nResult: {status.chunks_indexed} chunks ({status.status})") + if status.errors: + for error in status.errors: + print(f" ERROR: {error}") + + elif args.ingest_license: + license_type = LicenseType(args.ingest_license) + print(f"Ingesting all {license_type.value} sources...") + results = await ingestion.ingest_by_license(license_type) + print("\nResults:") + for name, status in results.items(): + print(f" {name}: {status.chunks_indexed} chunks ({status.status})") + + elif args.search: + print(f"Searching: {args.search}") + results = await ingestion.search( + args.search, + template_type=args.template_type, + language=args.language, + ) + print(f"\nFound {len(results)} results:") + for i, result in enumerate(results, 1): + print(f"\n{i}. [{result['template_type']}] {result['document_title']}") + print(f" Score: {result['score']:.3f}") + print(f" License: {result['license_name']}") + print(f" Source: {result['source_name']}") + print(f" Language: {result['language']}") + if result['attribution_required']: + print(f" Attribution: {result['attribution_text']}") + print(f" Text: {result['text'][:200]}...") + + else: + parser.print_help() + + finally: + await ingestion.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/klausur-service/backend/legal_templates_ingestion.py b/klausur-service/backend/legal_templates_ingestion.py index 16580cd..c1c79fa 100644 --- a/klausur-service/backend/legal_templates_ingestion.py +++ b/klausur-service/backend/legal_templates_ingestion.py @@ -8,18 +8,16 @@ proper attribution tracking. Collection: bp_legal_templates Usage: - python legal_templates_ingestion.py --ingest-all - python legal_templates_ingestion.py --ingest-source github-site-policy - python legal_templates_ingestion.py --status - python legal_templates_ingestion.py --search "Datenschutzerklaerung" + python legal_templates_cli.py --ingest-all + python legal_templates_cli.py --ingest-source github-site-policy + python legal_templates_cli.py --status + python legal_templates_cli.py --search "Datenschutzerklaerung" """ import asyncio import hashlib -import json import logging import os -from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -50,6 +48,17 @@ from github_crawler import ( RepositoryDownloader, ) +# Re-export from chunking module for backward compatibility +from legal_templates_chunking import ( # noqa: F401 + IngestionStatus, + TemplateChunk, + chunk_text, + create_chunks, + infer_clause_category, + infer_template_type, + split_sentences, +) + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -78,54 +87,6 @@ MAX_RETRIES = 3 RETRY_DELAY = 3.0 -@dataclass -class IngestionStatus: - """Status of a source ingestion.""" - source_name: str - status: str # "pending", "running", "completed", "failed" - documents_found: int = 0 - chunks_created: int = 0 - chunks_indexed: int = 0 - errors: List[str] = field(default_factory=list) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - -@dataclass -class TemplateChunk: - """A chunk of template text ready for indexing.""" - text: str - chunk_index: int - document_title: str - template_type: str - clause_category: Optional[str] - language: str - jurisdiction: str - license_id: str - license_name: str - license_url: str - attribution_required: bool - share_alike: bool - no_derivatives: bool - commercial_use: bool - source_name: str - source_url: str - source_repo: Optional[str] - source_commit: Optional[str] - source_file: str - source_hash: str - attribution_text: Optional[str] - copyright_notice: Optional[str] - is_complete_document: bool - is_modular: bool - requires_customization: bool - placeholders: List[str] - training_allowed: bool - output_allowed: bool - modification_allowed: bool - distortion_prohibited: bool - - class LegalTemplatesIngestion: """Handles ingestion of legal templates into Qdrant.""" @@ -168,212 +129,6 @@ class LegalTemplatesIngestion: logger.error(f"Embedding generation failed: {e}") raise - def _chunk_text(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: - """ - Split text into overlapping chunks. - Respects paragraph and sentence boundaries where possible. - """ - if not text: - return [] - - if len(text) <= chunk_size: - return [text.strip()] - - # Split into paragraphs first - paragraphs = text.split('\n\n') - chunks = [] - current_chunk = [] - current_length = 0 - - for para in paragraphs: - para = para.strip() - if not para: - continue - - para_length = len(para) - - if para_length > chunk_size: - # Large paragraph: split by sentences - if current_chunk: - chunks.append('\n\n'.join(current_chunk)) - current_chunk = [] - current_length = 0 - - # Split long paragraph by sentences - sentences = self._split_sentences(para) - for sentence in sentences: - if current_length + len(sentence) + 1 > chunk_size: - if current_chunk: - chunks.append(' '.join(current_chunk)) - # Keep overlap - overlap_count = max(1, len(current_chunk) // 3) - current_chunk = current_chunk[-overlap_count:] - current_length = sum(len(s) + 1 for s in current_chunk) - current_chunk.append(sentence) - current_length += len(sentence) + 1 - - elif current_length + para_length + 2 > chunk_size: - # Paragraph would exceed chunk size - if current_chunk: - chunks.append('\n\n'.join(current_chunk)) - current_chunk = [] - current_length = 0 - current_chunk.append(para) - current_length = para_length - - else: - current_chunk.append(para) - current_length += para_length + 2 - - # Add final chunk - if current_chunk: - chunks.append('\n\n'.join(current_chunk)) - - return [c.strip() for c in chunks if c.strip()] - - def _split_sentences(self, text: str) -> List[str]: - """Split text into sentences with basic abbreviation handling.""" - import re - - # Protect common abbreviations - abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no'] - protected = text - for abbr in abbreviations: - pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE) - protected = pattern.sub(abbr.replace('.', '') + '', protected) - - # Protect decimal numbers - protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) - - # Split on sentence endings - sentences = re.split(r'(?<=[.!?])\s+', protected) - - # Restore protected characters - result = [] - for s in sentences: - s = s.replace('', '.').replace('', '.').replace('', '.') - s = s.strip() - if s: - result.append(s) - - return result - - def _infer_template_type(self, doc: ExtractedDocument, source: SourceConfig) -> str: - """Infer the template type from document content and metadata.""" - text_lower = doc.text.lower() - title_lower = doc.title.lower() - - # Check known indicators - type_indicators = { - "privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"], - "terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"], - "cookie_banner": ["cookie", "cookies", "tracking"], - "impressum": ["impressum", "legal notice", "imprint"], - "widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"], - "dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"], - "sla": ["service level", "availability", "uptime"], - "nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"], - "community_guidelines": ["community", "guidelines", "conduct", "verhaltens"], - "acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"], - } - - for template_type, indicators in type_indicators.items(): - for indicator in indicators: - if indicator in text_lower or indicator in title_lower: - return template_type - - # Fall back to source's first template type - if source.template_types: - return source.template_types[0] - - return "clause" # Generic fallback - - def _infer_clause_category(self, text: str) -> Optional[str]: - """Infer the clause category from text content.""" - text_lower = text.lower() - - categories = { - "haftung": ["haftung", "liability", "haftungsausschluss", "limitation"], - "datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"], - "widerruf": ["widerruf", "cancellation", "withdrawal"], - "gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"], - "kuendigung": ["kuendigung", "termination", "beendigung"], - "zahlung": ["zahlung", "payment", "preis", "price"], - "gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"], - "aenderungen": ["aenderung", "modification", "amendment"], - "schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"], - } - - for category, indicators in categories.items(): - for indicator in indicators: - if indicator in text_lower: - return category - - return None - - def _create_chunks( - self, - doc: ExtractedDocument, - source: SourceConfig, - ) -> List[TemplateChunk]: - """Create template chunks from an extracted document.""" - license_info = source.license_info - template_type = self._infer_template_type(doc, source) - - # Chunk the text - text_chunks = self._chunk_text(doc.text) - - chunks = [] - for i, chunk_text in enumerate(text_chunks): - # Determine if this is a complete document or a clause - is_complete = len(text_chunks) == 1 and len(chunk_text) > 500 - is_modular = len(doc.sections) > 0 or '##' in doc.text - requires_customization = len(doc.placeholders) > 0 - - # Generate attribution text - attribution_text = None - if license_info.attribution_required: - attribution_text = license_info.get_attribution_text( - source.name, - doc.source_url or source.get_source_url() - ) - - chunk = TemplateChunk( - text=chunk_text, - chunk_index=i, - document_title=doc.title, - template_type=template_type, - clause_category=self._infer_clause_category(chunk_text), - language=doc.language, - jurisdiction=source.jurisdiction, - license_id=license_info.id.value, - license_name=license_info.name, - license_url=license_info.url, - attribution_required=license_info.attribution_required, - share_alike=license_info.share_alike, - no_derivatives=license_info.no_derivatives, - commercial_use=license_info.commercial_use, - source_name=source.name, - source_url=doc.source_url or source.get_source_url(), - source_repo=source.repo_url, - source_commit=doc.source_commit, - source_file=doc.file_path, - source_hash=doc.source_hash, - attribution_text=attribution_text, - copyright_notice=None, # Could be extracted from doc if present - is_complete_document=is_complete, - is_modular=is_modular, - requires_customization=requires_customization, - placeholders=doc.placeholders, - training_allowed=license_info.training_allowed, - output_allowed=license_info.output_allowed, - modification_allowed=license_info.modification_allowed, - distortion_prohibited=license_info.distortion_prohibited, - ) - chunks.append(chunk) - - return chunks - async def ingest_source(self, source: SourceConfig) -> IngestionStatus: """Ingest a single source into Qdrant.""" status = IngestionStatus( @@ -405,7 +160,7 @@ class LegalTemplatesIngestion: # Create chunks from all documents all_chunks: List[TemplateChunk] = [] for doc in documents: - chunks = self._create_chunks(doc, source) + chunks = create_chunks(doc, source, CHUNK_SIZE, CHUNK_OVERLAP) all_chunks.extend(chunks) status.chunks_created += len(chunks) @@ -637,21 +392,7 @@ class LegalTemplatesIngestion: attribution_required: Optional[bool] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: - """ - Search the legal templates collection. - - Args: - query: Search query text - template_type: Filter by template type (e.g., "privacy_policy") - license_types: Filter by license types (e.g., ["cc0", "mit"]) - language: Filter by language (e.g., "de") - jurisdiction: Filter by jurisdiction (e.g., "DE") - attribution_required: Filter by attribution requirement - top_k: Number of results to return - - Returns: - List of search results with full metadata - """ + """Search the legal templates collection.""" # Generate query embedding embeddings = await self._generate_embeddings([query]) query_vector = embeddings[0] @@ -661,45 +402,27 @@ class LegalTemplatesIngestion: if template_type: must_conditions.append( - FieldCondition( - key="template_type", - match=MatchValue(value=template_type), - ) + FieldCondition(key="template_type", match=MatchValue(value=template_type)) ) - if language: must_conditions.append( - FieldCondition( - key="language", - match=MatchValue(value=language), - ) + FieldCondition(key="language", match=MatchValue(value=language)) ) - if jurisdiction: must_conditions.append( - FieldCondition( - key="jurisdiction", - match=MatchValue(value=jurisdiction), - ) + FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction)) ) - if attribution_required is not None: must_conditions.append( - FieldCondition( - key="attribution_required", - match=MatchValue(value=attribution_required), - ) + FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required)) ) # License type filter (OR condition) should_conditions = [] if license_types: - for license_type in license_types: + for lt in license_types: should_conditions.append( - FieldCondition( - key="license_id", - match=MatchValue(value=license_type), - ) + FieldCondition(key="license_id", match=MatchValue(value=lt)) ) # Construct filter @@ -747,196 +470,31 @@ class LegalTemplatesIngestion: def delete_source(self, source_name: str) -> int: """Delete all chunks from a specific source.""" - # First count how many we're deleting count_result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( - must=[ - FieldCondition( - key="source_name", - match=MatchValue(value=source_name), - ) - ] + must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ), ) - - # Delete by filter self.qdrant.delete( collection_name=LEGAL_TEMPLATES_COLLECTION, points_selector=Filter( - must=[ - FieldCondition( - key="source_name", - match=MatchValue(value=source_name), - ) - ] + must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ), ) - return count_result.count def reset_collection(self): """Delete and recreate the collection.""" logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}") - - # Delete collection try: self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION) except Exception: - pass # Collection might not exist - - # Recreate + pass self._ensure_collection() self._ingestion_status.clear() - logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset") async def close(self): """Close HTTP client.""" await self.http_client.aclose() - - -async def main(): - """CLI entry point.""" - import argparse - - parser = argparse.ArgumentParser(description="Legal Templates Ingestion") - parser.add_argument( - "--ingest-all", - action="store_true", - help="Ingest all enabled sources" - ) - parser.add_argument( - "--ingest-source", - type=str, - metavar="NAME", - help="Ingest a specific source by name" - ) - parser.add_argument( - "--ingest-license", - type=str, - choices=["cc0", "mit", "cc_by_4", "public_domain"], - help="Ingest all sources of a specific license type" - ) - parser.add_argument( - "--max-priority", - type=int, - default=3, - help="Maximum priority level to ingest (1=highest, 5=lowest)" - ) - parser.add_argument( - "--status", - action="store_true", - help="Show collection status" - ) - parser.add_argument( - "--search", - type=str, - metavar="QUERY", - help="Test search query" - ) - parser.add_argument( - "--template-type", - type=str, - help="Filter search by template type" - ) - parser.add_argument( - "--language", - type=str, - help="Filter search by language" - ) - parser.add_argument( - "--reset", - action="store_true", - help="Reset (delete and recreate) the collection" - ) - parser.add_argument( - "--delete-source", - type=str, - metavar="NAME", - help="Delete all chunks from a source" - ) - - args = parser.parse_args() - - ingestion = LegalTemplatesIngestion() - - try: - if args.reset: - ingestion.reset_collection() - print("Collection reset successfully") - - elif args.delete_source: - count = ingestion.delete_source(args.delete_source) - print(f"Deleted {count} chunks from {args.delete_source}") - - elif args.status: - status = ingestion.get_status() - print(json.dumps(status, indent=2, default=str)) - - elif args.ingest_all: - print(f"Ingesting all sources (max priority: {args.max_priority})...") - results = await ingestion.ingest_all(max_priority=args.max_priority) - print("\nResults:") - for name, status in results.items(): - print(f" {name}: {status.chunks_indexed} chunks ({status.status})") - if status.errors: - for error in status.errors: - print(f" ERROR: {error}") - total = sum(s.chunks_indexed for s in results.values()) - print(f"\nTotal: {total} chunks indexed") - - elif args.ingest_source: - source = next( - (s for s in TEMPLATE_SOURCES if s.name == args.ingest_source), - None - ) - if not source: - print(f"Unknown source: {args.ingest_source}") - print("Available sources:") - for s in TEMPLATE_SOURCES: - print(f" - {s.name}") - return - - print(f"Ingesting: {source.name}") - status = await ingestion.ingest_source(source) - print(f"\nResult: {status.chunks_indexed} chunks ({status.status})") - if status.errors: - for error in status.errors: - print(f" ERROR: {error}") - - elif args.ingest_license: - license_type = LicenseType(args.ingest_license) - print(f"Ingesting all {license_type.value} sources...") - results = await ingestion.ingest_by_license(license_type) - print("\nResults:") - for name, status in results.items(): - print(f" {name}: {status.chunks_indexed} chunks ({status.status})") - - elif args.search: - print(f"Searching: {args.search}") - results = await ingestion.search( - args.search, - template_type=args.template_type, - language=args.language, - ) - print(f"\nFound {len(results)} results:") - for i, result in enumerate(results, 1): - print(f"\n{i}. [{result['template_type']}] {result['document_title']}") - print(f" Score: {result['score']:.3f}") - print(f" License: {result['license_name']}") - print(f" Source: {result['source_name']}") - print(f" Language: {result['language']}") - if result['attribution_required']: - print(f" Attribution: {result['attribution_text']}") - print(f" Text: {result['text'][:200]}...") - - else: - parser.print_help() - - finally: - await ingestion.close() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/klausur-service/backend/mail/mail_db.py b/klausur-service/backend/mail/mail_db.py index 36f5c44..8ddcce2 100644 --- a/klausur-service/backend/mail/mail_db.py +++ b/klausur-service/backend/mail/mail_db.py @@ -1,987 +1,70 @@ """ Unified Inbox Mail Database Service -PostgreSQL database operations for multi-account mail aggregation. +Barrel re-export -- the actual logic lives in: +- mail_db_pool.py: Connection pool and schema initialization +- mail_db_accounts.py: Email account CRUD +- mail_db_emails.py: Aggregated email operations +- mail_db_tasks.py: Inbox task operations +- mail_db_stats.py: Statistics and audit log """ -import os -import json -import uuid -from typing import Optional, List, Dict, Any -from datetime import datetime, timedelta - -# Database Configuration - from Vault or environment (test default for CI) -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test") - -# Flag to check if using test defaults -_DB_CONFIGURED = DATABASE_URL != "postgresql://test:test@localhost:5432/test" - -# Connection pool (shared with metrics_db) -_pool = None - - -async def get_pool(): - """Get or create database connection pool.""" - global _pool - if _pool is None: - try: - import asyncpg - _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) - except ImportError: - print("Warning: asyncpg not installed. Mail database disabled.") - return None - except Exception as e: - print(f"Warning: Failed to connect to PostgreSQL: {e}") - return None - return _pool - - -async def init_mail_tables() -> bool: - """Initialize mail tables in PostgreSQL.""" - pool = await get_pool() - if pool is None: - return False - - create_tables_sql = """ - -- ============================================================================= - -- External Email Accounts - -- ============================================================================= - CREATE TABLE IF NOT EXISTS external_email_accounts ( - id VARCHAR(36) PRIMARY KEY, - user_id VARCHAR(36) NOT NULL, - tenant_id VARCHAR(36) NOT NULL, - email VARCHAR(255) NOT NULL, - display_name VARCHAR(255), - account_type VARCHAR(50) DEFAULT 'personal', - - -- IMAP Settings (password stored in Vault) - imap_host VARCHAR(255) NOT NULL, - imap_port INTEGER DEFAULT 993, - imap_ssl BOOLEAN DEFAULT TRUE, - - -- SMTP Settings - smtp_host VARCHAR(255) NOT NULL, - smtp_port INTEGER DEFAULT 465, - smtp_ssl BOOLEAN DEFAULT TRUE, - - -- Vault path for credentials - vault_path VARCHAR(500), - - -- Status tracking - status VARCHAR(20) DEFAULT 'pending', - last_sync TIMESTAMP, - sync_error TEXT, - email_count INTEGER DEFAULT 0, - unread_count INTEGER DEFAULT 0, - - -- Timestamps - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW(), - - -- Constraints - UNIQUE(user_id, email) - ); - - CREATE INDEX IF NOT EXISTS idx_mail_accounts_user ON external_email_accounts(user_id); - CREATE INDEX IF NOT EXISTS idx_mail_accounts_tenant ON external_email_accounts(tenant_id); - CREATE INDEX IF NOT EXISTS idx_mail_accounts_status ON external_email_accounts(status); - - -- ============================================================================= - -- Aggregated Emails - -- ============================================================================= - CREATE TABLE IF NOT EXISTS aggregated_emails ( - id VARCHAR(36) PRIMARY KEY, - account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE, - user_id VARCHAR(36) NOT NULL, - tenant_id VARCHAR(36) NOT NULL, - - -- Email identification - message_id VARCHAR(500) NOT NULL, - folder VARCHAR(100) DEFAULT 'INBOX', - - -- Email content - subject TEXT, - sender_email VARCHAR(255), - sender_name VARCHAR(255), - recipients JSONB DEFAULT '[]', - cc JSONB DEFAULT '[]', - body_preview TEXT, - body_text TEXT, - body_html TEXT, - has_attachments BOOLEAN DEFAULT FALSE, - attachments JSONB DEFAULT '[]', - headers JSONB DEFAULT '{}', - - -- Status flags - is_read BOOLEAN DEFAULT FALSE, - is_starred BOOLEAN DEFAULT FALSE, - is_deleted BOOLEAN DEFAULT FALSE, - - -- Dates - date_sent TIMESTAMP, - date_received TIMESTAMP, - - -- AI enrichment - category VARCHAR(50), - sender_type VARCHAR(50), - sender_authority_name VARCHAR(255), - detected_deadlines JSONB DEFAULT '[]', - suggested_priority VARCHAR(20), - ai_summary TEXT, - ai_analyzed_at TIMESTAMP, - - created_at TIMESTAMP DEFAULT NOW(), - - -- Prevent duplicate imports - UNIQUE(account_id, message_id) - ); - - CREATE INDEX IF NOT EXISTS idx_emails_account ON aggregated_emails(account_id); - CREATE INDEX IF NOT EXISTS idx_emails_user ON aggregated_emails(user_id); - CREATE INDEX IF NOT EXISTS idx_emails_tenant ON aggregated_emails(tenant_id); - CREATE INDEX IF NOT EXISTS idx_emails_date ON aggregated_emails(date_received DESC); - CREATE INDEX IF NOT EXISTS idx_emails_category ON aggregated_emails(category); - CREATE INDEX IF NOT EXISTS idx_emails_unread ON aggregated_emails(is_read) WHERE is_read = FALSE; - CREATE INDEX IF NOT EXISTS idx_emails_starred ON aggregated_emails(is_starred) WHERE is_starred = TRUE; - CREATE INDEX IF NOT EXISTS idx_emails_sender ON aggregated_emails(sender_email); - - -- ============================================================================= - -- Inbox Tasks (Arbeitsvorrat) - -- ============================================================================= - CREATE TABLE IF NOT EXISTS inbox_tasks ( - id VARCHAR(36) PRIMARY KEY, - user_id VARCHAR(36) NOT NULL, - tenant_id VARCHAR(36) NOT NULL, - email_id VARCHAR(36) REFERENCES aggregated_emails(id) ON DELETE SET NULL, - account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE SET NULL, - - -- Task content - title VARCHAR(500) NOT NULL, - description TEXT, - priority VARCHAR(20) DEFAULT 'medium', - status VARCHAR(20) DEFAULT 'pending', - deadline TIMESTAMP, - - -- Source information - source_email_subject TEXT, - source_sender VARCHAR(255), - source_sender_type VARCHAR(50), - - -- AI extraction info - ai_extracted BOOLEAN DEFAULT FALSE, - confidence_score FLOAT, - - -- Completion tracking - completed_at TIMESTAMP, - reminder_at TIMESTAMP, - - -- Timestamps - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_tasks_user ON inbox_tasks(user_id); - CREATE INDEX IF NOT EXISTS idx_tasks_tenant ON inbox_tasks(tenant_id); - CREATE INDEX IF NOT EXISTS idx_tasks_status ON inbox_tasks(status); - CREATE INDEX IF NOT EXISTS idx_tasks_deadline ON inbox_tasks(deadline) WHERE deadline IS NOT NULL; - CREATE INDEX IF NOT EXISTS idx_tasks_priority ON inbox_tasks(priority); - CREATE INDEX IF NOT EXISTS idx_tasks_email ON inbox_tasks(email_id) WHERE email_id IS NOT NULL; - - -- ============================================================================= - -- Email Templates - -- ============================================================================= - CREATE TABLE IF NOT EXISTS email_templates ( - id VARCHAR(36) PRIMARY KEY, - user_id VARCHAR(36), -- NULL for system templates - tenant_id VARCHAR(36), - - name VARCHAR(255) NOT NULL, - category VARCHAR(100), - subject_template TEXT, - body_template TEXT, - variables JSONB DEFAULT '[]', - - is_system BOOLEAN DEFAULT FALSE, - usage_count INTEGER DEFAULT 0, - - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_templates_user ON email_templates(user_id); - CREATE INDEX IF NOT EXISTS idx_templates_tenant ON email_templates(tenant_id); - CREATE INDEX IF NOT EXISTS idx_templates_system ON email_templates(is_system); - - -- ============================================================================= - -- Mail Audit Log - -- ============================================================================= - CREATE TABLE IF NOT EXISTS mail_audit_log ( - id VARCHAR(36) PRIMARY KEY, - user_id VARCHAR(36) NOT NULL, - tenant_id VARCHAR(36), - action VARCHAR(100) NOT NULL, - entity_type VARCHAR(50), -- account, email, task - entity_id VARCHAR(36), - details JSONB, - ip_address VARCHAR(45), - user_agent TEXT, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_mail_audit_user ON mail_audit_log(user_id); - CREATE INDEX IF NOT EXISTS idx_mail_audit_created ON mail_audit_log(created_at DESC); - CREATE INDEX IF NOT EXISTS idx_mail_audit_action ON mail_audit_log(action); - - -- ============================================================================= - -- Sync Status Tracking - -- ============================================================================= - CREATE TABLE IF NOT EXISTS mail_sync_status ( - id VARCHAR(36) PRIMARY KEY, - account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE, - folder VARCHAR(100), - last_uid INTEGER DEFAULT 0, - last_sync TIMESTAMP, - sync_errors INTEGER DEFAULT 0, - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW(), - - UNIQUE(account_id, folder) - ); - """ - - try: - async with pool.acquire() as conn: - await conn.execute(create_tables_sql) - print("Mail tables initialized successfully") - return True - except Exception as e: - print(f"Failed to initialize mail tables: {e}") - return False - - -# ============================================================================= -# Email Account Operations -# ============================================================================= - -async def create_email_account( - user_id: str, - tenant_id: str, - email: str, - display_name: str, - account_type: str, - imap_host: str, - imap_port: int, - imap_ssl: bool, - smtp_host: str, - smtp_port: int, - smtp_ssl: bool, - vault_path: str, -) -> Optional[str]: - """Create a new email account. Returns the account ID.""" - pool = await get_pool() - if pool is None: - return None - - account_id = str(uuid.uuid4()) - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO external_email_accounts - (id, user_id, tenant_id, email, display_name, account_type, - imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - """, - account_id, user_id, tenant_id, email, display_name, account_type, - imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path - ) - return account_id - except Exception as e: - print(f"Failed to create email account: {e}") - return None - - -async def get_email_accounts( - user_id: str, - tenant_id: Optional[str] = None, -) -> List[Dict]: - """Get all email accounts for a user.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - if tenant_id: - rows = await conn.fetch( - """ - SELECT * FROM external_email_accounts - WHERE user_id = $1 AND tenant_id = $2 - ORDER BY created_at - """, - user_id, tenant_id - ) - else: - rows = await conn.fetch( - """ - SELECT * FROM external_email_accounts - WHERE user_id = $1 - ORDER BY created_at - """, - user_id - ) - return [dict(r) for r in rows] - except Exception as e: - print(f"Failed to get email accounts: {e}") - return [] - - -async def get_email_account(account_id: str, user_id: str) -> Optional[Dict]: - """Get a single email account.""" - pool = await get_pool() - if pool is None: - return None - - try: - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - SELECT * FROM external_email_accounts - WHERE id = $1 AND user_id = $2 - """, - account_id, user_id - ) - return dict(row) if row else None - except Exception as e: - print(f"Failed to get email account: {e}") - return None - - -async def update_account_status( - account_id: str, - status: str, - sync_error: Optional[str] = None, - email_count: Optional[int] = None, - unread_count: Optional[int] = None, -) -> bool: - """Update account sync status.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE external_email_accounts SET - status = $2, - sync_error = $3, - email_count = COALESCE($4, email_count), - unread_count = COALESCE($5, unread_count), - last_sync = NOW(), - updated_at = NOW() - WHERE id = $1 - """, - account_id, status, sync_error, email_count, unread_count - ) - return True - except Exception as e: - print(f"Failed to update account status: {e}") - return False - - -async def delete_email_account(account_id: str, user_id: str) -> bool: - """Delete an email account (cascades to emails).""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - result = await conn.execute( - """ - DELETE FROM external_email_accounts - WHERE id = $1 AND user_id = $2 - """, - account_id, user_id - ) - return "DELETE" in result - except Exception as e: - print(f"Failed to delete email account: {e}") - return False - - -# ============================================================================= -# Aggregated Email Operations -# ============================================================================= - -async def upsert_email( - account_id: str, - user_id: str, - tenant_id: str, - message_id: str, - subject: str, - sender_email: str, - sender_name: Optional[str], - recipients: List[str], - cc: List[str], - body_preview: Optional[str], - body_text: Optional[str], - body_html: Optional[str], - has_attachments: bool, - attachments: List[Dict], - headers: Dict, - folder: str, - date_sent: datetime, - date_received: datetime, -) -> Optional[str]: - """Insert or update an email. Returns the email ID.""" - pool = await get_pool() - if pool is None: - return None - - email_id = str(uuid.uuid4()) - try: - async with pool.acquire() as conn: - # Try insert, on conflict update (for re-sync scenarios) - row = await conn.fetchrow( - """ - INSERT INTO aggregated_emails - (id, account_id, user_id, tenant_id, message_id, subject, - sender_email, sender_name, recipients, cc, body_preview, - body_text, body_html, has_attachments, attachments, headers, - folder, date_sent, date_received) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) - ON CONFLICT (account_id, message_id) DO UPDATE SET - subject = EXCLUDED.subject, - is_read = EXCLUDED.is_read, - folder = EXCLUDED.folder - RETURNING id - """, - email_id, account_id, user_id, tenant_id, message_id, subject, - sender_email, sender_name, json.dumps(recipients), json.dumps(cc), - body_preview, body_text, body_html, has_attachments, - json.dumps(attachments), json.dumps(headers), folder, - date_sent, date_received - ) - return row['id'] if row else None - except Exception as e: - print(f"Failed to upsert email: {e}") - return None - - -async def get_unified_inbox( - user_id: str, - account_ids: Optional[List[str]] = None, - categories: Optional[List[str]] = None, - is_read: Optional[bool] = None, - is_starred: Optional[bool] = None, - limit: int = 50, - offset: int = 0, -) -> List[Dict]: - """Get unified inbox with filtering.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - # Build dynamic query - conditions = ["user_id = $1", "is_deleted = FALSE"] - params = [user_id] - param_idx = 2 - - if account_ids: - conditions.append(f"account_id = ANY(${param_idx})") - params.append(account_ids) - param_idx += 1 - - if categories: - conditions.append(f"category = ANY(${param_idx})") - params.append(categories) - param_idx += 1 - - if is_read is not None: - conditions.append(f"is_read = ${param_idx}") - params.append(is_read) - param_idx += 1 - - if is_starred is not None: - conditions.append(f"is_starred = ${param_idx}") - params.append(is_starred) - param_idx += 1 - - where_clause = " AND ".join(conditions) - params.extend([limit, offset]) - - query = f""" - SELECT e.*, a.email as account_email, a.display_name as account_name - FROM aggregated_emails e - JOIN external_email_accounts a ON e.account_id = a.id - WHERE {where_clause} - ORDER BY e.date_received DESC - LIMIT ${param_idx} OFFSET ${param_idx + 1} - """ - - rows = await conn.fetch(query, *params) - return [dict(r) for r in rows] - except Exception as e: - print(f"Failed to get unified inbox: {e}") - return [] - - -async def get_email(email_id: str, user_id: str) -> Optional[Dict]: - """Get a single email by ID.""" - pool = await get_pool() - if pool is None: - return None - - try: - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - SELECT e.*, a.email as account_email, a.display_name as account_name - FROM aggregated_emails e - JOIN external_email_accounts a ON e.account_id = a.id - WHERE e.id = $1 AND e.user_id = $2 - """, - email_id, user_id - ) - return dict(row) if row else None - except Exception as e: - print(f"Failed to get email: {e}") - return None - - -async def update_email_ai_analysis( - email_id: str, - category: str, - sender_type: str, - sender_authority_name: Optional[str], - detected_deadlines: List[Dict], - suggested_priority: str, - ai_summary: Optional[str], -) -> bool: - """Update email with AI analysis results.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE aggregated_emails SET - category = $2, - sender_type = $3, - sender_authority_name = $4, - detected_deadlines = $5, - suggested_priority = $6, - ai_summary = $7, - ai_analyzed_at = NOW() - WHERE id = $1 - """, - email_id, category, sender_type, sender_authority_name, - json.dumps(detected_deadlines), suggested_priority, ai_summary - ) - return True - except Exception as e: - print(f"Failed to update email AI analysis: {e}") - return False - - -async def mark_email_read(email_id: str, user_id: str, is_read: bool = True) -> bool: - """Mark email as read/unread.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE aggregated_emails SET is_read = $3 - WHERE id = $1 AND user_id = $2 - """, - email_id, user_id, is_read - ) - return True - except Exception as e: - print(f"Failed to mark email read: {e}") - return False - - -async def mark_email_starred(email_id: str, user_id: str, is_starred: bool = True) -> bool: - """Mark email as starred/unstarred.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE aggregated_emails SET is_starred = $3 - WHERE id = $1 AND user_id = $2 - """, - email_id, user_id, is_starred - ) - return True - except Exception as e: - print(f"Failed to mark email starred: {e}") - return False - - -# ============================================================================= -# Inbox Task Operations -# ============================================================================= - -async def create_task( - user_id: str, - tenant_id: str, - title: str, - description: Optional[str] = None, - priority: str = "medium", - deadline: Optional[datetime] = None, - email_id: Optional[str] = None, - account_id: Optional[str] = None, - source_email_subject: Optional[str] = None, - source_sender: Optional[str] = None, - source_sender_type: Optional[str] = None, - ai_extracted: bool = False, - confidence_score: Optional[float] = None, -) -> Optional[str]: - """Create a new inbox task.""" - pool = await get_pool() - if pool is None: - return None - - task_id = str(uuid.uuid4()) - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO inbox_tasks - (id, user_id, tenant_id, title, description, priority, deadline, - email_id, account_id, source_email_subject, source_sender, - source_sender_type, ai_extracted, confidence_score) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) - """, - task_id, user_id, tenant_id, title, description, priority, deadline, - email_id, account_id, source_email_subject, source_sender, - source_sender_type, ai_extracted, confidence_score - ) - return task_id - except Exception as e: - print(f"Failed to create task: {e}") - return None - - -async def get_tasks( - user_id: str, - status: Optional[str] = None, - priority: Optional[str] = None, - include_completed: bool = False, - limit: int = 50, - offset: int = 0, -) -> List[Dict]: - """Get tasks for a user.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - conditions = ["user_id = $1"] - params = [user_id] - param_idx = 2 - - if not include_completed: - conditions.append("status != 'completed'") - - if status: - conditions.append(f"status = ${param_idx}") - params.append(status) - param_idx += 1 - - if priority: - conditions.append(f"priority = ${param_idx}") - params.append(priority) - param_idx += 1 - - where_clause = " AND ".join(conditions) - params.extend([limit, offset]) - - query = f""" - SELECT * FROM inbox_tasks - WHERE {where_clause} - ORDER BY - CASE priority - WHEN 'urgent' THEN 1 - WHEN 'high' THEN 2 - WHEN 'medium' THEN 3 - WHEN 'low' THEN 4 - END, - deadline ASC NULLS LAST, - created_at DESC - LIMIT ${param_idx} OFFSET ${param_idx + 1} - """ - - rows = await conn.fetch(query, *params) - return [dict(r) for r in rows] - except Exception as e: - print(f"Failed to get tasks: {e}") - return [] - - -async def get_task(task_id: str, user_id: str) -> Optional[Dict]: - """Get a single task.""" - pool = await get_pool() - if pool is None: - return None - - try: - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT * FROM inbox_tasks WHERE id = $1 AND user_id = $2", - task_id, user_id - ) - return dict(row) if row else None - except Exception as e: - print(f"Failed to get task: {e}") - return None - - -async def update_task( - task_id: str, - user_id: str, - title: Optional[str] = None, - description: Optional[str] = None, - priority: Optional[str] = None, - status: Optional[str] = None, - deadline: Optional[datetime] = None, -) -> bool: - """Update a task.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - # Build dynamic update - updates = ["updated_at = NOW()"] - params = [task_id, user_id] - param_idx = 3 - - if title is not None: - updates.append(f"title = ${param_idx}") - params.append(title) - param_idx += 1 - - if description is not None: - updates.append(f"description = ${param_idx}") - params.append(description) - param_idx += 1 - - if priority is not None: - updates.append(f"priority = ${param_idx}") - params.append(priority) - param_idx += 1 - - if status is not None: - updates.append(f"status = ${param_idx}") - params.append(status) - param_idx += 1 - if status == "completed": - updates.append("completed_at = NOW()") - - if deadline is not None: - updates.append(f"deadline = ${param_idx}") - params.append(deadline) - param_idx += 1 - - set_clause = ", ".join(updates) - await conn.execute( - f"UPDATE inbox_tasks SET {set_clause} WHERE id = $1 AND user_id = $2", - *params - ) - return True - except Exception as e: - print(f"Failed to update task: {e}") - return False - - -async def get_task_dashboard_stats(user_id: str) -> Dict: - """Get dashboard statistics for tasks.""" - pool = await get_pool() - if pool is None: - return {} - - try: - async with pool.acquire() as conn: - now = datetime.now() - today_end = now.replace(hour=23, minute=59, second=59) - week_end = now + timedelta(days=7) - - stats = await conn.fetchrow( - """ - SELECT - COUNT(*) as total_tasks, - COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks, - COUNT(*) FILTER (WHERE status = 'in_progress') as in_progress_tasks, - COUNT(*) FILTER (WHERE status = 'completed') as completed_tasks, - COUNT(*) FILTER (WHERE status != 'completed' AND deadline < $2) as overdue_tasks, - COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $3) as due_today, - COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $4) as due_this_week - FROM inbox_tasks - WHERE user_id = $1 - """, - user_id, now, today_end, week_end - ) - - by_priority = await conn.fetch( - """ - SELECT priority, COUNT(*) as count - FROM inbox_tasks - WHERE user_id = $1 AND status != 'completed' - GROUP BY priority - """, - user_id - ) - - by_sender = await conn.fetch( - """ - SELECT source_sender_type, COUNT(*) as count - FROM inbox_tasks - WHERE user_id = $1 AND status != 'completed' AND source_sender_type IS NOT NULL - GROUP BY source_sender_type - """, - user_id - ) - - return { - "total_tasks": stats['total_tasks'] or 0, - "pending_tasks": stats['pending_tasks'] or 0, - "in_progress_tasks": stats['in_progress_tasks'] or 0, - "completed_tasks": stats['completed_tasks'] or 0, - "overdue_tasks": stats['overdue_tasks'] or 0, - "due_today": stats['due_today'] or 0, - "due_this_week": stats['due_this_week'] or 0, - "by_priority": {r['priority']: r['count'] for r in by_priority}, - "by_sender_type": {r['source_sender_type']: r['count'] for r in by_sender}, - } - except Exception as e: - print(f"Failed to get task stats: {e}") - return {} - - -# ============================================================================= -# Statistics & Audit -# ============================================================================= - -async def get_mail_stats(user_id: str) -> Dict: - """Get overall mail statistics for a user.""" - pool = await get_pool() - if pool is None: - return {} - - try: - async with pool.acquire() as conn: - today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - - # Account stats - accounts = await conn.fetch( - """ - SELECT id, email, display_name, status, email_count, unread_count, last_sync - FROM external_email_accounts - WHERE user_id = $1 - """, - user_id - ) - - # Email counts - email_stats = await conn.fetchrow( - """ - SELECT - COUNT(*) as total_emails, - COUNT(*) FILTER (WHERE is_read = FALSE) as unread_emails, - COUNT(*) FILTER (WHERE date_received >= $2) as emails_today, - COUNT(*) FILTER (WHERE ai_analyzed_at >= $2) as ai_analyses_today - FROM aggregated_emails - WHERE user_id = $1 - """, - user_id, today - ) - - # Task counts - task_stats = await conn.fetchrow( - """ - SELECT - COUNT(*) as total_tasks, - COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks, - COUNT(*) FILTER (WHERE status != 'completed' AND deadline < NOW()) as overdue_tasks - FROM inbox_tasks - WHERE user_id = $1 - """, - user_id - ) - - return { - "total_accounts": len(accounts), - "active_accounts": sum(1 for a in accounts if a['status'] == 'active'), - "error_accounts": sum(1 for a in accounts if a['status'] == 'error'), - "total_emails": email_stats['total_emails'] or 0, - "unread_emails": email_stats['unread_emails'] or 0, - "total_tasks": task_stats['total_tasks'] or 0, - "pending_tasks": task_stats['pending_tasks'] or 0, - "overdue_tasks": task_stats['overdue_tasks'] or 0, - "emails_today": email_stats['emails_today'] or 0, - "ai_analyses_today": email_stats['ai_analyses_today'] or 0, - "per_account": [ - { - "id": a['id'], - "email": a['email'], - "display_name": a['display_name'], - "status": a['status'], - "email_count": a['email_count'], - "unread_count": a['unread_count'], - "last_sync": a['last_sync'].isoformat() if a['last_sync'] else None, - } - for a in accounts - ], - } - except Exception as e: - print(f"Failed to get mail stats: {e}") - return {} - - -async def log_mail_audit( - user_id: str, - action: str, - entity_type: Optional[str] = None, - entity_id: Optional[str] = None, - details: Optional[Dict] = None, - tenant_id: Optional[str] = None, - ip_address: Optional[str] = None, - user_agent: Optional[str] = None, -) -> bool: - """Log a mail action for audit trail.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO mail_audit_log - (id, user_id, tenant_id, action, entity_type, entity_id, details, ip_address, user_agent) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - str(uuid.uuid4()), user_id, tenant_id, action, entity_type, entity_id, - json.dumps(details) if details else None, ip_address, user_agent - ) - return True - except Exception as e: - print(f"Failed to log mail audit: {e}") - return False +from .mail_db_pool import get_pool, init_mail_tables + +from .mail_db_accounts import ( + create_email_account, + get_email_accounts, + get_email_account, + update_account_status, + delete_email_account, +) + +from .mail_db_emails import ( + upsert_email, + get_unified_inbox, + get_email, + update_email_ai_analysis, + mark_email_read, + mark_email_starred, +) + +from .mail_db_tasks import ( + create_task, + get_tasks, + get_task, + update_task, + get_task_dashboard_stats, +) + +from .mail_db_stats import ( + get_mail_stats, + log_mail_audit, +) + +__all__ = [ + # Pool + "get_pool", + "init_mail_tables", + # Accounts + "create_email_account", + "get_email_accounts", + "get_email_account", + "update_account_status", + "delete_email_account", + # Emails + "upsert_email", + "get_unified_inbox", + "get_email", + "update_email_ai_analysis", + "mark_email_read", + "mark_email_starred", + # Tasks + "create_task", + "get_tasks", + "get_task", + "update_task", + "get_task_dashboard_stats", + # Stats + "get_mail_stats", + "log_mail_audit", +] diff --git a/klausur-service/backend/mail/mail_db_accounts.py b/klausur-service/backend/mail/mail_db_accounts.py new file mode 100644 index 0000000..503dd5a --- /dev/null +++ b/klausur-service/backend/mail/mail_db_accounts.py @@ -0,0 +1,156 @@ +""" +Mail Database - Email Account Operations. +""" + +import uuid +from typing import Optional, List, Dict + +from .mail_db_pool import get_pool + + +async def create_email_account( + user_id: str, + tenant_id: str, + email: str, + display_name: str, + account_type: str, + imap_host: str, + imap_port: int, + imap_ssl: bool, + smtp_host: str, + smtp_port: int, + smtp_ssl: bool, + vault_path: str, +) -> Optional[str]: + """Create a new email account. Returns the account ID.""" + pool = await get_pool() + if pool is None: + return None + + account_id = str(uuid.uuid4()) + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO external_email_accounts + (id, user_id, tenant_id, email, display_name, account_type, + imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + """, + account_id, user_id, tenant_id, email, display_name, account_type, + imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path + ) + return account_id + except Exception as e: + print(f"Failed to create email account: {e}") + return None + + +async def get_email_accounts( + user_id: str, + tenant_id: Optional[str] = None, +) -> List[Dict]: + """Get all email accounts for a user.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + if tenant_id: + rows = await conn.fetch( + """ + SELECT * FROM external_email_accounts + WHERE user_id = $1 AND tenant_id = $2 + ORDER BY created_at + """, + user_id, tenant_id + ) + else: + rows = await conn.fetch( + """ + SELECT * FROM external_email_accounts + WHERE user_id = $1 + ORDER BY created_at + """, + user_id + ) + return [dict(r) for r in rows] + except Exception as e: + print(f"Failed to get email accounts: {e}") + return [] + + +async def get_email_account(account_id: str, user_id: str) -> Optional[Dict]: + """Get a single email account.""" + pool = await get_pool() + if pool is None: + return None + + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT * FROM external_email_accounts + WHERE id = $1 AND user_id = $2 + """, + account_id, user_id + ) + return dict(row) if row else None + except Exception as e: + print(f"Failed to get email account: {e}") + return None + + +async def update_account_status( + account_id: str, + status: str, + sync_error: Optional[str] = None, + email_count: Optional[int] = None, + unread_count: Optional[int] = None, +) -> bool: + """Update account sync status.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE external_email_accounts SET + status = $2, + sync_error = $3, + email_count = COALESCE($4, email_count), + unread_count = COALESCE($5, unread_count), + last_sync = NOW(), + updated_at = NOW() + WHERE id = $1 + """, + account_id, status, sync_error, email_count, unread_count + ) + return True + except Exception as e: + print(f"Failed to update account status: {e}") + return False + + +async def delete_email_account(account_id: str, user_id: str) -> bool: + """Delete an email account (cascades to emails).""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + result = await conn.execute( + """ + DELETE FROM external_email_accounts + WHERE id = $1 AND user_id = $2 + """, + account_id, user_id + ) + return "DELETE" in result + except Exception as e: + print(f"Failed to delete email account: {e}") + return False diff --git a/klausur-service/backend/mail/mail_db_emails.py b/klausur-service/backend/mail/mail_db_emails.py new file mode 100644 index 0000000..7ee8966 --- /dev/null +++ b/klausur-service/backend/mail/mail_db_emails.py @@ -0,0 +1,225 @@ +""" +Mail Database - Aggregated Email Operations. +""" + +import json +import uuid +from typing import Optional, List, Dict +from datetime import datetime + +from .mail_db_pool import get_pool + + +async def upsert_email( + account_id: str, + user_id: str, + tenant_id: str, + message_id: str, + subject: str, + sender_email: str, + sender_name: Optional[str], + recipients: List[str], + cc: List[str], + body_preview: Optional[str], + body_text: Optional[str], + body_html: Optional[str], + has_attachments: bool, + attachments: List[Dict], + headers: Dict, + folder: str, + date_sent: datetime, + date_received: datetime, +) -> Optional[str]: + """Insert or update an email. Returns the email ID.""" + pool = await get_pool() + if pool is None: + return None + + email_id = str(uuid.uuid4()) + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO aggregated_emails + (id, account_id, user_id, tenant_id, message_id, subject, + sender_email, sender_name, recipients, cc, body_preview, + body_text, body_html, has_attachments, attachments, headers, + folder, date_sent, date_received) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) + ON CONFLICT (account_id, message_id) DO UPDATE SET + subject = EXCLUDED.subject, + is_read = EXCLUDED.is_read, + folder = EXCLUDED.folder + RETURNING id + """, + email_id, account_id, user_id, tenant_id, message_id, subject, + sender_email, sender_name, json.dumps(recipients), json.dumps(cc), + body_preview, body_text, body_html, has_attachments, + json.dumps(attachments), json.dumps(headers), folder, + date_sent, date_received + ) + return row['id'] if row else None + except Exception as e: + print(f"Failed to upsert email: {e}") + return None + + +async def get_unified_inbox( + user_id: str, + account_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + is_read: Optional[bool] = None, + is_starred: Optional[bool] = None, + limit: int = 50, + offset: int = 0, +) -> List[Dict]: + """Get unified inbox with filtering.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + conditions = ["user_id = $1", "is_deleted = FALSE"] + params = [user_id] + param_idx = 2 + + if account_ids: + conditions.append(f"account_id = ANY(${param_idx})") + params.append(account_ids) + param_idx += 1 + + if categories: + conditions.append(f"category = ANY(${param_idx})") + params.append(categories) + param_idx += 1 + + if is_read is not None: + conditions.append(f"is_read = ${param_idx}") + params.append(is_read) + param_idx += 1 + + if is_starred is not None: + conditions.append(f"is_starred = ${param_idx}") + params.append(is_starred) + param_idx += 1 + + where_clause = " AND ".join(conditions) + params.extend([limit, offset]) + + query = f""" + SELECT e.*, a.email as account_email, a.display_name as account_name + FROM aggregated_emails e + JOIN external_email_accounts a ON e.account_id = a.id + WHERE {where_clause} + ORDER BY e.date_received DESC + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """ + + rows = await conn.fetch(query, *params) + return [dict(r) for r in rows] + except Exception as e: + print(f"Failed to get unified inbox: {e}") + return [] + + +async def get_email(email_id: str, user_id: str) -> Optional[Dict]: + """Get a single email by ID.""" + pool = await get_pool() + if pool is None: + return None + + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT e.*, a.email as account_email, a.display_name as account_name + FROM aggregated_emails e + JOIN external_email_accounts a ON e.account_id = a.id + WHERE e.id = $1 AND e.user_id = $2 + """, + email_id, user_id + ) + return dict(row) if row else None + except Exception as e: + print(f"Failed to get email: {e}") + return None + + +async def update_email_ai_analysis( + email_id: str, + category: str, + sender_type: str, + sender_authority_name: Optional[str], + detected_deadlines: List[Dict], + suggested_priority: str, + ai_summary: Optional[str], +) -> bool: + """Update email with AI analysis results.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE aggregated_emails SET + category = $2, + sender_type = $3, + sender_authority_name = $4, + detected_deadlines = $5, + suggested_priority = $6, + ai_summary = $7, + ai_analyzed_at = NOW() + WHERE id = $1 + """, + email_id, category, sender_type, sender_authority_name, + json.dumps(detected_deadlines), suggested_priority, ai_summary + ) + return True + except Exception as e: + print(f"Failed to update email AI analysis: {e}") + return False + + +async def mark_email_read(email_id: str, user_id: str, is_read: bool = True) -> bool: + """Mark email as read/unread.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE aggregated_emails SET is_read = $3 + WHERE id = $1 AND user_id = $2 + """, + email_id, user_id, is_read + ) + return True + except Exception as e: + print(f"Failed to mark email read: {e}") + return False + + +async def mark_email_starred(email_id: str, user_id: str, is_starred: bool = True) -> bool: + """Mark email as starred/unstarred.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE aggregated_emails SET is_starred = $3 + WHERE id = $1 AND user_id = $2 + """, + email_id, user_id, is_starred + ) + return True + except Exception as e: + print(f"Failed to mark email starred: {e}") + return False diff --git a/klausur-service/backend/mail/mail_db_pool.py b/klausur-service/backend/mail/mail_db_pool.py new file mode 100644 index 0000000..41a50e5 --- /dev/null +++ b/klausur-service/backend/mail/mail_db_pool.py @@ -0,0 +1,253 @@ +""" +Mail Database - Connection Pool and Schema Initialization. +""" + +import os + +# Database Configuration - from Vault or environment (test default for CI) +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test") + +# Flag to check if using test defaults +_DB_CONFIGURED = DATABASE_URL != "postgresql://test:test@localhost:5432/test" + +# Connection pool (shared with metrics_db) +_pool = None + + +async def get_pool(): + """Get or create database connection pool.""" + global _pool + if _pool is None: + try: + import asyncpg + _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + except ImportError: + print("Warning: asyncpg not installed. Mail database disabled.") + return None + except Exception as e: + print(f"Warning: Failed to connect to PostgreSQL: {e}") + return None + return _pool + + +async def init_mail_tables() -> bool: + """Initialize mail tables in PostgreSQL.""" + pool = await get_pool() + if pool is None: + return False + + create_tables_sql = """ + -- ============================================================================= + -- External Email Accounts + -- ============================================================================= + CREATE TABLE IF NOT EXISTS external_email_accounts ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + tenant_id VARCHAR(36) NOT NULL, + email VARCHAR(255) NOT NULL, + display_name VARCHAR(255), + account_type VARCHAR(50) DEFAULT 'personal', + + -- IMAP Settings (password stored in Vault) + imap_host VARCHAR(255) NOT NULL, + imap_port INTEGER DEFAULT 993, + imap_ssl BOOLEAN DEFAULT TRUE, + + -- SMTP Settings + smtp_host VARCHAR(255) NOT NULL, + smtp_port INTEGER DEFAULT 465, + smtp_ssl BOOLEAN DEFAULT TRUE, + + -- Vault path for credentials + vault_path VARCHAR(500), + + -- Status tracking + status VARCHAR(20) DEFAULT 'pending', + last_sync TIMESTAMP, + sync_error TEXT, + email_count INTEGER DEFAULT 0, + unread_count INTEGER DEFAULT 0, + + -- Timestamps + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW(), + + -- Constraints + UNIQUE(user_id, email) + ); + + CREATE INDEX IF NOT EXISTS idx_mail_accounts_user ON external_email_accounts(user_id); + CREATE INDEX IF NOT EXISTS idx_mail_accounts_tenant ON external_email_accounts(tenant_id); + CREATE INDEX IF NOT EXISTS idx_mail_accounts_status ON external_email_accounts(status); + + -- ============================================================================= + -- Aggregated Emails + -- ============================================================================= + CREATE TABLE IF NOT EXISTS aggregated_emails ( + id VARCHAR(36) PRIMARY KEY, + account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE, + user_id VARCHAR(36) NOT NULL, + tenant_id VARCHAR(36) NOT NULL, + + -- Email identification + message_id VARCHAR(500) NOT NULL, + folder VARCHAR(100) DEFAULT 'INBOX', + + -- Email content + subject TEXT, + sender_email VARCHAR(255), + sender_name VARCHAR(255), + recipients JSONB DEFAULT '[]', + cc JSONB DEFAULT '[]', + body_preview TEXT, + body_text TEXT, + body_html TEXT, + has_attachments BOOLEAN DEFAULT FALSE, + attachments JSONB DEFAULT '[]', + headers JSONB DEFAULT '{}', + + -- Status flags + is_read BOOLEAN DEFAULT FALSE, + is_starred BOOLEAN DEFAULT FALSE, + is_deleted BOOLEAN DEFAULT FALSE, + + -- Dates + date_sent TIMESTAMP, + date_received TIMESTAMP, + + -- AI enrichment + category VARCHAR(50), + sender_type VARCHAR(50), + sender_authority_name VARCHAR(255), + detected_deadlines JSONB DEFAULT '[]', + suggested_priority VARCHAR(20), + ai_summary TEXT, + ai_analyzed_at TIMESTAMP, + + created_at TIMESTAMP DEFAULT NOW(), + + -- Prevent duplicate imports + UNIQUE(account_id, message_id) + ); + + CREATE INDEX IF NOT EXISTS idx_emails_account ON aggregated_emails(account_id); + CREATE INDEX IF NOT EXISTS idx_emails_user ON aggregated_emails(user_id); + CREATE INDEX IF NOT EXISTS idx_emails_tenant ON aggregated_emails(tenant_id); + CREATE INDEX IF NOT EXISTS idx_emails_date ON aggregated_emails(date_received DESC); + CREATE INDEX IF NOT EXISTS idx_emails_category ON aggregated_emails(category); + CREATE INDEX IF NOT EXISTS idx_emails_unread ON aggregated_emails(is_read) WHERE is_read = FALSE; + CREATE INDEX IF NOT EXISTS idx_emails_starred ON aggregated_emails(is_starred) WHERE is_starred = TRUE; + CREATE INDEX IF NOT EXISTS idx_emails_sender ON aggregated_emails(sender_email); + + -- ============================================================================= + -- Inbox Tasks (Arbeitsvorrat) + -- ============================================================================= + CREATE TABLE IF NOT EXISTS inbox_tasks ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + tenant_id VARCHAR(36) NOT NULL, + email_id VARCHAR(36) REFERENCES aggregated_emails(id) ON DELETE SET NULL, + account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE SET NULL, + + -- Task content + title VARCHAR(500) NOT NULL, + description TEXT, + priority VARCHAR(20) DEFAULT 'medium', + status VARCHAR(20) DEFAULT 'pending', + deadline TIMESTAMP, + + -- Source information + source_email_subject TEXT, + source_sender VARCHAR(255), + source_sender_type VARCHAR(50), + + -- AI extraction info + ai_extracted BOOLEAN DEFAULT FALSE, + confidence_score FLOAT, + + -- Completion tracking + completed_at TIMESTAMP, + reminder_at TIMESTAMP, + + -- Timestamps + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_tasks_user ON inbox_tasks(user_id); + CREATE INDEX IF NOT EXISTS idx_tasks_tenant ON inbox_tasks(tenant_id); + CREATE INDEX IF NOT EXISTS idx_tasks_status ON inbox_tasks(status); + CREATE INDEX IF NOT EXISTS idx_tasks_deadline ON inbox_tasks(deadline) WHERE deadline IS NOT NULL; + CREATE INDEX IF NOT EXISTS idx_tasks_priority ON inbox_tasks(priority); + CREATE INDEX IF NOT EXISTS idx_tasks_email ON inbox_tasks(email_id) WHERE email_id IS NOT NULL; + + -- ============================================================================= + -- Email Templates + -- ============================================================================= + CREATE TABLE IF NOT EXISTS email_templates ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36), -- NULL for system templates + tenant_id VARCHAR(36), + + name VARCHAR(255) NOT NULL, + category VARCHAR(100), + subject_template TEXT, + body_template TEXT, + variables JSONB DEFAULT '[]', + + is_system BOOLEAN DEFAULT FALSE, + usage_count INTEGER DEFAULT 0, + + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_templates_user ON email_templates(user_id); + CREATE INDEX IF NOT EXISTS idx_templates_tenant ON email_templates(tenant_id); + CREATE INDEX IF NOT EXISTS idx_templates_system ON email_templates(is_system); + + -- ============================================================================= + -- Mail Audit Log + -- ============================================================================= + CREATE TABLE IF NOT EXISTS mail_audit_log ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + tenant_id VARCHAR(36), + action VARCHAR(100) NOT NULL, + entity_type VARCHAR(50), -- account, email, task + entity_id VARCHAR(36), + details JSONB, + ip_address VARCHAR(45), + user_agent TEXT, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_mail_audit_user ON mail_audit_log(user_id); + CREATE INDEX IF NOT EXISTS idx_mail_audit_created ON mail_audit_log(created_at DESC); + CREATE INDEX IF NOT EXISTS idx_mail_audit_action ON mail_audit_log(action); + + -- ============================================================================= + -- Sync Status Tracking + -- ============================================================================= + CREATE TABLE IF NOT EXISTS mail_sync_status ( + id VARCHAR(36) PRIMARY KEY, + account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE, + folder VARCHAR(100), + last_uid INTEGER DEFAULT 0, + last_sync TIMESTAMP, + sync_errors INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW(), + + UNIQUE(account_id, folder) + ); + """ + + try: + async with pool.acquire() as conn: + await conn.execute(create_tables_sql) + print("Mail tables initialized successfully") + return True + except Exception as e: + print(f"Failed to initialize mail tables: {e}") + return False diff --git a/klausur-service/backend/mail/mail_db_stats.py b/klausur-service/backend/mail/mail_db_stats.py new file mode 100644 index 0000000..193d117 --- /dev/null +++ b/klausur-service/backend/mail/mail_db_stats.py @@ -0,0 +1,118 @@ +""" +Mail Database - Statistics and Audit Log Operations. +""" + +import json +import uuid +from typing import Optional, Dict +from datetime import datetime + +from .mail_db_pool import get_pool + + +async def get_mail_stats(user_id: str) -> Dict: + """Get overall mail statistics for a user.""" + pool = await get_pool() + if pool is None: + return {} + + try: + async with pool.acquire() as conn: + today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # Account stats + accounts = await conn.fetch( + """ + SELECT id, email, display_name, status, email_count, unread_count, last_sync + FROM external_email_accounts + WHERE user_id = $1 + """, + user_id + ) + + # Email counts + email_stats = await conn.fetchrow( + """ + SELECT + COUNT(*) as total_emails, + COUNT(*) FILTER (WHERE is_read = FALSE) as unread_emails, + COUNT(*) FILTER (WHERE date_received >= $2) as emails_today, + COUNT(*) FILTER (WHERE ai_analyzed_at >= $2) as ai_analyses_today + FROM aggregated_emails + WHERE user_id = $1 + """, + user_id, today + ) + + # Task counts + task_stats = await conn.fetchrow( + """ + SELECT + COUNT(*) as total_tasks, + COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks, + COUNT(*) FILTER (WHERE status != 'completed' AND deadline < NOW()) as overdue_tasks + FROM inbox_tasks + WHERE user_id = $1 + """, + user_id + ) + + return { + "total_accounts": len(accounts), + "active_accounts": sum(1 for a in accounts if a['status'] == 'active'), + "error_accounts": sum(1 for a in accounts if a['status'] == 'error'), + "total_emails": email_stats['total_emails'] or 0, + "unread_emails": email_stats['unread_emails'] or 0, + "total_tasks": task_stats['total_tasks'] or 0, + "pending_tasks": task_stats['pending_tasks'] or 0, + "overdue_tasks": task_stats['overdue_tasks'] or 0, + "emails_today": email_stats['emails_today'] or 0, + "ai_analyses_today": email_stats['ai_analyses_today'] or 0, + "per_account": [ + { + "id": a['id'], + "email": a['email'], + "display_name": a['display_name'], + "status": a['status'], + "email_count": a['email_count'], + "unread_count": a['unread_count'], + "last_sync": a['last_sync'].isoformat() if a['last_sync'] else None, + } + for a in accounts + ], + } + except Exception as e: + print(f"Failed to get mail stats: {e}") + return {} + + +async def log_mail_audit( + user_id: str, + action: str, + entity_type: Optional[str] = None, + entity_id: Optional[str] = None, + details: Optional[Dict] = None, + tenant_id: Optional[str] = None, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, +) -> bool: + """Log a mail action for audit trail.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO mail_audit_log + (id, user_id, tenant_id, action, entity_type, entity_id, details, ip_address, user_agent) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + str(uuid.uuid4()), user_id, tenant_id, action, entity_type, entity_id, + json.dumps(details) if details else None, ip_address, user_agent + ) + return True + except Exception as e: + print(f"Failed to log mail audit: {e}") + return False diff --git a/klausur-service/backend/mail/mail_db_tasks.py b/klausur-service/backend/mail/mail_db_tasks.py new file mode 100644 index 0000000..af48ef7 --- /dev/null +++ b/klausur-service/backend/mail/mail_db_tasks.py @@ -0,0 +1,247 @@ +""" +Mail Database - Inbox Task Operations. +""" + +import uuid +from typing import Optional, List, Dict +from datetime import datetime, timedelta + +from .mail_db_pool import get_pool + + +async def create_task( + user_id: str, + tenant_id: str, + title: str, + description: Optional[str] = None, + priority: str = "medium", + deadline: Optional[datetime] = None, + email_id: Optional[str] = None, + account_id: Optional[str] = None, + source_email_subject: Optional[str] = None, + source_sender: Optional[str] = None, + source_sender_type: Optional[str] = None, + ai_extracted: bool = False, + confidence_score: Optional[float] = None, +) -> Optional[str]: + """Create a new inbox task.""" + pool = await get_pool() + if pool is None: + return None + + task_id = str(uuid.uuid4()) + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO inbox_tasks + (id, user_id, tenant_id, title, description, priority, deadline, + email_id, account_id, source_email_subject, source_sender, + source_sender_type, ai_extracted, confidence_score) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + """, + task_id, user_id, tenant_id, title, description, priority, deadline, + email_id, account_id, source_email_subject, source_sender, + source_sender_type, ai_extracted, confidence_score + ) + return task_id + except Exception as e: + print(f"Failed to create task: {e}") + return None + + +async def get_tasks( + user_id: str, + status: Optional[str] = None, + priority: Optional[str] = None, + include_completed: bool = False, + limit: int = 50, + offset: int = 0, +) -> List[Dict]: + """Get tasks for a user.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + conditions = ["user_id = $1"] + params = [user_id] + param_idx = 2 + + if not include_completed: + conditions.append("status != 'completed'") + + if status: + conditions.append(f"status = ${param_idx}") + params.append(status) + param_idx += 1 + + if priority: + conditions.append(f"priority = ${param_idx}") + params.append(priority) + param_idx += 1 + + where_clause = " AND ".join(conditions) + params.extend([limit, offset]) + + query = f""" + SELECT * FROM inbox_tasks + WHERE {where_clause} + ORDER BY + CASE priority + WHEN 'urgent' THEN 1 + WHEN 'high' THEN 2 + WHEN 'medium' THEN 3 + WHEN 'low' THEN 4 + END, + deadline ASC NULLS LAST, + created_at DESC + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """ + + rows = await conn.fetch(query, *params) + return [dict(r) for r in rows] + except Exception as e: + print(f"Failed to get tasks: {e}") + return [] + + +async def get_task(task_id: str, user_id: str) -> Optional[Dict]: + """Get a single task.""" + pool = await get_pool() + if pool is None: + return None + + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT * FROM inbox_tasks WHERE id = $1 AND user_id = $2", + task_id, user_id + ) + return dict(row) if row else None + except Exception as e: + print(f"Failed to get task: {e}") + return None + + +async def update_task( + task_id: str, + user_id: str, + title: Optional[str] = None, + description: Optional[str] = None, + priority: Optional[str] = None, + status: Optional[str] = None, + deadline: Optional[datetime] = None, +) -> bool: + """Update a task.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + updates = ["updated_at = NOW()"] + params = [task_id, user_id] + param_idx = 3 + + if title is not None: + updates.append(f"title = ${param_idx}") + params.append(title) + param_idx += 1 + + if description is not None: + updates.append(f"description = ${param_idx}") + params.append(description) + param_idx += 1 + + if priority is not None: + updates.append(f"priority = ${param_idx}") + params.append(priority) + param_idx += 1 + + if status is not None: + updates.append(f"status = ${param_idx}") + params.append(status) + param_idx += 1 + if status == "completed": + updates.append("completed_at = NOW()") + + if deadline is not None: + updates.append(f"deadline = ${param_idx}") + params.append(deadline) + param_idx += 1 + + set_clause = ", ".join(updates) + await conn.execute( + f"UPDATE inbox_tasks SET {set_clause} WHERE id = $1 AND user_id = $2", + *params + ) + return True + except Exception as e: + print(f"Failed to update task: {e}") + return False + + +async def get_task_dashboard_stats(user_id: str) -> Dict: + """Get dashboard statistics for tasks.""" + pool = await get_pool() + if pool is None: + return {} + + try: + async with pool.acquire() as conn: + now = datetime.now() + today_end = now.replace(hour=23, minute=59, second=59) + week_end = now + timedelta(days=7) + + stats = await conn.fetchrow( + """ + SELECT + COUNT(*) as total_tasks, + COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks, + COUNT(*) FILTER (WHERE status = 'in_progress') as in_progress_tasks, + COUNT(*) FILTER (WHERE status = 'completed') as completed_tasks, + COUNT(*) FILTER (WHERE status != 'completed' AND deadline < $2) as overdue_tasks, + COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $3) as due_today, + COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $4) as due_this_week + FROM inbox_tasks + WHERE user_id = $1 + """, + user_id, now, today_end, week_end + ) + + by_priority = await conn.fetch( + """ + SELECT priority, COUNT(*) as count + FROM inbox_tasks + WHERE user_id = $1 AND status != 'completed' + GROUP BY priority + """, + user_id + ) + + by_sender = await conn.fetch( + """ + SELECT source_sender_type, COUNT(*) as count + FROM inbox_tasks + WHERE user_id = $1 AND status != 'completed' AND source_sender_type IS NOT NULL + GROUP BY source_sender_type + """, + user_id + ) + + return { + "total_tasks": stats['total_tasks'] or 0, + "pending_tasks": stats['pending_tasks'] or 0, + "in_progress_tasks": stats['in_progress_tasks'] or 0, + "completed_tasks": stats['completed_tasks'] or 0, + "overdue_tasks": stats['overdue_tasks'] or 0, + "due_today": stats['due_today'] or 0, + "due_this_week": stats['due_this_week'] or 0, + "by_priority": {r['priority']: r['count'] for r in by_priority}, + "by_sender_type": {r['source_sender_type']: r['count'] for r in by_sender}, + } + except Exception as e: + print(f"Failed to get task stats: {e}") + return {} diff --git a/klausur-service/backend/ocr_merge_helpers.py b/klausur-service/backend/ocr_merge_helpers.py new file mode 100644 index 0000000..571c116 --- /dev/null +++ b/klausur-service/backend/ocr_merge_helpers.py @@ -0,0 +1,272 @@ +""" +OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results. + +Extracted from ocr_pipeline_ocr_merge.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +def _split_paddle_multi_words(words: list) -> list: + """Split PaddleOCR multi-word boxes into individual word boxes. + + PaddleOCR often returns entire phrases as a single box, e.g. + "More than 200 singers took part in the" with one bounding box. + This splits them into individual words with proportional widths. + Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"]) + and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]). + """ + import re + + result = [] + for w in words: + raw_text = w.get("text", "").strip() + if not raw_text: + continue + # Split on whitespace, before "[" (IPA), and after "!" before letter + tokens = re.split( + r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text + ) + tokens = [t for t in tokens if t] + + if len(tokens) <= 1: + result.append(w) + else: + # Split proportionally by character count + total_chars = sum(len(t) for t in tokens) + if total_chars == 0: + continue + n_gaps = len(tokens) - 1 + gap_px = w["width"] * 0.02 + usable_w = w["width"] - gap_px * n_gaps + cursor = w["left"] + for t in tokens: + token_w = max(1, usable_w * len(t) / total_chars) + result.append({ + "text": t, + "left": round(cursor), + "top": w["top"], + "width": round(token_w), + "height": w["height"], + "conf": w.get("conf", 0), + }) + cursor += token_w + gap_px + return result + + +def _group_words_into_rows(words: list, row_gap: int = 12) -> list: + """Group words into rows by Y-position clustering. + + Words whose vertical centers are within `row_gap` pixels are on the same row. + Returns list of rows, each row is a list of words sorted left-to-right. + """ + if not words: + return [] + # Sort by vertical center + sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) + rows: list = [] + current_row: list = [sorted_words[0]] + current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 + + for w in sorted_words[1:]: + cy = w["top"] + w.get("height", 0) / 2 + if abs(cy - current_cy) <= row_gap: + current_row.append(w) + else: + # Sort current row left-to-right before saving + rows.append(sorted(current_row, key=lambda w: w["left"])) + current_row = [w] + current_cy = cy + if current_row: + rows.append(sorted(current_row, key=lambda w: w["left"])) + return rows + + +def _row_center_y(row: list) -> float: + """Average vertical center of a row of words.""" + if not row: + return 0.0 + return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) + + +def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: + """Merge two word sequences from the same row using sequence alignment. + + Both sequences are sorted left-to-right. Walk through both simultaneously: + - If words match (same/similar text): take Paddle text with averaged coords + - If they don't match: the extra word is unique to one engine, include it + """ + merged = [] + pi, ti = 0, 0 + + while pi < len(paddle_row) and ti < len(tess_row): + pw = paddle_row[pi] + tw = tess_row[ti] + + pt = pw.get("text", "").lower().strip() + tt = tw.get("text", "").lower().strip() + + is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) + + # Spatial overlap check + spatial_match = False + if not is_same: + overlap_left = max(pw["left"], tw["left"]) + overlap_right = min( + pw["left"] + pw.get("width", 0), + tw["left"] + tw.get("width", 0), + ) + overlap_w = max(0, overlap_right - overlap_left) + min_w = min(pw.get("width", 1), tw.get("width", 1)) + if min_w > 0 and overlap_w / min_w >= 0.4: + is_same = True + spatial_match = True + + if is_same: + pc = pw.get("conf", 80) + tc = tw.get("conf", 50) + total = pc + tc + if total == 0: + total = 1 + if spatial_match and pc < tc: + best_text = tw["text"] + else: + best_text = pw["text"] + merged.append({ + "text": best_text, + "left": round((pw["left"] * pc + tw["left"] * tc) / total), + "top": round((pw["top"] * pc + tw["top"] * tc) / total), + "width": round((pw["width"] * pc + tw["width"] * tc) / total), + "height": round((pw["height"] * pc + tw["height"] * tc) / total), + "conf": max(pc, tc), + }) + pi += 1 + ti += 1 + else: + paddle_ahead = any( + tess_row[t].get("text", "").lower().strip() == pt + for t in range(ti + 1, min(ti + 4, len(tess_row))) + ) + tess_ahead = any( + paddle_row[p].get("text", "").lower().strip() == tt + for p in range(pi + 1, min(pi + 4, len(paddle_row))) + ) + + if paddle_ahead and not tess_ahead: + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + elif tess_ahead and not paddle_ahead: + merged.append(pw) + pi += 1 + else: + if pw["left"] <= tw["left"]: + merged.append(pw) + pi += 1 + else: + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + while pi < len(paddle_row): + merged.append(paddle_row[pi]) + pi += 1 + while ti < len(tess_row): + tw = tess_row[ti] + if tw.get("conf", 0) >= 30: + merged.append(tw) + ti += 1 + + return merged + + +def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: + """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.""" + if not paddle_words and not tess_words: + return [] + if not paddle_words: + return [w for w in tess_words if w.get("conf", 0) >= 40] + if not tess_words: + return list(paddle_words) + + paddle_rows = _group_words_into_rows(paddle_words) + tess_rows = _group_words_into_rows(tess_words) + + used_tess_rows: set = set() + merged_all: list = [] + + for pr in paddle_rows: + pr_cy = _row_center_y(pr) + best_dist, best_tri = float("inf"), -1 + for tri, tr in enumerate(tess_rows): + if tri in used_tess_rows: + continue + tr_cy = _row_center_y(tr) + dist = abs(pr_cy - tr_cy) + if dist < best_dist: + best_dist, best_tri = dist, tri + + max_row_dist = max( + max((w.get("height", 20) for w in pr), default=20), + 15, + ) + + if best_tri >= 0 and best_dist <= max_row_dist: + tr = tess_rows[best_tri] + used_tess_rows.add(best_tri) + merged_all.extend(_merge_row_sequences(pr, tr)) + else: + merged_all.extend(pr) + + for tri, tr in enumerate(tess_rows): + if tri not in used_tess_rows: + for tw in tr: + if tw.get("conf", 0) >= 40: + merged_all.append(tw) + + return merged_all + + +def _deduplicate_words(words: list) -> list: + """Remove duplicate words with same text at overlapping positions.""" + if not words: + return words + + result: list = [] + for w in words: + wt = w.get("text", "").lower().strip() + if not wt: + continue + is_dup = False + w_right = w["left"] + w.get("width", 0) + w_bottom = w["top"] + w.get("height", 0) + for existing in result: + et = existing.get("text", "").lower().strip() + if wt != et: + continue + ox_l = max(w["left"], existing["left"]) + ox_r = min(w_right, existing["left"] + existing.get("width", 0)) + ox = max(0, ox_r - ox_l) + min_w = min(w.get("width", 1), existing.get("width", 1)) + if min_w <= 0 or ox / min_w < 0.5: + continue + oy_t = max(w["top"], existing["top"]) + oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) + oy = max(0, oy_b - oy_t) + min_h = min(w.get("height", 1), existing.get("height", 1)) + if min_h > 0 and oy / min_h >= 0.5: + is_dup = True + break + if not is_dup: + result.append(w) + + removed = len(words) - len(result) + if removed: + logger.info("dedup: removed %d duplicate words", removed) + return result diff --git a/klausur-service/backend/ocr_pipeline_llm_review.py b/klausur-service/backend/ocr_pipeline_llm_review.py new file mode 100644 index 0000000..37e8df7 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_llm_review.py @@ -0,0 +1,209 @@ +""" +OCR Pipeline LLM Review — LLM-based correction endpoints. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +from datetime import datetime +from typing import Dict, List + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + llm_review_entries, + llm_review_entries_streaming, +) +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + _append_pipeline_log, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 8: LLM Review +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/llm-review") +async def run_llm_review(session_id: str, request: Request, stream: bool = False): + """Run LLM-based correction on vocab entries from Step 5. + + Query params: + stream: false (default) for JSON response, true for SSE streaming + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") + + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if not entries: + raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") + + # Optional model override from request body + body = {} + try: + body = await request.json() + except Exception: + pass + model = body.get("model") or OLLAMA_REVIEW_MODEL + + if stream: + return StreamingResponse( + _llm_review_stream_generator(session_id, entries, word_result, model, request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + + # Non-streaming path + try: + result = await llm_review_entries(entries, model=model) + except Exception as e: + import traceback + logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") + + # Store result inside word_result as a sub-key + word_result["llm_review"] = { + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "entries_corrected": result["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " + f"{result['duration_ms']}ms, model={result['model_used']}") + + await _append_pipeline_log(session_id, "correction", { + "engine": "llm", + "model": result["model_used"], + "total_entries": len(entries), + "corrections_proposed": len(result["changes"]), + }, duration_ms=result["duration_ms"]) + + return { + "session_id": session_id, + "changes": result["changes"], + "model_used": result["model_used"], + "duration_ms": result["duration_ms"], + "total_entries": len(entries), + "corrections_found": len(result["changes"]), + } + + +async def _llm_review_stream_generator( + session_id: str, + entries: List[Dict], + word_result: Dict, + model: str, + request: Request, +): + """SSE generator that yields batch-by-batch LLM review progress.""" + try: + async for event in llm_review_entries_streaming(entries, model=model): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during LLM review for {session_id}") + return + + yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" + + # On complete: persist to DB + if event.get("type") == "complete": + word_result["llm_review"] = { + "changes": event["changes"], + "model_used": event["model_used"], + "duration_ms": event["duration_ms"], + "entries_corrected": event["entries_corrected"], + } + await update_session_db(session_id, word_result=word_result, current_step=9) + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " + f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") + + except Exception as e: + import traceback + logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") + error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} + yield f"data: {json.dumps(error_event)}\n\n" + + +@router.post("/sessions/{session_id}/llm-review/apply") +async def apply_llm_corrections(session_id: str, request: Request): + """Apply selected LLM corrections to vocab entries.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + llm_review = word_result.get("llm_review") + if not llm_review: + raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") + + body = await request.json() + accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] + + changes = llm_review.get("changes", []) + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + + # Build a lookup: (row_index, field) -> new_value for accepted changes + corrections = {} + applied_count = 0 + for idx, change in enumerate(changes): + if idx in accepted_indices: + key = (change["row_index"], change["field"]) + corrections[key] = change["new"] + applied_count += 1 + + # Apply corrections to entries + for entry in entries: + row_idx = entry.get("row_index", -1) + for field_name in ("english", "german", "example"): + key = (row_idx, field_name) + if key in corrections: + entry[field_name] = corrections[key] + entry["llm_corrected"] = True + + # Update word_result + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["llm_review"]["applied_count"] = applied_count + word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() + + await update_session_db(session_id, word_result=word_result) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") + + return { + "session_id": session_id, + "applied_count": applied_count, + "total_changes": len(changes), + } diff --git a/klausur-service/backend/ocr_pipeline_ocr_merge.py b/klausur-service/backend/ocr_pipeline_ocr_merge.py index d8b4c8c..c91f8b2 100644 --- a/klausur-service/backend/ocr_pipeline_ocr_merge.py +++ b/klausur-service/backend/ocr_pipeline_ocr_merge.py @@ -1,10 +1,8 @@ """ -OCR Merge Helpers and Kombi Endpoints. +OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints. -Contains merge helper functions for combining PaddleOCR/RapidOCR with Tesseract -results, plus the paddle-kombi and rapid-kombi endpoints. - -Extracted from ocr_pipeline_api.py for modularity. +Merge helper functions live in ocr_merge_helpers.py. +This module re-exports them for backward compatibility. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. @@ -12,10 +10,8 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. import logging import time -from typing import Any, Dict, List import cv2 -import httpx import numpy as np from fastapi import APIRouter, HTTPException @@ -23,356 +19,23 @@ from cv_words_first import build_grid_from_words from ocr_pipeline_common import _cache, _append_pipeline_log from ocr_pipeline_session_store import get_session_image, update_session_db +# Re-export merge helpers for backward compatibility +from ocr_merge_helpers import ( # noqa: F401 + _split_paddle_multi_words, + _group_words_into_rows, + _row_center_y, + _merge_row_sequences, + _merge_paddle_tesseract, + _deduplicate_words, +) + logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) -# --------------------------------------------------------------------------- -# Merge helper functions -# --------------------------------------------------------------------------- - - -def _split_paddle_multi_words(words: list) -> list: - """Split PaddleOCR multi-word boxes into individual word boxes. - - PaddleOCR often returns entire phrases as a single box, e.g. - "More than 200 singers took part in the" with one bounding box. - This splits them into individual words with proportional widths. - Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"]) - and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]). - """ - import re - - result = [] - for w in words: - raw_text = w.get("text", "").strip() - if not raw_text: - continue - # Split on whitespace, before "[" (IPA), and after "!" before letter - tokens = re.split( - r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text - ) - tokens = [t for t in tokens if t] - - if len(tokens) <= 1: - result.append(w) - else: - # Split proportionally by character count - total_chars = sum(len(t) for t in tokens) - if total_chars == 0: - continue - n_gaps = len(tokens) - 1 - gap_px = w["width"] * 0.02 - usable_w = w["width"] - gap_px * n_gaps - cursor = w["left"] - for t in tokens: - token_w = max(1, usable_w * len(t) / total_chars) - result.append({ - "text": t, - "left": round(cursor), - "top": w["top"], - "width": round(token_w), - "height": w["height"], - "conf": w.get("conf", 0), - }) - cursor += token_w + gap_px - return result - - -def _group_words_into_rows(words: list, row_gap: int = 12) -> list: - """Group words into rows by Y-position clustering. - - Words whose vertical centers are within `row_gap` pixels are on the same row. - Returns list of rows, each row is a list of words sorted left-to-right. - """ - if not words: - return [] - # Sort by vertical center - sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2) - rows: list = [] - current_row: list = [sorted_words[0]] - current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2 - - for w in sorted_words[1:]: - cy = w["top"] + w.get("height", 0) / 2 - if abs(cy - current_cy) <= row_gap: - current_row.append(w) - else: - # Sort current row left-to-right before saving - rows.append(sorted(current_row, key=lambda w: w["left"])) - current_row = [w] - current_cy = cy - if current_row: - rows.append(sorted(current_row, key=lambda w: w["left"])) - return rows - - -def _row_center_y(row: list) -> float: - """Average vertical center of a row of words.""" - if not row: - return 0.0 - return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row) - - -def _merge_row_sequences(paddle_row: list, tess_row: list) -> list: - """Merge two word sequences from the same row using sequence alignment. - - Both sequences are sorted left-to-right. Walk through both simultaneously: - - If words match (same/similar text): take Paddle text with averaged coords - - If they don't match: the extra word is unique to one engine, include it - - This prevents duplicates because both engines produce words in the same order. - """ - merged = [] - pi, ti = 0, 0 - - while pi < len(paddle_row) and ti < len(tess_row): - pw = paddle_row[pi] - tw = tess_row[ti] - - # Check if these are the same word - pt = pw.get("text", "").lower().strip() - tt = tw.get("text", "").lower().strip() - - # Same text or one contains the other - is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt)) - - # Spatial overlap check: if words overlap >= 40% horizontally, - # they're the same physical word regardless of OCR text differences. - # (40% catches borderline cases like "Stick"/"Stück" at 48% overlap) - spatial_match = False - if not is_same: - overlap_left = max(pw["left"], tw["left"]) - overlap_right = min( - pw["left"] + pw.get("width", 0), - tw["left"] + tw.get("width", 0), - ) - overlap_w = max(0, overlap_right - overlap_left) - min_w = min(pw.get("width", 1), tw.get("width", 1)) - if min_w > 0 and overlap_w / min_w >= 0.4: - is_same = True - spatial_match = True - - if is_same: - # Matched — average coordinates weighted by confidence - pc = pw.get("conf", 80) - tc = tw.get("conf", 50) - total = pc + tc - if total == 0: - total = 1 - # Text: prefer higher-confidence engine when texts differ - # (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80) - if spatial_match and pc < tc: - best_text = tw["text"] - else: - best_text = pw["text"] - merged.append({ - "text": best_text, - "left": round((pw["left"] * pc + tw["left"] * tc) / total), - "top": round((pw["top"] * pc + tw["top"] * tc) / total), - "width": round((pw["width"] * pc + tw["width"] * tc) / total), - "height": round((pw["height"] * pc + tw["height"] * tc) / total), - "conf": max(pc, tc), - }) - pi += 1 - ti += 1 - else: - # Different text — one engine found something extra - # Look ahead: is the current Paddle word somewhere in Tesseract ahead? - paddle_ahead = any( - tess_row[t].get("text", "").lower().strip() == pt - for t in range(ti + 1, min(ti + 4, len(tess_row))) - ) - # Is the current Tesseract word somewhere in Paddle ahead? - tess_ahead = any( - paddle_row[p].get("text", "").lower().strip() == tt - for p in range(pi + 1, min(pi + 4, len(paddle_row))) - ) - - if paddle_ahead and not tess_ahead: - # Tesseract has an extra word (e.g. "!" or bullet) → include it - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - elif tess_ahead and not paddle_ahead: - # Paddle has an extra word → include it - merged.append(pw) - pi += 1 - else: - # Both have unique words or neither found ahead → take leftmost first - if pw["left"] <= tw["left"]: - merged.append(pw) - pi += 1 - else: - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - # Remaining words from either engine - while pi < len(paddle_row): - merged.append(paddle_row[pi]) - pi += 1 - while ti < len(tess_row): - tw = tess_row[ti] - if tw.get("conf", 0) >= 30: - merged.append(tw) - ti += 1 - - return merged - - -def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list: - """Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment. - - Strategy: - 1. Group each engine's words into rows (by Y-position clustering) - 2. Match rows between engines (by vertical center proximity) - 3. Within each matched row: merge sequences left-to-right, deduplicating - words that appear in both engines at the same sequence position - 4. Unmatched rows from either engine: keep as-is - - This prevents: - - Cross-line averaging (words from different lines being merged) - - Duplicate words (same word from both engines shown twice) - """ - if not paddle_words and not tess_words: - return [] - if not paddle_words: - return [w for w in tess_words if w.get("conf", 0) >= 40] - if not tess_words: - return list(paddle_words) - - # Step 1: Group into rows - paddle_rows = _group_words_into_rows(paddle_words) - tess_rows = _group_words_into_rows(tess_words) - - # Step 2: Match rows between engines by vertical center proximity - used_tess_rows: set = set() - merged_all: list = [] - - for pr in paddle_rows: - pr_cy = _row_center_y(pr) - best_dist, best_tri = float("inf"), -1 - for tri, tr in enumerate(tess_rows): - if tri in used_tess_rows: - continue - tr_cy = _row_center_y(tr) - dist = abs(pr_cy - tr_cy) - if dist < best_dist: - best_dist, best_tri = dist, tri - - # Row height threshold — rows must be within ~1.5x typical line height - max_row_dist = max( - max((w.get("height", 20) for w in pr), default=20), - 15, - ) - - if best_tri >= 0 and best_dist <= max_row_dist: - # Matched row — merge sequences - tr = tess_rows[best_tri] - used_tess_rows.add(best_tri) - merged_all.extend(_merge_row_sequences(pr, tr)) - else: - # No matching Tesseract row — keep Paddle row as-is - merged_all.extend(pr) - - # Add unmatched Tesseract rows - for tri, tr in enumerate(tess_rows): - if tri not in used_tess_rows: - for tw in tr: - if tw.get("conf", 0) >= 40: - merged_all.append(tw) - - return merged_all - - -def _deduplicate_words(words: list) -> list: - """Remove duplicate words with same text at overlapping positions. - - PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =") - that produce duplicate words after splitting. This pass removes them. - - A word is a duplicate only when BOTH horizontal AND vertical overlap - exceed 50% — same text on the same visual line at the same position. - """ - if not words: - return words - - result: list = [] - for w in words: - wt = w.get("text", "").lower().strip() - if not wt: - continue - is_dup = False - w_right = w["left"] + w.get("width", 0) - w_bottom = w["top"] + w.get("height", 0) - for existing in result: - et = existing.get("text", "").lower().strip() - if wt != et: - continue - # Horizontal overlap - ox_l = max(w["left"], existing["left"]) - ox_r = min(w_right, existing["left"] + existing.get("width", 0)) - ox = max(0, ox_r - ox_l) - min_w = min(w.get("width", 1), existing.get("width", 1)) - if min_w <= 0 or ox / min_w < 0.5: - continue - # Vertical overlap — must also be on the same line - oy_t = max(w["top"], existing["top"]) - oy_b = min(w_bottom, existing["top"] + existing.get("height", 0)) - oy = max(0, oy_b - oy_t) - min_h = min(w.get("height", 1), existing.get("height", 1)) - if min_h > 0 and oy / min_h >= 0.5: - is_dup = True - break - if not is_dup: - result.append(w) - - removed = len(words) - len(result) - if removed: - logger.info("dedup: removed %d duplicate words", removed) - return result - - -# --------------------------------------------------------------------------- -# Kombi endpoints -# --------------------------------------------------------------------------- - - -@router.post("/sessions/{session_id}/paddle-kombi") -async def paddle_kombi(session_id: str): - """Run PaddleOCR + Tesseract on the preprocessed image and merge results. - - Both engines run on the same preprocessed (cropped/dewarped) image. - Word boxes are matched by IoU and coordinates are averaged weighted by - confidence. Unmatched Tesseract words (bullets, symbols) are added. - """ - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode image") - - img_h, img_w = img_bgr.shape[:2] - - from cv_ocr_engines import ocr_region_paddle - - t0 = time.time() - - # --- PaddleOCR --- - paddle_words = await ocr_region_paddle(img_bgr, region=None) - if not paddle_words: - paddle_words = [] - - # --- Tesseract --- +def _run_tesseract_words(img_bgr) -> list: + """Run Tesseract OCR on an image and return word dicts.""" from PIL import Image import pytesseract @@ -397,15 +60,98 @@ async def paddle_kombi(session_id: str): "height": data["height"][i], "conf": conf, }) + return tess_words + + +def _build_kombi_word_result( + cells: list, + columns_meta: list, + img_w: int, + img_h: int, + duration: float, + engine_name: str, + raw_engine_words: list, + raw_engine_words_split: list, + tess_words: list, + merged_words: list, + raw_engine_key: str = "raw_paddle_words", + raw_split_key: str = "raw_paddle_words_split", +) -> dict: + """Build the word_result dict for kombi endpoints.""" + n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + col_types = {c.get("type") for c in columns_meta} + is_vocab = bool(col_types & {"column_en", "column_de"}) + + return { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": engine_name, + "grid_method": engine_name, + raw_engine_key: raw_engine_words, + raw_split_key: raw_engine_words_split, + "raw_tesseract_words": tess_words, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words), + raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split), + "tesseract_words": len(tess_words), + "merged_words": len(merged_words), + }, + } + + +async def _load_session_image(session_id: str): + """Load preprocessed image for kombi endpoints.""" + img_png = await get_session_image(session_id, "cropped") + if not img_png: + img_png = await get_session_image(session_id, "dewarped") + if not img_png: + img_png = await get_session_image(session_id, "original") + if not img_png: + raise HTTPException(status_code=404, detail="No image found for this session") + + img_arr = np.frombuffer(img_png, dtype=np.uint8) + img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + if img_bgr is None: + raise HTTPException(status_code=400, detail="Failed to decode image") + + return img_png, img_bgr + + +# --------------------------------------------------------------------------- +# Kombi endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/paddle-kombi") +async def paddle_kombi(session_id: str): + """Run PaddleOCR + Tesseract on the preprocessed image and merge results.""" + img_png, img_bgr = await _load_session_image(session_id) + img_h, img_w = img_bgr.shape[:2] + + from cv_ocr_engines import ocr_region_paddle + + t0 = time.time() + + paddle_words = await ocr_region_paddle(img_bgr, region=None) + if not paddle_words: + paddle_words = [] + + tess_words = _run_tesseract_words(img_bgr) - # --- Split multi-word Paddle boxes into individual words --- paddle_words_split = _split_paddle_multi_words(paddle_words) logger.info( - "paddle_kombi: split %d paddle boxes → %d individual words", + "paddle_kombi: split %d paddle boxes -> %d individual words", len(paddle_words), len(paddle_words_split), ) - # --- Merge --- if not paddle_words_split and not tess_words: raise HTTPException(status_code=400, detail="Both OCR engines returned no words") @@ -418,49 +164,23 @@ async def paddle_kombi(session_id: str): for cell in cells: cell["ocr_engine"] = "kombi" - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "kombi", - "grid_method": "kombi", - "raw_paddle_words": paddle_words, - "raw_paddle_words_split": paddle_words_split, - "raw_tesseract_words": tess_words, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - "paddle_words": len(paddle_words), - "paddle_words_split": len(paddle_words_split), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - }, - } + word_result = _build_kombi_word_result( + cells, columns_meta, img_w, img_h, duration, "kombi", + paddle_words, paddle_words_split, tess_words, merged_words, + "raw_paddle_words", "raw_paddle_words_split", + ) await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, + session_id, word_result=word_result, cropped_png=img_png, current_step=8, ) - # Update in-memory cache so detect-structure can access word_result if session_id in _cache: _cache[session_id]["word_result"] = word_result logger.info( "paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " "[paddle=%d, tess=%d, merged=%d]", - session_id, len(cells), n_rows, n_cols, duration, + session_id, len(cells), word_result["grid_shape"]["rows"], + word_result["grid_shape"]["cols"], duration, len(paddle_words), len(tess_words), len(merged_words), ) @@ -478,24 +198,8 @@ async def paddle_kombi(session_id: str): @router.post("/sessions/{session_id}/rapid-kombi") async def rapid_kombi(session_id: str): - """Run RapidOCR + Tesseract on the preprocessed image and merge results. - - Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime) - instead of remote PaddleOCR service. - """ - img_png = await get_session_image(session_id, "cropped") - if not img_png: - img_png = await get_session_image(session_id, "dewarped") - if not img_png: - img_png = await get_session_image(session_id, "original") - if not img_png: - raise HTTPException(status_code=404, detail="No image found for this session") - - img_arr = np.frombuffer(img_png, dtype=np.uint8) - img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) - if img_bgr is None: - raise HTTPException(status_code=400, detail="Failed to decode image") - + """Run RapidOCR + Tesseract on the preprocessed image and merge results.""" + img_png, img_bgr = await _load_session_image(session_id) img_h, img_w = img_bgr.shape[:2] from cv_ocr_engines import ocr_region_rapid @@ -503,7 +207,6 @@ async def rapid_kombi(session_id: str): t0 = time.time() - # --- RapidOCR (local, synchronous) --- full_region = PageRegion( type="full_page", x=0, y=0, width=img_w, height=img_h, ) @@ -511,40 +214,14 @@ async def rapid_kombi(session_id: str): if not rapid_words: rapid_words = [] - # --- Tesseract --- - from PIL import Image - import pytesseract + tess_words = _run_tesseract_words(img_bgr) - pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) - data = pytesseract.image_to_data( - pil_img, lang="eng+deu", - config="--psm 6 --oem 3", - output_type=pytesseract.Output.DICT, - ) - tess_words = [] - for i in range(len(data["text"])): - text = str(data["text"][i]).strip() - conf_raw = str(data["conf"][i]) - conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1 - if not text or conf < 20: - continue - tess_words.append({ - "text": text, - "left": data["left"][i], - "top": data["top"][i], - "width": data["width"][i], - "height": data["height"][i], - "conf": conf, - }) - - # --- Split multi-word RapidOCR boxes into individual words --- rapid_words_split = _split_paddle_multi_words(rapid_words) logger.info( - "rapid_kombi: split %d rapid boxes → %d individual words", + "rapid_kombi: split %d rapid boxes -> %d individual words", len(rapid_words), len(rapid_words_split), ) - # --- Merge --- if not rapid_words_split and not tess_words: raise HTTPException(status_code=400, detail="Both OCR engines returned no words") @@ -557,49 +234,23 @@ async def rapid_kombi(session_id: str): for cell in cells: cell["ocr_engine"] = "rapid_kombi" - n_rows = len(set(c["row_index"] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - col_types = {c.get("type") for c in columns_meta} - is_vocab = bool(col_types & {"column_en", "column_de"}) - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": "rapid_kombi", - "grid_method": "rapid_kombi", - "raw_rapid_words": rapid_words, - "raw_rapid_words_split": rapid_words_split, - "raw_tesseract_words": tess_words, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - "rapid_words": len(rapid_words), - "rapid_words_split": len(rapid_words_split), - "tesseract_words": len(tess_words), - "merged_words": len(merged_words), - }, - } + word_result = _build_kombi_word_result( + cells, columns_meta, img_w, img_h, duration, "rapid_kombi", + rapid_words, rapid_words_split, tess_words, merged_words, + "raw_rapid_words", "raw_rapid_words_split", + ) await update_session_db( - session_id, - word_result=word_result, - cropped_png=img_png, - current_step=8, + session_id, word_result=word_result, cropped_png=img_png, current_step=8, ) - # Update in-memory cache so detect-structure can access word_result if session_id in _cache: _cache[session_id]["word_result"] = word_result logger.info( "rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs " "[rapid=%d, tess=%d, merged=%d]", - session_id, len(cells), n_rows, n_cols, duration, + session_id, len(cells), word_result["grid_shape"]["rows"], + word_result["grid_shape"]["cols"], duration, len(rapid_words), len(tess_words), len(merged_words), ) diff --git a/klausur-service/backend/ocr_pipeline_postprocess.py b/klausur-service/backend/ocr_pipeline_postprocess.py index 3445800..388f5e2 100644 --- a/klausur-service/backend/ocr_pipeline_postprocess.py +++ b/klausur-service/backend/ocr_pipeline_postprocess.py @@ -1,929 +1,26 @@ """ -OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation, -image detection/generation, and handwriting removal endpoints. +OCR Pipeline Postprocessing API — composite router assembling LLM review, +reconstruction, export, validation, image detection/generation, and +handwriting removal endpoints. -Extracted from ocr_pipeline_api.py to keep the main module manageable. +Split into sub-modules: + ocr_pipeline_llm_review — LLM review + apply corrections + ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX + ocr_pipeline_validation — image detection, generation, validation, handwriting removal Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json -import logging -import os -import re -from datetime import datetime -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -from cv_vocab_pipeline import ( - OLLAMA_REVIEW_MODEL, - llm_review_entries, - llm_review_entries_streaming, -) -from ocr_pipeline_session_store import ( - get_session_db, - get_session_image, - get_sub_sessions, - update_session_db, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _get_base_image_png, - _append_pipeline_log, - RemoveHandwritingRequest, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - -# --------------------------------------------------------------------------- -# Pydantic Models -# --------------------------------------------------------------------------- - -STYLE_SUFFIXES = { - "educational": "educational illustration, textbook style, clear, colorful", - "cartoon": "cartoon, child-friendly, simple shapes", - "sketch": "pencil sketch, hand-drawn, black and white", - "clipart": "clipart, flat vector style, simple", - "realistic": "photorealistic, high detail", -} - - -class ValidationRequest(BaseModel): - notes: Optional[str] = None - score: Optional[int] = None - - -class GenerateImageRequest(BaseModel): - region_index: int - prompt: str - style: str = "educational" - - -# --------------------------------------------------------------------------- -# Step 8: LLM Review -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/llm-review") -async def run_llm_review(session_id: str, request: Request, stream: bool = False): - """Run LLM-based correction on vocab entries from Step 5. - - Query params: - stream: false (default) for JSON response, true for SSE streaming - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found — run Step 5 first") - - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if not entries: - raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first") - - # Optional model override from request body - body = {} - try: - body = await request.json() - except Exception: - pass - model = body.get("model") or OLLAMA_REVIEW_MODEL - - if stream: - return StreamingResponse( - _llm_review_stream_generator(session_id, entries, word_result, model, request), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, - ) - - # Non-streaming path - try: - result = await llm_review_entries(entries, model=model) - except Exception as e: - import traceback - logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}") - - # Store result inside word_result as a sub-key - word_result["llm_review"] = { - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "entries_corrected": result["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, " - f"{result['duration_ms']}ms, model={result['model_used']}") - - await _append_pipeline_log(session_id, "correction", { - "engine": "llm", - "model": result["model_used"], - "total_entries": len(entries), - "corrections_proposed": len(result["changes"]), - }, duration_ms=result["duration_ms"]) - - return { - "session_id": session_id, - "changes": result["changes"], - "model_used": result["model_used"], - "duration_ms": result["duration_ms"], - "total_entries": len(entries), - "corrections_found": len(result["changes"]), - } - - -async def _llm_review_stream_generator( - session_id: str, - entries: List[Dict], - word_result: Dict, - model: str, - request: Request, -): - """SSE generator that yields batch-by-batch LLM review progress.""" - try: - async for event in llm_review_entries_streaming(entries, model=model): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during LLM review for {session_id}") - return - - yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" - - # On complete: persist to DB - if event.get("type") == "complete": - word_result["llm_review"] = { - "changes": event["changes"], - "model_used": event["model_used"], - "duration_ms": event["duration_ms"], - "entries_corrected": event["entries_corrected"], - } - await update_session_db(session_id, word_result=word_result, current_step=9) - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, " - f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}") - - except Exception as e: - import traceback - logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}") - error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"} - yield f"data: {json.dumps(error_event)}\n\n" - - -@router.post("/sessions/{session_id}/llm-review/apply") -async def apply_llm_corrections(session_id: str, request: Request): - """Apply selected LLM corrections to vocab entries.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - llm_review = word_result.get("llm_review") - if not llm_review: - raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first") - - body = await request.json() - accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[] - - changes = llm_review.get("changes", []) - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - - # Build a lookup: (row_index, field) -> new_value for accepted changes - corrections = {} - applied_count = 0 - for idx, change in enumerate(changes): - if idx in accepted_indices: - key = (change["row_index"], change["field"]) - corrections[key] = change["new"] - applied_count += 1 - - # Apply corrections to entries - for entry in entries: - row_idx = entry.get("row_index", -1) - for field_name in ("english", "german", "example"): - key = (row_idx, field_name) - if key in corrections: - entry[field_name] = corrections[key] - entry["llm_corrected"] = True - - # Update word_result - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["llm_review"]["applied_count"] = applied_count - word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat() - - await update_session_db(session_id, word_result=word_result) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}") - - return { - "session_id": session_id, - "applied_count": applied_count, - "total_changes": len(changes), - } - - -# --------------------------------------------------------------------------- -# Step 9: Reconstruction + Fabric JSON export -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reconstruction") -async def save_reconstruction(session_id: str, request: Request): - """Save edited cell texts from reconstruction step.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - body = await request.json() - cell_updates = body.get("cells", []) - - if not cell_updates: - await update_session_db(session_id, current_step=10) - return {"session_id": session_id, "updated": 0} - - # Build update map: cell_id -> new text - update_map = {c["cell_id"]: c["text"] for c in cell_updates} - - # Separate sub-session updates (cell_ids prefixed with "box{N}_") - sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} - main_updates: Dict[str, str] = {} - for cell_id, text in update_map.items(): - m = re.match(r'^box(\d+)_(.+)$', cell_id) - if m: - bi = int(m.group(1)) - original_id = m.group(2) - sub_updates.setdefault(bi, {})[original_id] = text - else: - main_updates[cell_id] = text - - # Update main session cells - cells = word_result.get("cells", []) - updated_count = 0 - for cell in cells: - if cell["cell_id"] in main_updates: - cell["text"] = main_updates[cell["cell_id"]] - cell["status"] = "edited" - updated_count += 1 - - word_result["cells"] = cells - - # Also update vocab_entries if present - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - if entries: - # Map cell_id pattern "R{row}_C{col}" to entry fields - for entry in entries: - row_idx = entry.get("row_index", -1) - # Check each field's cell - for col_idx, field_name in enumerate(["english", "german", "example"]): - cell_id = f"R{row_idx:02d}_C{col_idx}" - # Also try without zero-padding - cell_id_alt = f"R{row_idx}_C{col_idx}" - new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) - if new_text is not None: - entry[field_name] = new_text - - word_result["vocab_entries"] = entries - if "entries" in word_result: - word_result["entries"] = entries - - await update_session_db(session_id, word_result=word_result, current_step=10) - - if session_id in _cache: - _cache[session_id]["word_result"] = word_result - - # Route sub-session updates - sub_updated = 0 - if sub_updates: - subs = await get_sub_sessions(session_id) - sub_by_index = {s.get("box_index"): s["id"] for s in subs} - for bi, updates in sub_updates.items(): - sub_id = sub_by_index.get(bi) - if not sub_id: - continue - sub_session = await get_session_db(sub_id) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word: - continue - sub_cells = sub_word.get("cells", []) - for cell in sub_cells: - if cell["cell_id"] in updates: - cell["text"] = updates[cell["cell_id"]] - cell["status"] = "edited" - sub_updated += 1 - sub_word["cells"] = sub_cells - await update_session_db(sub_id, word_result=sub_word) - if sub_id in _cache: - _cache[sub_id]["word_result"] = sub_word - - total_updated = updated_count + sub_updated - logger.info(f"Reconstruction saved for session {session_id}: " - f"{updated_count} main + {sub_updated} sub-session cells updated") - - return { - "session_id": session_id, - "updated": total_updated, - "main_updated": updated_count, - "sub_updated": sub_updated, - } - - -@router.get("/sessions/{session_id}/reconstruction/fabric-json") -async def get_fabric_json(session_id: str): - """Return cell grid as Fabric.js-compatible JSON for the canvas editor. - - If the session has sub-sessions (box regions), their cells are merged - into the result at the correct Y positions. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = list(word_result.get("cells", [])) - img_w = word_result.get("image_width", 800) - img_h = word_result.get("image_height", 600) - - # Merge sub-session cells at box positions - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") - if not sub_word or not sub_word.get("cells"): - continue - - bi = sub.get("box_index", 0) - if bi < len(box_zones): - box = box_zones[bi]["box"] - box_y, box_x = box["y"], box["x"] - else: - box_y, box_x = 0, 0 - - # Offset sub-session cells to absolute page coordinates - for cell in sub_word["cells"]: - cell_copy = dict(cell) - # Prefix cell_id with box index - cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" - cell_copy["source"] = f"box_{bi}" - # Offset bbox_px - bbox = cell_copy.get("bbox_px", {}) - if bbox: - bbox = dict(bbox) - bbox["x"] = bbox.get("x", 0) + box_x - bbox["y"] = bbox.get("y", 0) + box_y - cell_copy["bbox_px"] = bbox - cells.append(cell_copy) - - from services.layout_reconstruction_service import cells_to_fabric_json - fabric_json = cells_to_fabric_json(cells, img_w, img_h) - - return fabric_json - - -# --------------------------------------------------------------------------- -# Vocab entries merged + PDF/DOCX export -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/vocab-entries/merged") -async def get_merged_vocab_entries(session_id: str): - """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") or {} - entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) - - # Tag main entries - for e in entries: - e.setdefault("source", "main") - - # Merge sub-session entries - subs = await get_sub_sessions(session_id) - if subs: - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - - for sub in subs: - sub_session = await get_session_db(sub["id"]) - if not sub_session: - continue - sub_word = sub_session.get("word_result") or {} - sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] - - bi = sub.get("box_index", 0) - box_y = 0 - if bi < len(box_zones): - box_y = box_zones[bi]["box"]["y"] - - for e in sub_entries: - e_copy = dict(e) - e_copy["source"] = f"box_{bi}" - e_copy["source_y"] = box_y # for sorting - entries.append(e_copy) - - # Sort by approximate Y position - def _sort_key(e): - if e.get("source", "main") == "main": - return e.get("row_index", 0) * 100 # main entries by row index - return e.get("source_y", 0) * 100 + e.get("row_index", 0) - - entries.sort(key=_sort_key) - - return { - "session_id": session_id, - "entries": entries, - "total": len(entries), - "sources": list(set(e.get("source", "main") for e in entries)), - } - - -@router.get("/sessions/{session_id}/reconstruction/export/pdf") -async def export_reconstruction_pdf(session_id: str): - """Export the reconstructed cell grid as a PDF table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - # Build table data: rows x columns - table_data: list[list[str]] = [] - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - table_data.append(header) - - for r in range(n_rows): - row_texts = [] - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - row_texts.append(cell.get("text", "") if cell else "") - table_data.append(row_texts) - - # Generate PDF with reportlab - try: - from reportlab.lib.pagesizes import A4 - from reportlab.lib import colors - from reportlab.platypus import SimpleDocTemplate, Table, TableStyle - import io as _io - - buf = _io.BytesIO() - doc = SimpleDocTemplate(buf, pagesize=A4) - if not table_data or not table_data[0]: - raise HTTPException(status_code=400, detail="No data to export") - - t = Table(table_data) - t.setStyle(TableStyle([ - ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), - ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), - ('FONTSIZE', (0, 0), (-1, -1), 9), - ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), - ('VALIGN', (0, 0), (-1, -1), 'TOP'), - ('WORDWRAP', (0, 0), (-1, -1), True), - ])) - doc.build([t]) - buf.seek(0) - - from fastapi.responses import StreamingResponse - return StreamingResponse( - buf, - media_type="application/pdf", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="reportlab not installed") - - -@router.get("/sessions/{session_id}/reconstruction/export/docx") -async def export_reconstruction_docx(session_id: str): - """Export the reconstructed cell grid as a DOCX table.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=400, detail="No word result found") - - cells = word_result.get("cells", []) - columns_used = word_result.get("columns_used", []) - grid_shape = word_result.get("grid_shape", {}) - n_rows = grid_shape.get("rows", 0) - n_cols = grid_shape.get("cols", 0) - - try: - from docx import Document - from docx.shared import Pt - import io as _io - - doc = Document() - doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1) - - # Build header - header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] - if not header: - header = [f"Col {i}" for i in range(n_cols)] - - table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) - table.style = 'Table Grid' - - # Header row - for ci, h in enumerate(header): - table.rows[0].cells[ci].text = h - - # Data rows - for r in range(n_rows): - for ci in range(n_cols): - cell_id = f"R{r:02d}_C{ci}" - cell = next((c for c in cells if c.get("cell_id") == cell_id), None) - table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" - - buf = _io.BytesIO() - doc.save(buf) - buf.seek(0) - - from fastapi.responses import StreamingResponse - return StreamingResponse( - buf, - media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, - ) - except ImportError: - raise HTTPException(status_code=501, detail="python-docx not installed") - - -# --------------------------------------------------------------------------- -# Step 8: Validation — Original vs. Reconstruction -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reconstruction/detect-images") -async def detect_image_regions(session_id: str): - """Detect illustration/image regions in the original scan using VLM. - - Sends the original image to qwen2.5vl to find non-text, non-table - image areas, returning bounding boxes (in %) and descriptions. - """ - import base64 - import httpx - import re - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # Get original image bytes - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=400, detail="No original image found") - - # Build context from vocab entries for richer descriptions - word_result = session.get("word_result") or {} - entries = word_result.get("vocab_entries") or word_result.get("entries") or [] - vocab_context = "" - if entries: - sample = entries[:10] - words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] - if words: - vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "Analyze this scanned page. Find ALL illustration/image/picture regions " - "(NOT text, NOT table cells, NOT blank areas). " - "For each image region found, return its bounding box as percentage of page dimensions " - "and a short English description of what the image shows. " - "Reply with ONLY a JSON array like: " - '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' - "where x, y, w, h are percentages (0-100) of the page width/height. " - "If there are NO images on the page, return an empty array: []" - f"{vocab_context}" - ) - - img_b64 = base64.b64encode(original_png).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON array from response - match = re.search(r'\[.*?\]', text, re.DOTALL) - if match: - raw_regions = json.loads(match.group(0)) - else: - raw_regions = [] - - # Normalize to ImageRegion format - regions = [] - for r in raw_regions: - regions.append({ - "bbox_pct": { - "x": max(0, min(100, float(r.get("x", 0)))), - "y": max(0, min(100, float(r.get("y", 0)))), - "w": max(1, min(100, float(r.get("w", 10)))), - "h": max(1, min(100, float(r.get("h", 10)))), - }, - "description": r.get("description", ""), - "prompt": r.get("description", ""), - "image_b64": None, - "style": "educational", - }) - - # Enrich prompts with nearby vocab context - if entries: - for region in regions: - ry = region["bbox_pct"]["y"] - rh = region["bbox_pct"]["h"] - nearby = [ - e for e in entries - if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 - ] - if nearby: - en_words = [e.get("english", "") for e in nearby if e.get("english")] - de_words = [e.get("german", "") for e in nearby if e.get("german")] - if en_words or de_words: - context = f" (vocabulary context: {', '.join(en_words[:5])}" - if de_words: - context += f" / {', '.join(de_words[:5])}" - context += ")" - region["prompt"] = region["description"] + context - - # Save to ground_truth JSONB - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["image_regions"] = regions - validation["detected_at"] = datetime.utcnow().isoformat() - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Detected {len(regions)} image regions for session {session_id}") - - return {"regions": regions, "count": len(regions)} - - except httpx.ConnectError: - logger.warning(f"VLM not available at {ollama_base} for image detection") - return {"regions": [], "count": 0, "error": "VLM not available"} - except Exception as e: - logger.error(f"Image detection failed for {session_id}: {e}") - return {"regions": [], "count": 0, "error": str(e)} - - -@router.post("/sessions/{session_id}/reconstruction/generate-image") -async def generate_image_for_region(session_id: str, req: GenerateImageRequest): - """Generate a replacement image for a detected region using mflux. - - Sends the prompt (with style suffix) to the mflux-service running - natively on the Mac Mini (Metal GPU required). - """ - import httpx - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - regions = validation.get("image_regions") or [] - - if req.region_index < 0 or req.region_index >= len(regions): - raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") - - mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") - style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) - full_prompt = f"{req.prompt}, {style_suffix}" - - # Determine image size from region aspect ratio (snap to multiples of 64) - region = regions[req.region_index] - bbox = region["bbox_pct"] - aspect = bbox["w"] / max(bbox["h"], 1) - if aspect > 1.3: - width, height = 768, 512 - elif aspect < 0.7: - width, height = 512, 768 - else: - width, height = 512, 512 - - try: - async with httpx.AsyncClient(timeout=300.0) as client: - resp = await client.post(f"{mflux_url}/generate", json={ - "prompt": full_prompt, - "width": width, - "height": height, - "steps": 4, - }) - resp.raise_for_status() - data = resp.json() - image_b64 = data.get("image_b64") - - if not image_b64: - return {"image_b64": None, "success": False, "error": "No image returned"} - - # Save to ground_truth - regions[req.region_index]["image_b64"] = image_b64 - regions[req.region_index]["prompt"] = req.prompt - regions[req.region_index]["style"] = req.style - validation["image_regions"] = regions - ground_truth["validation"] = validation - await update_session_db(session_id, ground_truth=ground_truth) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Generated image for session {session_id} region {req.region_index}") - return {"image_b64": image_b64, "success": True} - - except httpx.ConnectError: - logger.warning(f"mflux-service not available at {mflux_url}") - return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} - except Exception as e: - logger.error(f"Image generation failed for {session_id}: {e}") - return {"image_b64": None, "success": False, "error": str(e)} - - -@router.post("/sessions/{session_id}/reconstruction/validate") -async def save_validation(session_id: str, req: ValidationRequest): - """Save final validation results for step 8. - - Stores notes, score, and preserves any detected/generated image regions. - Sets current_step = 10 to mark pipeline as complete. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") or {} - validation["validated_at"] = datetime.utcnow().isoformat() - validation["notes"] = req.notes - validation["score"] = req.score - ground_truth["validation"] = validation - - await update_session_db(session_id, ground_truth=ground_truth, current_step=11) - - if session_id in _cache: - _cache[session_id]["ground_truth"] = ground_truth - - logger.info(f"Validation saved for session {session_id}: score={req.score}") - - return {"session_id": session_id, "validation": validation} - - -@router.get("/sessions/{session_id}/reconstruction/validation") -async def get_validation(session_id: str): - """Retrieve saved validation data for step 8.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - ground_truth = session.get("ground_truth") or {} - validation = ground_truth.get("validation") - - return { - "session_id": session_id, - "validation": validation, - "word_result": session.get("word_result"), - } - - -# --------------------------------------------------------------------------- -# Remove handwriting -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/remove-handwriting") -async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): - """ - Remove handwriting from a session image using inpainting. - - Steps: - 1. Load source image (auto -> deskewed if available, else original) - 2. Detect handwriting mask (filtered by target_ink) - 3. Dilate mask to cover stroke edges - 4. Inpaint the image - 5. Store result as clean_png in the session - - Returns metadata including the URL to fetch the clean image. - """ - import time as _time - t0 = _time.monotonic() - - from services.handwriting_detection import detect_handwriting - from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # 1. Determine source image - source = req.use_source - if source == "auto": - deskewed = await get_session_image(session_id, "deskewed") - source = "deskewed" if deskewed else "original" - - image_bytes = await get_session_image(session_id, source) - if not image_bytes: - raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") - - # 2. Detect handwriting mask - detection = detect_handwriting(image_bytes, target_ink=req.target_ink) - - # 3. Convert mask to PNG bytes and dilate - import io - from PIL import Image as _PILImage - mask_img = _PILImage.fromarray(detection.mask) - mask_buf = io.BytesIO() - mask_img.save(mask_buf, format="PNG") - mask_bytes = mask_buf.getvalue() - - if req.dilation > 0: - mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) - - # 4. Inpaint - method_map = { - "telea": InpaintingMethod.OPENCV_TELEA, - "ns": InpaintingMethod.OPENCV_NS, - "auto": InpaintingMethod.AUTO, - } - inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) - - result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) - if not result.success: - raise HTTPException(status_code=500, detail="Inpainting failed") - - elapsed_ms = int((_time.monotonic() - t0) * 1000) - - meta = { - "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), - "handwriting_ratio": round(detection.handwriting_ratio, 4), - "detection_confidence": round(detection.confidence, 4), - "target_ink": req.target_ink, - "dilation": req.dilation, - "source_image": source, - "processing_time_ms": elapsed_ms, - } - - # 5. Persist clean image (convert BGR ndarray -> PNG bytes) - clean_png_bytes = image_to_png(result.image) - await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) - - return { - **meta, - "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", - "session_id": session_id, - } +from fastapi import APIRouter + +from ocr_pipeline_llm_review import router as _llm_review_router +from ocr_pipeline_reconstruction import router as _reconstruction_router +from ocr_pipeline_validation import router as _validation_router + +# Composite router — drop-in replacement for the old monolithic router. +# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``. +router = APIRouter() +router.include_router(_llm_review_router) +router.include_router(_reconstruction_router) +router.include_router(_validation_router) diff --git a/klausur-service/backend/ocr_pipeline_reconstruction.py b/klausur-service/backend/ocr_pipeline_reconstruction.py new file mode 100644 index 0000000..99081c4 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_reconstruction.py @@ -0,0 +1,362 @@ +""" +OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Dict + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from ocr_pipeline_session_store import ( + get_session_db, + get_sub_sessions, + update_session_db, +) +from ocr_pipeline_common import _cache + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Step 9: Reconstruction + Fabric JSON export +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction") +async def save_reconstruction(session_id: str, request: Request): + """Save edited cell texts from reconstruction step.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + body = await request.json() + cell_updates = body.get("cells", []) + + if not cell_updates: + await update_session_db(session_id, current_step=10) + return {"session_id": session_id, "updated": 0} + + # Build update map: cell_id -> new text + update_map = {c["cell_id"]: c["text"] for c in cell_updates} + + # Separate sub-session updates (cell_ids prefixed with "box{N}_") + sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text} + main_updates: Dict[str, str] = {} + for cell_id, text in update_map.items(): + m = re.match(r'^box(\d+)_(.+)$', cell_id) + if m: + bi = int(m.group(1)) + original_id = m.group(2) + sub_updates.setdefault(bi, {})[original_id] = text + else: + main_updates[cell_id] = text + + # Update main session cells + cells = word_result.get("cells", []) + updated_count = 0 + for cell in cells: + if cell["cell_id"] in main_updates: + cell["text"] = main_updates[cell["cell_id"]] + cell["status"] = "edited" + updated_count += 1 + + word_result["cells"] = cells + + # Also update vocab_entries if present + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + if entries: + for entry in entries: + row_idx = entry.get("row_index", -1) + for col_idx, field_name in enumerate(["english", "german", "example"]): + cell_id = f"R{row_idx:02d}_C{col_idx}" + cell_id_alt = f"R{row_idx}_C{col_idx}" + new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt) + if new_text is not None: + entry[field_name] = new_text + + word_result["vocab_entries"] = entries + if "entries" in word_result: + word_result["entries"] = entries + + await update_session_db(session_id, word_result=word_result, current_step=10) + + if session_id in _cache: + _cache[session_id]["word_result"] = word_result + + # Route sub-session updates + sub_updated = 0 + if sub_updates: + subs = await get_sub_sessions(session_id) + sub_by_index = {s.get("box_index"): s["id"] for s in subs} + for bi, updates in sub_updates.items(): + sub_id = sub_by_index.get(bi) + if not sub_id: + continue + sub_session = await get_session_db(sub_id) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word: + continue + sub_cells = sub_word.get("cells", []) + for cell in sub_cells: + if cell["cell_id"] in updates: + cell["text"] = updates[cell["cell_id"]] + cell["status"] = "edited" + sub_updated += 1 + sub_word["cells"] = sub_cells + await update_session_db(sub_id, word_result=sub_word) + if sub_id in _cache: + _cache[sub_id]["word_result"] = sub_word + + total_updated = updated_count + sub_updated + logger.info(f"Reconstruction saved for session {session_id}: " + f"{updated_count} main + {sub_updated} sub-session cells updated") + + return { + "session_id": session_id, + "updated": total_updated, + "main_updated": updated_count, + "sub_updated": sub_updated, + } + + +@router.get("/sessions/{session_id}/reconstruction/fabric-json") +async def get_fabric_json(session_id: str): + """Return cell grid as Fabric.js-compatible JSON for the canvas editor.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = list(word_result.get("cells", [])) + img_w = word_result.get("image_width", 800) + img_h = word_result.get("image_height", 600) + + # Merge sub-session cells at box positions + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") + if not sub_word or not sub_word.get("cells"): + continue + + bi = sub.get("box_index", 0) + if bi < len(box_zones): + box = box_zones[bi]["box"] + box_y, box_x = box["y"], box["x"] + else: + box_y, box_x = 0, 0 + + for cell in sub_word["cells"]: + cell_copy = dict(cell) + cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}" + cell_copy["source"] = f"box_{bi}" + bbox = cell_copy.get("bbox_px", {}) + if bbox: + bbox = dict(bbox) + bbox["x"] = bbox.get("x", 0) + box_x + bbox["y"] = bbox.get("y", 0) + box_y + cell_copy["bbox_px"] = bbox + cells.append(cell_copy) + + from services.layout_reconstruction_service import cells_to_fabric_json + fabric_json = cells_to_fabric_json(cells, img_w, img_h) + + return fabric_json + + +# --------------------------------------------------------------------------- +# Vocab entries merged + PDF/DOCX export +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/vocab-entries/merged") +async def get_merged_vocab_entries(session_id: str): + """Return vocab entries from main session + all sub-sessions, sorted by Y position.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") or {} + entries = list(word_result.get("vocab_entries") or word_result.get("entries") or []) + + for e in entries: + e.setdefault("source", "main") + + subs = await get_sub_sessions(session_id) + if subs: + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + + for sub in subs: + sub_session = await get_session_db(sub["id"]) + if not sub_session: + continue + sub_word = sub_session.get("word_result") or {} + sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or [] + + bi = sub.get("box_index", 0) + box_y = 0 + if bi < len(box_zones): + box_y = box_zones[bi]["box"]["y"] + + for e in sub_entries: + e_copy = dict(e) + e_copy["source"] = f"box_{bi}" + e_copy["source_y"] = box_y + entries.append(e_copy) + + def _sort_key(e): + if e.get("source", "main") == "main": + return e.get("row_index", 0) * 100 + return e.get("source_y", 0) * 100 + e.get("row_index", 0) + + entries.sort(key=_sort_key) + + return { + "session_id": session_id, + "entries": entries, + "total": len(entries), + "sources": list(set(e.get("source", "main") for e in entries)), + } + + +@router.get("/sessions/{session_id}/reconstruction/export/pdf") +async def export_reconstruction_pdf(session_id: str): + """Export the reconstructed cell grid as a PDF table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + # Build table data: rows x columns + table_data: list[list[str]] = [] + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + table_data.append(header) + + for r in range(n_rows): + row_texts = [] + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + row_texts.append(cell.get("text", "") if cell else "") + table_data.append(row_texts) + + try: + from reportlab.lib.pagesizes import A4 + from reportlab.lib import colors + from reportlab.platypus import SimpleDocTemplate, Table, TableStyle + import io as _io + + buf = _io.BytesIO() + doc = SimpleDocTemplate(buf, pagesize=A4) + if not table_data or not table_data[0]: + raise HTTPException(status_code=400, detail="No data to export") + + t = Table(table_data) + t.setStyle(TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), + ('FONTSIZE', (0, 0), (-1, -1), 9), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('WORDWRAP', (0, 0), (-1, -1), True), + ])) + doc.build([t]) + buf.seek(0) + + return StreamingResponse( + buf, + media_type="application/pdf", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="reportlab not installed") + + +@router.get("/sessions/{session_id}/reconstruction/export/docx") +async def export_reconstruction_docx(session_id: str): + """Export the reconstructed cell grid as a DOCX table.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=400, detail="No word result found") + + cells = word_result.get("cells", []) + columns_used = word_result.get("columns_used", []) + grid_shape = word_result.get("grid_shape", {}) + n_rows = grid_shape.get("rows", 0) + n_cols = grid_shape.get("cols", 0) + + try: + from docx import Document + from docx.shared import Pt + import io as _io + + doc = Document() + doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1) + + header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)] + if not header: + header = [f"Col {i}" for i in range(n_cols)] + + table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1)) + table.style = 'Table Grid' + + for ci, h in enumerate(header): + table.rows[0].cells[ci].text = h + + for r in range(n_rows): + for ci in range(n_cols): + cell_id = f"R{r:02d}_C{ci}" + cell = next((c for c in cells if c.get("cell_id") == cell_id), None) + table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else "" + + buf = _io.BytesIO() + doc.save(buf) + buf.seek(0) + + return StreamingResponse( + buf, + media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'}, + ) + except ImportError: + raise HTTPException(status_code=501, detail="python-docx not installed") diff --git a/klausur-service/backend/ocr_pipeline_validation.py b/klausur-service/backend/ocr_pipeline_validation.py new file mode 100644 index 0000000..3382a3f --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_validation.py @@ -0,0 +1,362 @@ +""" +OCR Pipeline Validation — image detection, generation, validation save, +and handwriting removal endpoints. + +Extracted from ocr_pipeline_postprocess.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +from datetime import datetime +from typing import Optional + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from ocr_pipeline_session_store import ( + get_session_db, + get_session_image, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + RemoveHandwritingRequest, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + +# --------------------------------------------------------------------------- +# Pydantic Models +# --------------------------------------------------------------------------- + +STYLE_SUFFIXES = { + "educational": "educational illustration, textbook style, clear, colorful", + "cartoon": "cartoon, child-friendly, simple shapes", + "sketch": "pencil sketch, hand-drawn, black and white", + "clipart": "clipart, flat vector style, simple", + "realistic": "photorealistic, high detail", +} + + +class ValidationRequest(BaseModel): + notes: Optional[str] = None + score: Optional[int] = None + + +class GenerateImageRequest(BaseModel): + region_index: int + prompt: str + style: str = "educational" + + +# --------------------------------------------------------------------------- +# Image detection + generation +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction/detect-images") +async def detect_image_regions(session_id: str): + """Detect illustration/image regions in the original scan using VLM.""" + import base64 + import httpx + import re + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=400, detail="No original image found") + + word_result = session.get("word_result") or {} + entries = word_result.get("vocab_entries") or word_result.get("entries") or [] + vocab_context = "" + if entries: + sample = entries[:10] + words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')] + if words: + vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}" + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "Analyze this scanned page. Find ALL illustration/image/picture regions " + "(NOT text, NOT table cells, NOT blank areas). " + "For each image region found, return its bounding box as percentage of page dimensions " + "and a short English description of what the image shows. " + "Reply with ONLY a JSON array like: " + '[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] ' + "where x, y, w, h are percentages (0-100) of the page width/height. " + "If there are NO images on the page, return an empty array: []" + f"{vocab_context}" + ) + + img_b64 = base64.b64encode(original_png).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + match = re.search(r'\[.*?\]', text, re.DOTALL) + if match: + raw_regions = json.loads(match.group(0)) + else: + raw_regions = [] + + regions = [] + for r in raw_regions: + regions.append({ + "bbox_pct": { + "x": max(0, min(100, float(r.get("x", 0)))), + "y": max(0, min(100, float(r.get("y", 0)))), + "w": max(1, min(100, float(r.get("w", 10)))), + "h": max(1, min(100, float(r.get("h", 10)))), + }, + "description": r.get("description", ""), + "prompt": r.get("description", ""), + "image_b64": None, + "style": "educational", + }) + + # Enrich prompts with nearby vocab context + if entries: + for region in regions: + ry = region["bbox_pct"]["y"] + rh = region["bbox_pct"]["h"] + nearby = [ + e for e in entries + if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10 + ] + if nearby: + en_words = [e.get("english", "") for e in nearby if e.get("english")] + de_words = [e.get("german", "") for e in nearby if e.get("german")] + if en_words or de_words: + context = f" (vocabulary context: {', '.join(en_words[:5])}" + if de_words: + context += f" / {', '.join(de_words[:5])}" + context += ")" + region["prompt"] = region["description"] + context + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["image_regions"] = regions + validation["detected_at"] = datetime.utcnow().isoformat() + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Detected {len(regions)} image regions for session {session_id}") + + return {"regions": regions, "count": len(regions)} + + except httpx.ConnectError: + logger.warning(f"VLM not available at {ollama_base} for image detection") + return {"regions": [], "count": 0, "error": "VLM not available"} + except Exception as e: + logger.error(f"Image detection failed for {session_id}: {e}") + return {"regions": [], "count": 0, "error": str(e)} + + +@router.post("/sessions/{session_id}/reconstruction/generate-image") +async def generate_image_for_region(session_id: str, req: GenerateImageRequest): + """Generate a replacement image for a detected region using mflux.""" + import httpx + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + regions = validation.get("image_regions") or [] + + if req.region_index < 0 or req.region_index >= len(regions): + raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions") + + mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095") + style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"]) + full_prompt = f"{req.prompt}, {style_suffix}" + + region = regions[req.region_index] + bbox = region["bbox_pct"] + aspect = bbox["w"] / max(bbox["h"], 1) + if aspect > 1.3: + width, height = 768, 512 + elif aspect < 0.7: + width, height = 512, 768 + else: + width, height = 512, 512 + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{mflux_url}/generate", json={ + "prompt": full_prompt, + "width": width, + "height": height, + "steps": 4, + }) + resp.raise_for_status() + data = resp.json() + image_b64 = data.get("image_b64") + + if not image_b64: + return {"image_b64": None, "success": False, "error": "No image returned"} + + regions[req.region_index]["image_b64"] = image_b64 + regions[req.region_index]["prompt"] = req.prompt + regions[req.region_index]["style"] = req.style + validation["image_regions"] = regions + ground_truth["validation"] = validation + await update_session_db(session_id, ground_truth=ground_truth) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Generated image for session {session_id} region {req.region_index}") + return {"image_b64": image_b64, "success": True} + + except httpx.ConnectError: + logger.warning(f"mflux-service not available at {mflux_url}") + return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"} + except Exception as e: + logger.error(f"Image generation failed for {session_id}: {e}") + return {"image_b64": None, "success": False, "error": str(e)} + + +# --------------------------------------------------------------------------- +# Validation save/get +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/reconstruction/validate") +async def save_validation(session_id: str, req: ValidationRequest): + """Save final validation results for step 8.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") or {} + validation["validated_at"] = datetime.utcnow().isoformat() + validation["notes"] = req.notes + validation["score"] = req.score + ground_truth["validation"] = validation + + await update_session_db(session_id, ground_truth=ground_truth, current_step=11) + + if session_id in _cache: + _cache[session_id]["ground_truth"] = ground_truth + + logger.info(f"Validation saved for session {session_id}: score={req.score}") + + return {"session_id": session_id, "validation": validation} + + +@router.get("/sessions/{session_id}/reconstruction/validation") +async def get_validation(session_id: str): + """Retrieve saved validation data for step 8.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + ground_truth = session.get("ground_truth") or {} + validation = ground_truth.get("validation") + + return { + "session_id": session_id, + "validation": validation, + "word_result": session.get("word_result"), + } + + +# --------------------------------------------------------------------------- +# Remove handwriting +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/remove-handwriting") +async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest): + """Remove handwriting from a session image using inpainting.""" + import time as _time + + from services.handwriting_detection import detect_handwriting + from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + t0 = _time.monotonic() + + # 1. Determine source image + source = req.use_source + if source == "auto": + deskewed = await get_session_image(session_id, "deskewed") + source = "deskewed" if deskewed else "original" + + image_bytes = await get_session_image(session_id, source) + if not image_bytes: + raise HTTPException(status_code=404, detail=f"Source image '{source}' not available") + + # 2. Detect handwriting mask + detection = detect_handwriting(image_bytes, target_ink=req.target_ink) + + # 3. Convert mask to PNG bytes and dilate + import io + from PIL import Image as _PILImage + mask_img = _PILImage.fromarray(detection.mask) + mask_buf = io.BytesIO() + mask_img.save(mask_buf, format="PNG") + mask_bytes = mask_buf.getvalue() + + if req.dilation > 0: + mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation) + + # 4. Inpaint + method_map = { + "telea": InpaintingMethod.OPENCV_TELEA, + "ns": InpaintingMethod.OPENCV_NS, + "auto": InpaintingMethod.AUTO, + } + inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO) + + result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method) + if not result.success: + raise HTTPException(status_code=500, detail="Inpainting failed") + + elapsed_ms = int((_time.monotonic() - t0) * 1000) + + meta = { + "method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used), + "handwriting_ratio": round(detection.handwriting_ratio, 4), + "detection_confidence": round(detection.confidence, 4), + "target_ink": req.target_ink, + "dilation": req.dilation, + "source_image": source, + "processing_time_ms": elapsed_ms, + } + + # 5. Persist clean image + clean_png_bytes = image_to_png(result.image) + await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta) + + return { + **meta, + "image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean", + "session_id": session_id, + } diff --git a/klausur-service/backend/ocr_pipeline_words.py b/klausur-service/backend/ocr_pipeline_words.py index c3eb7d6..a1d0f87 100644 --- a/klausur-service/backend/ocr_pipeline_words.py +++ b/klausur-service/backend/ocr_pipeline_words.py @@ -1,18 +1,18 @@ """ -OCR Pipeline Words - Word detection and ground truth endpoints. +OCR Pipeline Words — composite router for word detection, PaddleOCR direct, +and ground truth endpoints. -Extracted from ocr_pipeline_api.py. -Handles: -- POST /sessions/{session_id}/words — main SSE streaming word detection -- POST /sessions/{session_id}/paddle-direct — PaddleOCR direct endpoint -- POST /sessions/{session_id}/ground-truth/words — save ground truth -- GET /sessions/{session_id}/ground-truth/words — get ground truth +Split into sub-modules: + ocr_pipeline_words_detect — main detect_words endpoint (Step 7) + ocr_pipeline_words_stream — SSE streaming generators + +This barrel module contains the PaddleOCR direct endpoint and ground truth +endpoints, and assembles all word-related routers. Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json import logging import time from datetime import datetime @@ -20,22 +20,9 @@ from typing import Any, Dict, List, Optional import cv2 import numpy as np -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from cv_vocab_pipeline import ( - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _fix_character_confusion, - _fix_phonetic_brackets, - fix_cell_phonetics, - build_cell_grid_v2, - build_cell_grid_v2_streaming, - create_ocr_image, - detect_column_geometry, -) from cv_words_first import build_grid_from_words from ocr_pipeline_session_store import ( get_session_db, @@ -44,15 +31,13 @@ from ocr_pipeline_session_store import ( ) from ocr_pipeline_common import ( _cache, - _load_session_to_cache, - _get_cached, - _get_base_image_png, _append_pipeline_log, ) +from ocr_pipeline_words_detect import router as _detect_router logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) +_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) # --------------------------------------------------------------------------- @@ -65,689 +50,13 @@ class WordGroundTruthRequest(BaseModel): notes: Optional[str] = None -# --------------------------------------------------------------------------- -# Word Detection Endpoint (Step 7) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/words") -async def detect_words( - session_id: str, - request: Request, - engine: str = "auto", - pronunciation: str = "british", - stream: bool = False, - skip_heal_gaps: bool = False, - grid_method: str = "v2", -): - """Build word grid from columns × rows, OCR each cell. - - Query params: - engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' - pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup - stream: false (default) for JSON response, true for SSE streaming - skip_heal_gaps: false (default). When true, cells keep exact row geometry - positions without gap-healing expansion. Better for overlay rendering. - grid_method: 'v2' (default) or 'words_first' — grid construction strategy. - 'v2' uses pre-detected columns/rows (top-down). - 'words_first' clusters words bottom-up (no column/row detection needed). - """ - # PaddleOCR is full-page remote OCR → force words_first grid method - if engine == "paddle" and grid_method != "words_first": - logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) - grid_method = "words_first" - - if session_id not in _cache: - logger.info("detect_words: session %s not in cache, loading from DB", session_id) - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if dewarped_bgr is None: - logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", - session_id, [k for k in cached.keys() if k.endswith('_bgr')]) - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") - - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - row_result = session.get("row_result") - if not column_result or not column_result.get("columns"): - # No column detection — synthesize a single full-page pseudo-column. - # This enables the overlay pipeline which skips column detection. - img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] - column_result = { - "columns": [{ - "type": "column_text", - "x": 0, "y": 0, - "width": img_w_tmp, "height": img_h_tmp, - "classification_confidence": 1.0, - "classification_method": "full_page_fallback", - }], - "zones": [], - "duration_seconds": 0, - } - logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) - if grid_method != "words_first" and (not row_result or not row_result.get("rows")): - raise HTTPException(status_code=400, detail="Row detection must be completed first") - - # Convert column dicts back to PageRegion objects - col_regions = [ - PageRegion( - type=c["type"], - x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - # Convert row dicts back to RowGeometry objects - row_geoms = [ - RowGeometry( - index=r["index"], - x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), - words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - # Cell-First OCR (v2): no full-page word re-population needed. - # Each cell is cropped and OCR'd in isolation → no neighbour bleeding. - # We still need word_count > 0 for row filtering in build_cell_grid_v2, - # so populate from cached words if available (just for counting). - word_dicts = cached.get("_word_dicts") - if word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if word_dicts: - content_bounds = cached.get("_content_bounds") - if content_bounds: - _lx, _rx, top_y, _by = content_bounds - else: - top_y = min(r.y for r in row_geoms) if row_geoms else 0 - - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - # Exclude rows that fall within box zones. - # Use inner box range (shrunk by border_thickness) so that rows at - # the boundary (overlapping with the box border) are NOT excluded. - zones = column_result.get("zones") or [] - box_ranges_inner = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin - box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) - - if box_ranges_inner: - def _row_in_box(r): - center_y = r.y + r.height / 2 - return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) - - before_count = len(row_geoms) - row_geoms = [r for r in row_geoms if not _row_in_box(r)] - excluded = before_count - len(row_geoms) - if excluded: - logger.info(f"detect_words: excluded {excluded} rows inside box zones") - - # --- Words-First path: bottom-up grid from word boxes --- - if grid_method == "words_first": - t0 = time.time() - img_h, img_w = dewarped_bgr.shape[:2] - - # For paddle engine: run remote PaddleOCR full-page instead of Tesseract - if engine == "paddle": - from cv_ocr_engines import ocr_region_paddle - - wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) - # PaddleOCR returns absolute coordinates, no content_bounds offset needed - cached["_paddle_word_dicts"] = wf_word_dicts - else: - # Get word_dicts from cache or run Tesseract full-page - wf_word_dicts = cached.get("_word_dicts") - if wf_word_dicts is None: - ocr_img_tmp = create_ocr_image(dewarped_bgr) - geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) - if geo_result is not None: - _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result - cached["_word_dicts"] = wf_word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - if not wf_word_dicts: - raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid") - - # Convert word coordinates to absolute image coordinates if needed - # (detect_column_geometry returns words relative to content ROI) - # PaddleOCR already returns absolute coordinates — skip offset. - if engine != "paddle": - content_bounds = cached.get("_content_bounds") - if content_bounds: - lx, _rx, ty, _by = content_bounds - abs_words = [] - for w in wf_word_dicts: - abs_words.append({ - **w, - 'left': w['left'] + lx, - 'top': w['top'] + ty, - }) - wf_word_dicts = abs_words - - # Extract box rects for box-aware column clustering - box_rects = [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box_rects.append(zone["box"]) - - cells, columns_meta = build_grid_from_words( - wf_word_dicts, img_w, img_h, box_rects=box_rects or None, - ) - duration = time.time() - t0 - - # Apply IPA phonetic fixes - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # Add zone_index for backward compat - for cell in cells: - cell.setdefault("zone_index", 0) - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 - n_cols = len(columns_meta) - used_engine = "paddle" if engine == "paddle" else "words_first" - - word_result = { - "cells": cells, - "grid_shape": { - "rows": n_rows, - "cols": n_cols, - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "grid_method": "words_first", - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - if is_vocab or 'column_text' in col_types: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words-first session {session_id}: " - f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") - - await _append_pipeline_log(session_id, "words", { - "grid_method": "words_first", - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **word_result} - - if stream: - # Cell-First OCR v2: use batch-then-stream approach instead of - # per-cell streaming. The parallel ThreadPoolExecutor in - # build_cell_grid_v2 is much faster than sequential streaming. - return StreamingResponse( - _word_batch_stream_generator( - session_id, cached, col_regions, row_geoms, - dewarped_bgr, engine, pronunciation, request, - skip_heal_gaps=skip_heal_gaps, - ), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - # --- Non-streaming path (grid_method=v2) --- - t0 = time.time() - - # Create binarized OCR image (for Tesseract) - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - # Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation - cells, columns_meta = build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ) - duration = time.time() - t0 - - # Add zone_index to each cell (default 0 for backward compatibility) - for cell in cells: - cell.setdefault("zone_index", 0) - - # Layout detection - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - - # Count content rows and columns for grid_shape - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len(columns_meta) - - # Determine which engine was actually used - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - # Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # Grid result (always generic) - word_result = { - "cells": cells, - "grid_shape": { - "rows": n_content_rows, - "cols": n_cols, - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - # For vocab layout or single-column (box sub-sessions): map cells 1:1 - # to vocab entries (row→entry). - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - - # Persist to DB - await update_session_db( - session_id, - word_result=word_result, - current_step=8, - ) - - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") - - await _append_pipeline_log(session_id, "words", { - "total_cells": len(cells), - "non_empty_cells": word_result["summary"]["non_empty_cells"], - "low_confidence_count": word_result["summary"]["low_confidence"], - "ocr_engine": used_engine, - "layout": word_result["layout"], - "entry_count": word_result.get("entry_count", 0), - }, duration_ms=int(duration * 1000)) - - return { - "session_id": session_id, - **word_result, - } - - -async def _word_batch_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, - skip_heal_gaps: bool = False, -): - """SSE generator that runs batch OCR (parallel) then streams results. - - Unlike the old per-cell streaming, this uses build_cell_grid_v2 with - ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events. - The 'preparing' event keeps the connection alive during OCR processing. - """ - import asyncio - - t0 = time.time() - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - total_cells = n_content_rows * n_cols - - # 1. Send meta event immediately - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - # 2. Send preparing event (keepalive for proxy) - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" - - # 3. Run batch OCR in thread pool with periodic keepalive events. - # The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE - # connections after 30-60s. Send keepalive every 5s to prevent this. - loop = asyncio.get_event_loop() - ocr_future = loop.run_in_executor( - None, - lambda: build_cell_grid_v2( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - skip_heal_gaps=skip_heal_gaps, - ), - ) - - # Send keepalive events every 5 seconds while OCR runs - keepalive_count = 0 - while not ocr_future.done(): - try: - cells, columns_meta = await asyncio.wait_for( - asyncio.shield(ocr_future), timeout=5.0, - ) - break # OCR finished - except asyncio.TimeoutError: - keepalive_count += 1 - elapsed = int(time.time() - t0) - yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected during OCR for {session_id}") - ocr_future.cancel() - return - else: - cells, columns_meta = ocr_future.result() - - if await request.is_disconnected(): - logger.info(f"SSE batch: client disconnected after OCR for {session_id}") - return - - # 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(cells, pronunciation=pronunciation) - - # 5. Send columns meta - if columns_meta: - yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" - - # 6. Stream all cells - for idx, cell in enumerate(cells): - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": idx + 1, "total": len(cells)}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # 6. Build final result and persist - duration = time.time() - t0 - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine - - word_result = { - "cells": cells, - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - await update_session_db(session_id, word_result=word_result, current_step=8) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " - f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") - - # 7. Send complete event - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" - - -async def _word_stream_generator( - session_id: str, - cached: Dict[str, Any], - col_regions: List[PageRegion], - row_geoms: List[RowGeometry], - dewarped_bgr: np.ndarray, - engine: str, - pronunciation: str, - request: Request, -): - """SSE generator that yields cell-by-cell OCR progress.""" - t0 = time.time() - - ocr_img = create_ocr_image(dewarped_bgr) - img_h, img_w = dewarped_bgr.shape[:2] - - # Compute grid shape upfront for the meta event - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} - n_cols = len([c for c in col_regions if c.type not in _skip_types]) - - # Determine layout - col_types = {c.type for c in col_regions if c.type not in _skip_types} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - - # Start streaming — first event: meta - columns_meta = None # will be set from first yield - total_cells = n_content_rows * n_cols - - meta_event = { - "type": "meta", - "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, - "layout": "vocab" if is_vocab else "generic", - } - yield f"data: {json.dumps(meta_event)}\n\n" - - # Keepalive: send preparing event so proxy doesn't timeout during OCR init - yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" - - # Stream cells one by one - all_cells: List[Dict[str, Any]] = [] - cell_idx = 0 - last_keepalive = time.time() - - for cell, cols_meta, total in build_cell_grid_v2_streaming( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=engine, img_bgr=dewarped_bgr, - ): - if await request.is_disconnected(): - logger.info(f"SSE: client disconnected during streaming for {session_id}") - return - - if columns_meta is None: - columns_meta = cols_meta - # Send columns_used as part of first cell or update meta - meta_update = { - "type": "columns", - "columns_used": cols_meta, - } - yield f"data: {json.dumps(meta_update)}\n\n" - - all_cells.append(cell) - cell_idx += 1 - - cell_event = { - "type": "cell", - "cell": cell, - "progress": {"current": cell_idx, "total": total}, - } - yield f"data: {json.dumps(cell_event)}\n\n" - - # All cells done — build final result - duration = time.time() - t0 - if columns_meta is None: - columns_meta = [] - - # Post-OCR: remove rows where ALL cells are empty (inter-row gaps - # that had stray Tesseract artifacts giving word_count > 0). - rows_with_text: set = set() - for c in all_cells: - if c.get("text", "").strip(): - rows_with_text.add(c["row_index"]) - before_filter = len(all_cells) - all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] - empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) - if empty_rows_removed > 0: - logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") - - used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine - - # Apply IPA phonetic fixes directly to cell texts (for overlay mode) - fix_cell_phonetics(all_cells, pronunciation=pronunciation) - - word_result = { - "cells": all_cells, - "grid_shape": { - "rows": n_content_rows, - "cols": n_cols, - "total_cells": len(all_cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(all_cells), - "non_empty_cells": sum(1 for c in all_cells if c.get("text")), - "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), - }, - } - - # For vocab layout or single-column (box sub-sessions): map cells 1:1 - # to vocab entries (row→entry). - vocab_entries = None - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(all_cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) - word_result["vocab_entries"] = entries - word_result["entries"] = entries - word_result["entry_count"] = len(entries) - word_result["summary"]["total_entries"] = len(entries) - word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) - word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) - vocab_entries = entries - - # Persist to DB - await update_session_db( - session_id, - word_result=word_result, - current_step=8, - ) - cached["word_result"] = word_result - - logger.info(f"OCR Pipeline SSE: words session {session_id}: " - f"layout={word_result['layout']}, " - f"{len(all_cells)} cells ({duration:.2f}s)") - - # Final complete event - complete_event = { - "type": "complete", - "summary": word_result["summary"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - } - if vocab_entries is not None: - complete_event["vocab_entries"] = vocab_entries - yield f"data: {json.dumps(complete_event)}\n\n" - - # --------------------------------------------------------------------------- # PaddleOCR Direct Endpoint # --------------------------------------------------------------------------- -@router.post("/sessions/{session_id}/paddle-direct") +@_local_router.post("/sessions/{session_id}/paddle-direct") async def paddle_direct(session_id: str): - """Run PaddleOCR on the preprocessed image and build a word grid directly. - - Expects orientation/deskew/dewarp/crop to be done already. - Uses the cropped image (falls back to dewarped, then original). - The used image is stored as cropped_png so OverlayReconstruction - can display it as the background. - """ - # Try preprocessed images first (crop > dewarp > original) + """Run PaddleOCR on the preprocessed image and build a word grid directly.""" img_png = await get_session_image(session_id, "cropped") if not img_png: img_png = await get_session_image(session_id, "dewarped") @@ -770,13 +79,9 @@ async def paddle_direct(session_id: str): if not word_dicts: raise HTTPException(status_code=400, detail="PaddleOCR returned no words") - # Reuse build_grid_from_words — same function that works in the regular - # pipeline with PaddleOCR (engine=paddle, grid_method=words_first). - # Handles phrase splitting, column clustering, and reading order. cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h) duration = time.time() - t0 - # Tag cells as paddle_direct for cell in cells: cell["ocr_engine"] = "paddle_direct" @@ -787,11 +92,7 @@ async def paddle_direct(session_id: str): word_result = { "cells": cells, - "grid_shape": { - "rows": n_rows, - "cols": n_cols, - "total_cells": len(cells), - }, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, "columns_used": columns_meta, "layout": "vocab" if is_vocab else "generic", "image_width": img_w, @@ -806,7 +107,6 @@ async def paddle_direct(session_id: str): }, } - # Store preprocessed image as cropped_png so OverlayReconstruction shows it await update_session_db( session_id, word_result=word_result, @@ -832,7 +132,7 @@ async def paddle_direct(session_id: str): # Ground Truth Words Endpoints # --------------------------------------------------------------------------- -@router.post("/sessions/{session_id}/ground-truth/words") +@_local_router.post("/sessions/{session_id}/ground-truth/words") async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): """Save ground truth feedback for the word recognition step.""" session = await get_session_db(session_id) @@ -857,7 +157,7 @@ async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest): return {"session_id": session_id, "ground_truth": gt} -@router.get("/sessions/{session_id}/ground-truth/words") +@_local_router.get("/sessions/{session_id}/ground-truth/words") async def get_word_ground_truth(session_id: str): """Retrieve saved ground truth for word recognition.""" session = await get_session_db(session_id) @@ -874,3 +174,12 @@ async def get_word_ground_truth(session_id: str): "words_gt": words_gt, "words_auto": session.get("word_result"), } + + +# --------------------------------------------------------------------------- +# Composite router +# --------------------------------------------------------------------------- + +router = APIRouter() +router.include_router(_detect_router) +router.include_router(_local_router) diff --git a/klausur-service/backend/ocr_pipeline_words_detect.py b/klausur-service/backend/ocr_pipeline_words_detect.py new file mode 100644 index 0000000..b70cff3 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_words_detect.py @@ -0,0 +1,393 @@ +""" +OCR Pipeline Words Detect — main word detection endpoint (Step 7). + +Extracted from ocr_pipeline_words.py. Contains the ``detect_words`` +endpoint which handles both v2 and words_first grid methods. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, List + +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _fix_phonetic_brackets, + fix_cell_phonetics, + build_cell_grid_v2, + create_ocr_image, + detect_column_geometry, +) +from cv_words_first import build_grid_from_words +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, + _append_pipeline_log, +) +from ocr_pipeline_words_stream import ( + _word_batch_stream_generator, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Word Detection Endpoint (Step 7) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/words") +async def detect_words( + session_id: str, + request: Request, + engine: str = "auto", + pronunciation: str = "british", + stream: bool = False, + skip_heal_gaps: bool = False, + grid_method: str = "v2", +): + """Build word grid from columns x rows, OCR each cell. + + Query params: + engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle' + pronunciation: 'british' (default) or 'american' + stream: false (default) for JSON response, true for SSE streaming + skip_heal_gaps: false (default). When true, cells keep exact row geometry. + grid_method: 'v2' (default) or 'words_first' + """ + # PaddleOCR is full-page remote OCR -> force words_first grid method + if engine == "paddle" and grid_method != "words_first": + logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method) + grid_method = "words_first" + + if session_id not in _cache: + logger.info("detect_words: session %s not in cache, loading from DB", session_id) + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if dewarped_bgr is None: + logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)", + session_id, [k for k in cached.keys() if k.endswith('_bgr')]) + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection") + + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + row_result = session.get("row_result") + if not column_result or not column_result.get("columns"): + img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2] + column_result = { + "columns": [{ + "type": "column_text", + "x": 0, "y": 0, + "width": img_w_tmp, "height": img_h_tmp, + "classification_confidence": 1.0, + "classification_method": "full_page_fallback", + }], + "zones": [], + "duration_seconds": 0, + } + logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp) + if grid_method != "words_first" and (not row_result or not row_result.get("rows")): + raise HTTPException(status_code=400, detail="Row detection must be completed first") + + # Convert column dicts back to PageRegion objects + col_regions = [ + PageRegion( + type=c["type"], + x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + # Convert row dicts back to RowGeometry objects + row_geoms = [ + RowGeometry( + index=r["index"], + x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), + words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + # Populate word counts from cached words + word_dicts = cached.get("_word_dicts") + if word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if word_dicts: + content_bounds = cached.get("_content_bounds") + if content_bounds: + _lx, _rx, top_y, _by = content_bounds + else: + top_y = min(r.y for r in row_geoms) if row_geoms else 0 + + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + # Exclude rows that fall within box zones + zones = column_result.get("zones") or [] + box_ranges_inner = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bt = max(box.get("border_thickness", 0), 5) + box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt)) + + if box_ranges_inner: + def _row_in_box(r): + center_y = r.y + r.height / 2 + return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner) + + before_count = len(row_geoms) + row_geoms = [r for r in row_geoms if not _row_in_box(r)] + excluded = before_count - len(row_geoms) + if excluded: + logger.info(f"detect_words: excluded {excluded} rows inside box zones") + + # --- Words-First path --- + if grid_method == "words_first": + return await _words_first_path( + session_id, cached, dewarped_bgr, engine, pronunciation, zones, + ) + + if stream: + return StreamingResponse( + _word_batch_stream_generator( + session_id, cached, col_regions, row_geoms, + dewarped_bgr, engine, pronunciation, request, + skip_heal_gaps=skip_heal_gaps, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + # --- Non-streaming path (grid_method=v2) --- + return await _v2_path( + session_id, cached, col_regions, row_geoms, + dewarped_bgr, engine, pronunciation, skip_heal_gaps, + ) + + +async def _words_first_path( + session_id: str, + cached: Dict[str, Any], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + zones: list, +) -> dict: + """Words-first grid construction path.""" + t0 = time.time() + img_h, img_w = dewarped_bgr.shape[:2] + + if engine == "paddle": + from cv_ocr_engines import ocr_region_paddle + wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None) + cached["_paddle_word_dicts"] = wf_word_dicts + else: + wf_word_dicts = cached.get("_word_dicts") + if wf_word_dicts is None: + ocr_img_tmp = create_ocr_image(dewarped_bgr) + geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr) + if geo_result is not None: + _geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result + cached["_word_dicts"] = wf_word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + if not wf_word_dicts: + raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid") + + # Convert word coordinates to absolute if needed + if engine != "paddle": + content_bounds = cached.get("_content_bounds") + if content_bounds: + lx, _rx, ty, _by = content_bounds + abs_words = [] + for w in wf_word_dicts: + abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty}) + wf_word_dicts = abs_words + + box_rects = [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box_rects.append(zone["box"]) + + cells, columns_meta = build_grid_from_words( + wf_word_dicts, img_w, img_h, box_rects=box_rects or None, + ) + duration = time.time() - t0 + + fix_cell_phonetics(cells, pronunciation=pronunciation) + for cell in cells: + cell.setdefault("zone_index", 0) + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_rows = len(set(c['row_index'] for c in cells)) if cells else 0 + n_cols = len(columns_meta) + used_engine = "paddle" if engine == "paddle" else "words_first" + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "grid_method": "words_first", + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + if is_vocab or 'column_text' in col_types: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words-first session {session_id}: " + f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols") + + await _append_pipeline_log(session_id, "words", { + "grid_method": "words_first", + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} + + +async def _v2_path( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + skip_heal_gaps: bool, +) -> dict: + """Cell-First OCR v2 non-streaming path.""" + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + cells, columns_meta = build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ) + duration = time.time() - t0 + + for cell in cells: + cell.setdefault("zone_index", 0) + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len(columns_meta) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + fix_cell_phonetics(cells, pronunciation=pronunciation) + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}") + + await _append_pipeline_log(session_id, "words", { + "total_cells": len(cells), + "non_empty_cells": word_result["summary"]["non_empty_cells"], + "low_confidence_count": word_result["summary"]["low_confidence"], + "ocr_engine": used_engine, + "layout": word_result["layout"], + "entry_count": word_result.get("entry_count", 0), + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **word_result} diff --git a/klausur-service/backend/ocr_pipeline_words_stream.py b/klausur-service/backend/ocr_pipeline_words_stream.py new file mode 100644 index 0000000..bb7d990 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_words_stream.py @@ -0,0 +1,303 @@ +""" +OCR Pipeline Words Stream — SSE streaming generators for word detection. + +Extracted from ocr_pipeline_words.py. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, List + +import numpy as np +from fastapi import Request + +from cv_vocab_pipeline import ( + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + build_cell_grid_v2, + build_cell_grid_v2_streaming, + create_ocr_image, +) +from ocr_pipeline_session_store import update_session_db +from ocr_pipeline_common import _cache + +logger = logging.getLogger(__name__) + + +async def _word_batch_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, + skip_heal_gaps: bool = False, +): + """SSE generator that runs batch OCR (parallel) then streams results. + + Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR, + then emits all cells as SSE events. + """ + import asyncio + + t0 = time.time() + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + total_cells = n_content_rows * n_cols + + # 1. Send meta event immediately + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + # 2. Send preparing event (keepalive for proxy) + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n" + + # 3. Run batch OCR in thread pool with periodic keepalive events. + loop = asyncio.get_event_loop() + ocr_future = loop.run_in_executor( + None, + lambda: build_cell_grid_v2( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + skip_heal_gaps=skip_heal_gaps, + ), + ) + + # Send keepalive events every 5 seconds while OCR runs + keepalive_count = 0 + while not ocr_future.done(): + try: + cells, columns_meta = await asyncio.wait_for( + asyncio.shield(ocr_future), timeout=5.0, + ) + break # OCR finished + except asyncio.TimeoutError: + keepalive_count += 1 + elapsed = int(time.time() - t0) + yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n" + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected during OCR for {session_id}") + ocr_future.cancel() + return + else: + cells, columns_meta = ocr_future.result() + + if await request.is_disconnected(): + logger.info(f"SSE batch: client disconnected after OCR for {session_id}") + return + + # 4. Apply IPA phonetic fixes + fix_cell_phonetics(cells, pronunciation=pronunciation) + + # 5. Send columns meta + if columns_meta: + yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n" + + # 6. Stream all cells + for idx, cell in enumerate(cells): + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": idx + 1, "total": len(cells)}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # 7. Build final result and persist + duration = time.time() - t0 + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine + + word_result = { + "cells": cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE batch: words session {session_id}: " + f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)") + + # 8. Send complete event + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" + + +async def _word_stream_generator( + session_id: str, + cached: Dict[str, Any], + col_regions: List[PageRegion], + row_geoms: List[RowGeometry], + dewarped_bgr: np.ndarray, + engine: str, + pronunciation: str, + request: Request, +): + """SSE generator that yields cell-by-cell OCR progress.""" + t0 = time.time() + + ocr_img = create_ocr_image(dewarped_bgr) + img_h, img_w = dewarped_bgr.shape[:2] + + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + _skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'} + n_cols = len([c for c in col_regions if c.type not in _skip_types]) + + col_types = {c.type for c in col_regions if c.type not in _skip_types} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + + columns_meta = None + total_cells = n_content_rows * n_cols + + meta_event = { + "type": "meta", + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells}, + "layout": "vocab" if is_vocab else "generic", + } + yield f"data: {json.dumps(meta_event)}\n\n" + + yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n" + + all_cells: List[Dict[str, Any]] = [] + cell_idx = 0 + last_keepalive = time.time() + + for cell, cols_meta, total in build_cell_grid_v2_streaming( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=engine, img_bgr=dewarped_bgr, + ): + if await request.is_disconnected(): + logger.info(f"SSE: client disconnected during streaming for {session_id}") + return + + if columns_meta is None: + columns_meta = cols_meta + meta_update = {"type": "columns", "columns_used": cols_meta} + yield f"data: {json.dumps(meta_update)}\n\n" + + all_cells.append(cell) + cell_idx += 1 + + cell_event = { + "type": "cell", + "cell": cell, + "progress": {"current": cell_idx, "total": total}, + } + yield f"data: {json.dumps(cell_event)}\n\n" + + # All cells done + duration = time.time() - t0 + if columns_meta is None: + columns_meta = [] + + # Remove all-empty rows + rows_with_text: set = set() + for c in all_cells: + if c.get("text", "").strip(): + rows_with_text.add(c["row_index"]) + before_filter = len(all_cells) + all_cells = [c for c in all_cells if c["row_index"] in rows_with_text] + empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1) + if empty_rows_removed > 0: + logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR") + + used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine + + fix_cell_phonetics(all_cells, pronunciation=pronunciation) + + word_result = { + "cells": all_cells, + "grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)}, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(all_cells), + "non_empty_cells": sum(1 for c in all_cells if c.get("text")), + "low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50), + }, + } + + vocab_entries = None + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(all_cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation) + word_result["vocab_entries"] = entries + word_result["entries"] = entries + word_result["entry_count"] = len(entries) + word_result["summary"]["total_entries"] = len(entries) + word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english")) + word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german")) + vocab_entries = entries + + await update_session_db(session_id, word_result=word_result, current_step=8) + cached["word_result"] = word_result + + logger.info(f"OCR Pipeline SSE: words session {session_id}: " + f"layout={word_result['layout']}, " + f"{len(all_cells)} cells ({duration:.2f}s)") + + complete_event = { + "type": "complete", + "summary": word_result["summary"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + } + if vocab_entries is not None: + complete_event["vocab_entries"] = vocab_entries + yield f"data: {json.dumps(complete_event)}\n\n" diff --git a/klausur-service/frontend/src/components/KorrekturDocumentViewer.tsx b/klausur-service/frontend/src/components/KorrekturDocumentViewer.tsx new file mode 100644 index 0000000..f75d5a4 --- /dev/null +++ b/klausur-service/frontend/src/components/KorrekturDocumentViewer.tsx @@ -0,0 +1,79 @@ +/** + * KorrekturDocumentViewer — center panel document display. + * + * Extracted from KorrekturPage.tsx. + */ + +import { StudentKlausur } from '../services/api' + +interface KorrekturDocumentViewerProps { + currentStudent: StudentKlausur | null +} + +export default function KorrekturDocumentViewer({ currentStudent }: KorrekturDocumentViewerProps) { + return ( +
+
+
+
+ {currentStudent ? currentStudent.student_name : 'Dokument-Ansicht'} +
+
+ {currentStudent && ( + <> + + + + )} +
+
+ +
+ {!currentStudent ? ( +
+
{'\uD83D\uDCC4'}
+
Keine Arbeit ausgewaehlt
+
+ Waehlen Sie eine Schuelerarbeit aus der Liste oder laden Sie eine neue hoch +
+
+ ) : currentStudent.file_path ? ( +
+
+ {'\uD83D\uDCC4'} {currentStudent.student_name} + {'\u2713'} Hochgeladen +
+
+ {currentStudent.file_path.endsWith('.pdf') ? ( +